feat: create v2-prep branch with comprehensive planning
Restructured project for V2 refactor: **Structure Changes:** - Moved all V1 code to orig/ folder (preserved with git mv) - Created docs/planning/ directory - Added orig/README_V1.md explaining V1 preservation **Planning Documents:** - 00_V2_MASTER_PLAN.md: Complete architecture overview - Executive summary of critical V1 issues - High-level component architecture diagrams - 5-phase implementation roadmap - Success metrics and risk mitigation - 07_TASK_BREAKDOWN.md: Atomic task breakdown - 99+ hours of detailed tasks - Every task < 2 hours (atomic) - Clear dependencies and success criteria - Organized by implementation phase **V2 Key Improvements:** - Per-exchange parsers (factory pattern) - Multi-layer strict validation - Multi-index pool cache - Background validation pipeline - Comprehensive observability **Critical Issues Addressed:** - Zero address tokens (strict validation + cache enrichment) - Parsing accuracy (protocol-specific parsers) - No audit trail (background validation channel) - Inefficient lookups (multi-index cache) - Stats disconnection (event-driven metrics) Next Steps: 1. Review planning documents 2. Begin Phase 1: Foundation (P1-001 through P1-010) 3. Implement parsers in Phase 2 4. Build cache system in Phase 3 5. Add validation pipeline in Phase 4 6. Migrate and test in Phase 5 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
262
orig/internal/auth/middleware.go
Normal file
262
orig/internal/auth/middleware.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// AuthConfig holds authentication configuration
|
||||
type AuthConfig struct {
|
||||
APIKey string
|
||||
BasicUsername string
|
||||
BasicPassword string
|
||||
AllowedIPs []string
|
||||
RequireHTTPS bool
|
||||
RateLimitRPS int
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
// Middleware provides authentication middleware for HTTP endpoints
|
||||
type Middleware struct {
|
||||
config *AuthConfig
|
||||
rateLimiter map[string]*RateLimiter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// RateLimiter tracks request rates per IP
|
||||
type RateLimiter struct {
|
||||
requests []time.Time
|
||||
maxRequests int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new authentication middleware
|
||||
func NewMiddleware(config *AuthConfig) *Middleware {
|
||||
// Use environment variables for sensitive data
|
||||
if config.APIKey == "" {
|
||||
config.APIKey = os.Getenv("MEV_BOT_API_KEY")
|
||||
}
|
||||
if config.BasicUsername == "" {
|
||||
config.BasicUsername = os.Getenv("MEV_BOT_USERNAME")
|
||||
}
|
||||
if config.BasicPassword == "" {
|
||||
config.BasicPassword = os.Getenv("MEV_BOT_PASSWORD")
|
||||
}
|
||||
|
||||
return &Middleware{
|
||||
config: config,
|
||||
rateLimiter: make(map[string]*RateLimiter),
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuthentication is a middleware that requires API key or basic auth
|
||||
func (m *Middleware) RequireAuthentication(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Security headers
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Require HTTPS in production
|
||||
if m.config.RequireHTTPS && r.Header.Get("X-Forwarded-Proto") != "https" && r.TLS == nil {
|
||||
http.Error(w, "HTTPS required", http.StatusUpgradeRequired)
|
||||
return
|
||||
}
|
||||
|
||||
// Check IP allowlist
|
||||
if !m.isIPAllowed(r.RemoteAddr) {
|
||||
m.config.Logger.Warn(fmt.Sprintf("Blocked request from unauthorized IP: %s", r.RemoteAddr))
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if !m.checkRateLimit(r.RemoteAddr) {
|
||||
m.config.Logger.Warn(fmt.Sprintf("Rate limit exceeded for IP: %s", r.RemoteAddr))
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
// Try API key authentication first
|
||||
if m.authenticateAPIKey(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Try basic authentication
|
||||
if m.authenticateBasic(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Authentication failed
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="MEV Bot API"`)
|
||||
http.Error(w, "Authentication required", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
// authenticateAPIKey checks for valid API key in header or query param
|
||||
func (m *Middleware) authenticateAPIKey(r *http.Request) bool {
|
||||
if m.config.APIKey == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check Authorization header
|
||||
auth := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
return subtle.ConstantTimeCompare([]byte(token), []byte(m.config.APIKey)) == 1
|
||||
}
|
||||
|
||||
// Check X-API-Key header
|
||||
apiKey := r.Header.Get("X-API-Key")
|
||||
if apiKey != "" {
|
||||
return subtle.ConstantTimeCompare([]byte(apiKey), []byte(m.config.APIKey)) == 1
|
||||
}
|
||||
|
||||
// Check query parameter (less secure, but sometimes necessary)
|
||||
queryKey := r.URL.Query().Get("api_key")
|
||||
if queryKey != "" {
|
||||
return subtle.ConstantTimeCompare([]byte(queryKey), []byte(m.config.APIKey)) == 1
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// authenticateBasic checks basic authentication credentials
|
||||
func (m *Middleware) authenticateBasic(r *http.Request) bool {
|
||||
if m.config.BasicUsername == "" || m.config.BasicPassword == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use constant time comparison to prevent timing attacks
|
||||
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(m.config.BasicUsername)) == 1
|
||||
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(m.config.BasicPassword)) == 1
|
||||
|
||||
return usernameMatch && passwordMatch
|
||||
}
|
||||
|
||||
// isIPAllowed checks if the request IP is in the allowlist
|
||||
func (m *Middleware) isIPAllowed(remoteAddr string) bool {
|
||||
if len(m.config.AllowedIPs) == 0 {
|
||||
return true // No IP restrictions
|
||||
}
|
||||
|
||||
// Extract IP from address (remove port)
|
||||
ip := strings.Split(remoteAddr, ":")[0]
|
||||
|
||||
for _, allowedIP := range m.config.AllowedIPs {
|
||||
if ip == allowedIP || allowedIP == "*" {
|
||||
return true
|
||||
}
|
||||
// Support CIDR notation in future versions
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// checkRateLimit implements simple rate limiting per IP
|
||||
func (m *Middleware) checkRateLimit(remoteAddr string) bool {
|
||||
if m.config.RateLimitRPS <= 0 {
|
||||
return true // No rate limiting
|
||||
}
|
||||
|
||||
ip := strings.Split(remoteAddr, ":")[0]
|
||||
now := time.Now()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Get or create rate limiter for this IP
|
||||
limiter, exists := m.rateLimiter[ip]
|
||||
if !exists {
|
||||
limiter = &RateLimiter{
|
||||
requests: make([]time.Time, 0),
|
||||
maxRequests: m.config.RateLimitRPS,
|
||||
window: time.Minute,
|
||||
}
|
||||
m.rateLimiter[ip] = limiter
|
||||
}
|
||||
|
||||
// Clean old requests outside the time window
|
||||
cutoff := now.Add(-limiter.window)
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, reqTime := range limiter.requests {
|
||||
if reqTime.After(cutoff) {
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
limiter.requests = validRequests
|
||||
|
||||
// Check if we're under the limit
|
||||
if len(limiter.requests) >= limiter.maxRequests {
|
||||
return false
|
||||
}
|
||||
|
||||
// Add current request
|
||||
limiter.requests = append(limiter.requests, now)
|
||||
return true
|
||||
}
|
||||
|
||||
// RequireReadOnly is a middleware for read-only endpoints (less strict)
|
||||
func (m *Middleware) RequireReadOnly(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only allow GET requests for read-only endpoints
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply basic security checks
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
|
||||
// Check IP allowlist
|
||||
if !m.isIPAllowed(r.RemoteAddr) {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply rate limiting
|
||||
if !m.checkRateLimit(r.RemoteAddr) {
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupRateLimiters removes old rate limiter entries
|
||||
func (m *Middleware) CleanupRateLimiters() {
|
||||
cutoff := time.Now().Add(-time.Hour)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for ip, limiter := range m.rateLimiter {
|
||||
// Remove limiters that haven't been used recently
|
||||
if len(limiter.requests) == 0 {
|
||||
delete(m.rateLimiter, ip)
|
||||
continue
|
||||
}
|
||||
|
||||
lastRequest := limiter.requests[len(limiter.requests)-1]
|
||||
if lastRequest.Before(cutoff) {
|
||||
delete(m.rateLimiter, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
879
orig/internal/config/config.go
Normal file
879
orig/internal/config/config.go
Normal file
@@ -0,0 +1,879 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config represents the application configuration
|
||||
type Config struct {
|
||||
Arbitrum ArbitrumConfig `yaml:"arbitrum"`
|
||||
Bot BotConfig `yaml:"bot"`
|
||||
Uniswap UniswapConfig `yaml:"uniswap"`
|
||||
Log LogConfig `yaml:"log"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Ethereum EthereumConfig `yaml:"ethereum"`
|
||||
Contracts ContractsConfig `yaml:"contracts"`
|
||||
Arbitrage ArbitrageConfig `yaml:"arbitrage"`
|
||||
Features Features `yaml:"features"`
|
||||
ArbitrageOptimized ArbitrageOptimizedConfig `yaml:"arbitrage_optimized"`
|
||||
}
|
||||
|
||||
// ArbitrumConfig represents the Arbitrum node configuration
|
||||
type ArbitrumConfig struct {
|
||||
// Chain ID for Arbitrum (42161 for mainnet)
|
||||
ChainID int64 `yaml:"chain_id"`
|
||||
|
||||
// Reading endpoints (WSS preferred for real-time monitoring)
|
||||
ReadingEndpoints []EndpointConfig `yaml:"reading_endpoints"`
|
||||
|
||||
// Execution endpoints (HTTP/HTTPS or WSS for transaction submission)
|
||||
ExecutionEndpoints []EndpointConfig `yaml:"execution_endpoints"`
|
||||
|
||||
// Fallback endpoints for failover scenarios
|
||||
FallbackEndpoints []EndpointConfig `yaml:"fallback_endpoints"`
|
||||
|
||||
// Legacy fields for backward compatibility
|
||||
RPCEndpoint string `yaml:"rpc_endpoint,omitempty"`
|
||||
WSEndpoint string `yaml:"ws_endpoint,omitempty"`
|
||||
|
||||
// Global rate limiting configuration
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
}
|
||||
|
||||
// EndpointConfig represents an RPC endpoint configuration
|
||||
type EndpointConfig struct {
|
||||
// RPC endpoint URL
|
||||
URL string `yaml:"url"`
|
||||
// Endpoint name for identification
|
||||
Name string `yaml:"name"`
|
||||
// Priority (lower number = higher priority)
|
||||
Priority int `yaml:"priority"`
|
||||
// Maximum requests per second for this endpoint
|
||||
MaxRPS int `yaml:"max_rps"`
|
||||
// Maximum concurrent connections
|
||||
MaxConcurrent int `yaml:"max_concurrent"`
|
||||
// Connection timeout in seconds
|
||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||
// Health check interval in seconds
|
||||
HealthCheckInterval int `yaml:"health_check_interval"`
|
||||
// Rate limiting configuration for this endpoint
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
}
|
||||
|
||||
// RateLimitConfig represents rate limiting configuration
|
||||
type RateLimitConfig struct {
|
||||
// Maximum requests per second
|
||||
RequestsPerSecond int `yaml:"requests_per_second"`
|
||||
// Maximum concurrent requests
|
||||
MaxConcurrent int `yaml:"max_concurrent"`
|
||||
// Burst size for rate limiting
|
||||
Burst int `yaml:"burst"`
|
||||
}
|
||||
|
||||
// BotConfig represents the bot configuration
|
||||
type BotConfig struct {
|
||||
// Enable or disable the bot
|
||||
Enabled bool `yaml:"enabled"`
|
||||
// Polling interval in seconds
|
||||
PollingInterval int `yaml:"polling_interval"`
|
||||
// Minimum profit threshold in USD
|
||||
MinProfitThreshold float64 `yaml:"min_profit_threshold"`
|
||||
// Gas price multiplier (for faster transactions)
|
||||
GasPriceMultiplier float64 `yaml:"gas_price_multiplier"`
|
||||
// Maximum number of concurrent workers for processing
|
||||
MaxWorkers int `yaml:"max_workers"`
|
||||
// Buffer size for channels
|
||||
ChannelBufferSize int `yaml:"channel_buffer_size"`
|
||||
// Timeout for RPC calls in seconds
|
||||
RPCTimeout int `yaml:"rpc_timeout"`
|
||||
}
|
||||
|
||||
// UniswapConfig represents the Uniswap configuration
|
||||
type UniswapConfig struct {
|
||||
// Factory contract address
|
||||
FactoryAddress string `yaml:"factory_address"`
|
||||
// Position manager contract address
|
||||
PositionManagerAddress string `yaml:"position_manager_address"`
|
||||
// Supported fee tiers
|
||||
FeeTiers []int64 `yaml:"fee_tiers"`
|
||||
// Cache configuration for pool data
|
||||
Cache CacheConfig `yaml:"cache"`
|
||||
}
|
||||
|
||||
// CacheConfig represents caching configuration
|
||||
type CacheConfig struct {
|
||||
// Enable or disable caching
|
||||
Enabled bool `yaml:"enabled"`
|
||||
// Cache expiration time in seconds
|
||||
Expiration int `yaml:"expiration"`
|
||||
// Maximum cache size
|
||||
MaxSize int `yaml:"max_size"`
|
||||
}
|
||||
|
||||
// LogConfig represents the logging configuration
|
||||
type LogConfig struct {
|
||||
// Log level (debug, info, warn, error)
|
||||
Level string `yaml:"level"`
|
||||
// Log format (json, text)
|
||||
Format string `yaml:"format"`
|
||||
// Log file path (empty for stdout)
|
||||
File string `yaml:"file"`
|
||||
}
|
||||
|
||||
// DatabaseConfig represents the database configuration
|
||||
type DatabaseConfig struct {
|
||||
// Database file path
|
||||
File string `yaml:"file"`
|
||||
// Maximum number of open connections
|
||||
MaxOpenConnections int `yaml:"max_open_connections"`
|
||||
// Maximum number of idle connections
|
||||
MaxIdleConnections int `yaml:"max_idle_connections"`
|
||||
}
|
||||
|
||||
// EthereumConfig represents the Ethereum account configuration
|
||||
type EthereumConfig struct {
|
||||
// Private key for transaction signing
|
||||
PrivateKey string `yaml:"private_key"`
|
||||
// Account address
|
||||
AccountAddress string `yaml:"account_address"`
|
||||
// Gas price multiplier (for faster transactions)
|
||||
GasPriceMultiplier float64 `yaml:"gas_price_multiplier"`
|
||||
}
|
||||
|
||||
// ContractsConfig represents the smart contract addresses
|
||||
type ContractsConfig struct {
|
||||
// Arbitrage executor contract address
|
||||
ArbitrageExecutor string `yaml:"arbitrage_executor"`
|
||||
// Flash swapper contract address
|
||||
FlashSwapper string `yaml:"flash_swapper"`
|
||||
// Flash loan receiver contract address (Balancer flash loans)
|
||||
FlashLoanReceiver string `yaml:"flash_loan_receiver"`
|
||||
// Balancer Vault address for flash loans
|
||||
BalancerVault string `yaml:"balancer_vault"`
|
||||
// Data fetcher contract address for batch pool data fetching
|
||||
DataFetcher string `yaml:"data_fetcher"`
|
||||
// Authorized caller addresses
|
||||
AuthorizedCallers []string `yaml:"authorized_callers"`
|
||||
// Authorized DEX addresses
|
||||
AuthorizedDEXes []string `yaml:"authorized_dexes"`
|
||||
}
|
||||
|
||||
// Load loads the configuration from a file
|
||||
func Load(filename string) (*Config, error) {
|
||||
// Read the config file
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
// Expand environment variables in the raw YAML
|
||||
expandedData := expandEnvVars(string(data))
|
||||
|
||||
// Parse the YAML
|
||||
var config Config
|
||||
if err := yaml.Unmarshal([]byte(expandedData), &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
// Override with environment variables if they exist
|
||||
config.OverrideWithEnv()
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// expandEnvVars expands ${VAR} and $VAR patterns in the given string
|
||||
func expandEnvVars(s string) string {
|
||||
// Pattern to match ${VAR} and $VAR
|
||||
envVarPattern := regexp.MustCompile(`\$\{([^}]+)\}|\$([A-Za-z_][A-Za-z0-9_]*)`)
|
||||
|
||||
return envVarPattern.ReplaceAllStringFunc(s, func(match string) string {
|
||||
var varName string
|
||||
|
||||
// Handle ${VAR} format
|
||||
if strings.HasPrefix(match, "${") && strings.HasSuffix(match, "}") {
|
||||
varName = match[2 : len(match)-1]
|
||||
} else if strings.HasPrefix(match, "$") {
|
||||
// Handle $VAR format
|
||||
varName = match[1:]
|
||||
}
|
||||
|
||||
// Get environment variable value
|
||||
if value := os.Getenv(varName); value != "" {
|
||||
return value
|
||||
}
|
||||
|
||||
// Return empty string if environment variable is not set
|
||||
// This prevents invalid YAML when variables are missing
|
||||
return ""
|
||||
})
|
||||
}
|
||||
|
||||
// OverrideWithEnv overrides configuration with environment variables
|
||||
func (c *Config) OverrideWithEnv() {
|
||||
// Override legacy RPC endpoint (backward compatibility)
|
||||
if rpcEndpoint := os.Getenv("ARBITRUM_RPC_ENDPOINT"); rpcEndpoint != "" {
|
||||
c.Arbitrum.RPCEndpoint = rpcEndpoint
|
||||
// Also add to execution endpoints if not already configured
|
||||
if len(c.Arbitrum.ExecutionEndpoints) == 0 {
|
||||
// Determine RPS based on endpoint type
|
||||
rps := 200
|
||||
if strings.HasPrefix(rpcEndpoint, "ws") {
|
||||
rps = 300
|
||||
}
|
||||
c.Arbitrum.ExecutionEndpoints = append(c.Arbitrum.ExecutionEndpoints, EndpointConfig{
|
||||
URL: rpcEndpoint,
|
||||
Name: "Arbitrum Public HTTP",
|
||||
Priority: 1,
|
||||
MaxRPS: rps,
|
||||
MaxConcurrent: 20,
|
||||
TimeoutSeconds: 30,
|
||||
HealthCheckInterval: 60,
|
||||
RateLimit: RateLimitConfig{
|
||||
RequestsPerSecond: rps,
|
||||
MaxConcurrent: 20,
|
||||
Burst: rps * 2,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Override legacy WebSocket endpoint (backward compatibility)
|
||||
if wsEndpoint := os.Getenv("ARBITRUM_WS_ENDPOINT"); wsEndpoint != "" {
|
||||
c.Arbitrum.WSEndpoint = wsEndpoint
|
||||
// Also add to reading endpoints if not already configured
|
||||
if len(c.Arbitrum.ReadingEndpoints) == 0 {
|
||||
c.Arbitrum.ReadingEndpoints = append(c.Arbitrum.ReadingEndpoints, EndpointConfig{
|
||||
URL: wsEndpoint,
|
||||
Name: "Arbitrum Public WS",
|
||||
Priority: 1,
|
||||
MaxRPS: 300,
|
||||
MaxConcurrent: 25,
|
||||
TimeoutSeconds: 60,
|
||||
HealthCheckInterval: 30,
|
||||
RateLimit: RateLimitConfig{
|
||||
RequestsPerSecond: 300,
|
||||
MaxConcurrent: 25,
|
||||
Burst: 600,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Override reading endpoints from environment
|
||||
if readingEndpoints := os.Getenv("ARBITRUM_READING_ENDPOINTS"); readingEndpoints != "" {
|
||||
c.Arbitrum.ReadingEndpoints = c.parseEndpointsFromEnv(readingEndpoints, "Reading")
|
||||
}
|
||||
|
||||
// Override execution endpoints from environment
|
||||
if executionEndpoints := os.Getenv("ARBITRUM_EXECUTION_ENDPOINTS"); executionEndpoints != "" {
|
||||
c.Arbitrum.ExecutionEndpoints = c.parseEndpointsFromEnv(executionEndpoints, "Execution")
|
||||
}
|
||||
|
||||
// Override fallback endpoints from environment (legacy support)
|
||||
if fallbackEndpoints := os.Getenv("ARBITRUM_FALLBACK_ENDPOINTS"); fallbackEndpoints != "" {
|
||||
// Add to both reading and execution if they're empty
|
||||
fallbackConfigs := c.parseEndpointsFromEnv(fallbackEndpoints, "Fallback")
|
||||
if len(c.Arbitrum.ReadingEndpoints) == 0 {
|
||||
c.Arbitrum.ReadingEndpoints = append(c.Arbitrum.ReadingEndpoints, fallbackConfigs...)
|
||||
}
|
||||
if len(c.Arbitrum.ExecutionEndpoints) == 0 {
|
||||
c.Arbitrum.ExecutionEndpoints = append(c.Arbitrum.ExecutionEndpoints, fallbackConfigs...)
|
||||
}
|
||||
}
|
||||
|
||||
// Override rate limit settings
|
||||
if rps := os.Getenv("RPC_REQUESTS_PER_SECOND"); rps != "" {
|
||||
if val, err := strconv.Atoi(rps); err == nil {
|
||||
c.Arbitrum.RateLimit.RequestsPerSecond = val
|
||||
}
|
||||
}
|
||||
|
||||
if maxConcurrent := os.Getenv("RPC_MAX_CONCURRENT"); maxConcurrent != "" {
|
||||
if val, err := strconv.Atoi(maxConcurrent); err == nil {
|
||||
c.Arbitrum.RateLimit.MaxConcurrent = val
|
||||
}
|
||||
}
|
||||
|
||||
// Override bot settings
|
||||
if maxWorkers := os.Getenv("BOT_MAX_WORKERS"); maxWorkers != "" {
|
||||
if val, err := strconv.Atoi(maxWorkers); err == nil {
|
||||
c.Bot.MaxWorkers = val
|
||||
}
|
||||
}
|
||||
|
||||
if channelBufferSize := os.Getenv("BOT_CHANNEL_BUFFER_SIZE"); channelBufferSize != "" {
|
||||
if val, err := strconv.Atoi(channelBufferSize); err == nil {
|
||||
c.Bot.ChannelBufferSize = val
|
||||
}
|
||||
}
|
||||
|
||||
// Override Ethereum settings
|
||||
if privateKey := os.Getenv("ETHEREUM_PRIVATE_KEY"); privateKey != "" {
|
||||
c.Ethereum.PrivateKey = privateKey
|
||||
}
|
||||
|
||||
if accountAddress := os.Getenv("ETHEREUM_ACCOUNT_ADDRESS"); accountAddress != "" {
|
||||
c.Ethereum.AccountAddress = accountAddress
|
||||
}
|
||||
|
||||
if gasPriceMultiplier := os.Getenv("ETHEREUM_GAS_PRICE_MULTIPLIER"); gasPriceMultiplier != "" {
|
||||
if val, err := strconv.ParseFloat(gasPriceMultiplier, 64); err == nil {
|
||||
c.Ethereum.GasPriceMultiplier = val
|
||||
}
|
||||
}
|
||||
|
||||
// Override contract addresses
|
||||
if arbitrageExecutor := os.Getenv("CONTRACT_ARBITRAGE_EXECUTOR"); arbitrageExecutor != "" {
|
||||
c.Contracts.ArbitrageExecutor = arbitrageExecutor
|
||||
}
|
||||
|
||||
if flashSwapper := os.Getenv("CONTRACT_FLASH_SWAPPER"); flashSwapper != "" {
|
||||
c.Contracts.FlashSwapper = flashSwapper
|
||||
}
|
||||
}
|
||||
|
||||
// parseEndpointsFromEnv parses comma-separated endpoint URLs from environment variable
|
||||
func (c *Config) parseEndpointsFromEnv(endpointsStr, namePrefix string) []EndpointConfig {
|
||||
if endpointsStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
urls := strings.Split(endpointsStr, ",")
|
||||
endpoints := make([]EndpointConfig, 0, len(urls))
|
||||
|
||||
for i, url := range urls {
|
||||
url = strings.TrimSpace(url)
|
||||
if url == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine defaults based on URL scheme
|
||||
var maxRPS, maxConcurrent, timeoutSeconds, healthCheckInterval int
|
||||
if strings.HasPrefix(url, "ws") {
|
||||
// WebSocket endpoints - higher rate limits for real-time data
|
||||
maxRPS = 300
|
||||
maxConcurrent = 25
|
||||
timeoutSeconds = 60
|
||||
healthCheckInterval = 30
|
||||
} else {
|
||||
// HTTP endpoints - conservative rate limits
|
||||
maxRPS = 200
|
||||
maxConcurrent = 20
|
||||
timeoutSeconds = 30
|
||||
healthCheckInterval = 60
|
||||
}
|
||||
|
||||
endpoint := EndpointConfig{
|
||||
URL: url,
|
||||
Name: fmt.Sprintf("%s-%d", namePrefix, i+1),
|
||||
Priority: i + 1, // Lower number = higher priority
|
||||
MaxRPS: maxRPS,
|
||||
MaxConcurrent: maxConcurrent,
|
||||
TimeoutSeconds: timeoutSeconds,
|
||||
HealthCheckInterval: healthCheckInterval,
|
||||
RateLimit: RateLimitConfig{
|
||||
RequestsPerSecond: maxRPS,
|
||||
MaxConcurrent: maxConcurrent,
|
||||
Burst: maxRPS * 2, // Allow burst of 2x normal rate
|
||||
},
|
||||
}
|
||||
|
||||
endpoints = append(endpoints, endpoint)
|
||||
}
|
||||
|
||||
return endpoints
|
||||
}
|
||||
|
||||
// CreateProviderConfigFile creates a temporary YAML config file for the transport system
|
||||
func (c *Config) CreateProviderConfigFile(tempPath string) error {
|
||||
// Convert config to provider format
|
||||
providerConfig := c.ConvertToProviderConfig()
|
||||
|
||||
// Marshal to YAML
|
||||
yamlData, err := yaml.Marshal(providerConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal provider config: %w", err)
|
||||
}
|
||||
|
||||
// Write to file
|
||||
if err := os.WriteFile(tempPath, yamlData, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write provider config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConvertToProviderConfig converts ArbitrumConfig to transport.ProvidersConfig
|
||||
func (c *Config) createProviderConfig(endpoint EndpointConfig, features []string) map[string]interface{} {
|
||||
provider := map[string]interface{}{
|
||||
"name": endpoint.Name,
|
||||
"type": "standard",
|
||||
"http_endpoint": "",
|
||||
"ws_endpoint": "",
|
||||
"priority": endpoint.Priority,
|
||||
"rate_limit": map[string]interface{}{
|
||||
"requests_per_second": endpoint.RateLimit.RequestsPerSecond,
|
||||
"burst": endpoint.RateLimit.Burst,
|
||||
"timeout": fmt.Sprintf("%ds", endpoint.TimeoutSeconds),
|
||||
"retry_delay": "1s",
|
||||
"max_retries": 3,
|
||||
},
|
||||
"features": features,
|
||||
"health_check": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"interval": fmt.Sprintf("%ds", endpoint.HealthCheckInterval),
|
||||
"timeout": fmt.Sprintf("%ds", endpoint.TimeoutSeconds),
|
||||
},
|
||||
}
|
||||
|
||||
// Determine endpoint type and assign to appropriate field
|
||||
if strings.HasPrefix(endpoint.URL, "ws") {
|
||||
provider["ws_endpoint"] = endpoint.URL
|
||||
} else {
|
||||
provider["http_endpoint"] = endpoint.URL
|
||||
}
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
func (c *Config) ConvertToProviderConfig() map[string]interface{} {
|
||||
providerConfigs := make([]map[string]interface{}, 0)
|
||||
|
||||
// Handle legacy configuration if new endpoints are not configured
|
||||
if len(c.Arbitrum.ReadingEndpoints) == 0 && len(c.Arbitrum.ExecutionEndpoints) == 0 {
|
||||
// Use legacy RPC and WS endpoints
|
||||
if c.Arbitrum.RPCEndpoint != "" {
|
||||
// Set default rate limits if zero
|
||||
rps := c.Arbitrum.RateLimit.RequestsPerSecond
|
||||
if rps <= 0 {
|
||||
if strings.HasPrefix(c.Arbitrum.RPCEndpoint, "ws") {
|
||||
rps = 300 // Default for WebSocket
|
||||
} else {
|
||||
rps = 200 // Default for HTTP
|
||||
}
|
||||
}
|
||||
burst := c.Arbitrum.RateLimit.Burst
|
||||
if burst <= 0 {
|
||||
burst = rps * 2 // Default burst is 2x RPS
|
||||
}
|
||||
|
||||
provider := map[string]interface{}{
|
||||
"name": "Legacy-RPC",
|
||||
"type": "standard",
|
||||
"priority": 1,
|
||||
"rate_limit": map[string]interface{}{
|
||||
"requests_per_second": rps,
|
||||
"burst": burst,
|
||||
"timeout": "30s",
|
||||
"retry_delay": "1s",
|
||||
"max_retries": 3,
|
||||
},
|
||||
"features": []string{"execution", "reading"},
|
||||
"health_check": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"interval": "60s",
|
||||
"timeout": "30s",
|
||||
},
|
||||
}
|
||||
|
||||
// Determine endpoint type and assign to appropriate field
|
||||
if strings.HasPrefix(c.Arbitrum.RPCEndpoint, "ws") {
|
||||
provider["http_endpoint"] = ""
|
||||
provider["ws_endpoint"] = c.Arbitrum.RPCEndpoint
|
||||
} else {
|
||||
provider["http_endpoint"] = c.Arbitrum.RPCEndpoint
|
||||
provider["ws_endpoint"] = ""
|
||||
}
|
||||
|
||||
providerConfigs = append(providerConfigs, provider)
|
||||
}
|
||||
|
||||
if c.Arbitrum.WSEndpoint != "" {
|
||||
// Set default rate limits if zero
|
||||
rps := c.Arbitrum.RateLimit.RequestsPerSecond
|
||||
if rps <= 0 {
|
||||
rps = 300 // Default for WebSocket
|
||||
}
|
||||
burst := c.Arbitrum.RateLimit.Burst
|
||||
if burst <= 0 {
|
||||
burst = rps * 2 // Default burst is 2x RPS
|
||||
}
|
||||
|
||||
provider := map[string]interface{}{
|
||||
"name": "Legacy-WSS",
|
||||
"type": "standard",
|
||||
"http_endpoint": "",
|
||||
"ws_endpoint": c.Arbitrum.WSEndpoint,
|
||||
"priority": 1,
|
||||
"rate_limit": map[string]interface{}{
|
||||
"requests_per_second": rps,
|
||||
"burst": burst,
|
||||
"timeout": "60s",
|
||||
"retry_delay": "1s",
|
||||
"max_retries": 3,
|
||||
},
|
||||
"features": []string{"reading", "real_time"},
|
||||
"health_check": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"interval": "30s",
|
||||
"timeout": "60s",
|
||||
},
|
||||
}
|
||||
providerConfigs = append(providerConfigs, provider)
|
||||
}
|
||||
|
||||
// Create simple pool configuration for legacy mode
|
||||
providerPools := make(map[string]interface{})
|
||||
if len(providerConfigs) > 0 {
|
||||
providerNames := make([]string, 0)
|
||||
for _, provider := range providerConfigs {
|
||||
providerNames = append(providerNames, provider["name"].(string))
|
||||
}
|
||||
|
||||
// Use same providers for both reading and execution in legacy mode
|
||||
providerPools["read_only"] = map[string]interface{}{
|
||||
"strategy": "priority_based",
|
||||
"max_concurrent_connections": 25,
|
||||
"health_check_interval": "30s",
|
||||
"failover_enabled": true,
|
||||
"providers": providerNames,
|
||||
}
|
||||
providerPools["execution"] = map[string]interface{}{
|
||||
"strategy": "priority_based",
|
||||
"max_concurrent_connections": 20,
|
||||
"health_check_interval": "30s",
|
||||
"failover_enabled": true,
|
||||
"providers": providerNames,
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"provider_pools": providerPools,
|
||||
"providers": providerConfigs,
|
||||
"rotation": map[string]interface{}{
|
||||
"strategy": "priority_based",
|
||||
"health_check_required": true,
|
||||
"fallover_enabled": true,
|
||||
"retry_failed_after": "5m",
|
||||
},
|
||||
"global_limits": map[string]interface{}{
|
||||
"max_concurrent_connections": 50,
|
||||
"connection_timeout": "30s",
|
||||
"read_timeout": "60s",
|
||||
"write_timeout": "30s",
|
||||
"idle_timeout": "300s",
|
||||
},
|
||||
"monitoring": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"metrics_interval": "60s",
|
||||
"log_slow_requests": true,
|
||||
"slow_request_threshold": "5s",
|
||||
"track_provider_performance": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Convert reading endpoints
|
||||
for _, endpoint := range c.Arbitrum.ReadingEndpoints {
|
||||
providerConfigs = append(providerConfigs, c.createProviderConfig(endpoint, []string{"reading", "real_time"}))
|
||||
}
|
||||
|
||||
// Convert execution endpoints
|
||||
for _, endpoint := range c.Arbitrum.ExecutionEndpoints {
|
||||
providerConfigs = append(providerConfigs, c.createProviderConfig(endpoint, []string{"execution", "transaction_submission"}))
|
||||
}
|
||||
|
||||
// Build provider pool configurations
|
||||
providerPools := make(map[string]interface{})
|
||||
|
||||
// Reading pool configuration
|
||||
if len(c.Arbitrum.ReadingEndpoints) > 0 {
|
||||
readingProviders := make([]string, 0)
|
||||
for _, endpoint := range c.Arbitrum.ReadingEndpoints {
|
||||
readingProviders = append(readingProviders, endpoint.Name)
|
||||
}
|
||||
providerPools["read_only"] = map[string]interface{}{
|
||||
"strategy": "websocket_preferred",
|
||||
"max_concurrent_connections": 25,
|
||||
"health_check_interval": "30s",
|
||||
"failover_enabled": true,
|
||||
"providers": readingProviders,
|
||||
}
|
||||
}
|
||||
|
||||
// Execution pool configuration
|
||||
if len(c.Arbitrum.ExecutionEndpoints) > 0 {
|
||||
executionProviders := make([]string, 0)
|
||||
for _, endpoint := range c.Arbitrum.ExecutionEndpoints {
|
||||
executionProviders = append(executionProviders, endpoint.Name)
|
||||
}
|
||||
providerPools["execution"] = map[string]interface{}{
|
||||
"strategy": "reliability_first",
|
||||
"max_concurrent_connections": 20,
|
||||
"health_check_interval": "30s",
|
||||
"failover_enabled": true,
|
||||
"providers": executionProviders,
|
||||
}
|
||||
}
|
||||
|
||||
// Complete configuration
|
||||
return map[string]interface{}{
|
||||
"provider_pools": providerPools,
|
||||
"providers": providerConfigs,
|
||||
"rotation": map[string]interface{}{
|
||||
"strategy": "priority_based",
|
||||
"health_check_required": true,
|
||||
"fallover_enabled": true,
|
||||
"retry_failed_after": "5m",
|
||||
},
|
||||
"global_limits": map[string]interface{}{
|
||||
"max_concurrent_connections": 50,
|
||||
"connection_timeout": "30s",
|
||||
"read_timeout": "60s",
|
||||
"write_timeout": "30s",
|
||||
"idle_timeout": "300s",
|
||||
},
|
||||
"monitoring": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"metrics_interval": "60s",
|
||||
"log_slow_requests": true,
|
||||
"slow_request_threshold": "5s",
|
||||
"track_provider_performance": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateEnvironmentVariables validates all required environment variables
|
||||
func (c *Config) ValidateEnvironmentVariables() error {
|
||||
// Validate RPC endpoint
|
||||
if c.Arbitrum.RPCEndpoint == "" {
|
||||
return fmt.Errorf("ARBITRUM_RPC_ENDPOINT environment variable is required")
|
||||
}
|
||||
|
||||
if err := validateRPCEndpoint(c.Arbitrum.RPCEndpoint); err != nil {
|
||||
return fmt.Errorf("invalid ARBITRUM_RPC_ENDPOINT: %w", err)
|
||||
}
|
||||
|
||||
// Validate WebSocket endpoint if provided
|
||||
if c.Arbitrum.WSEndpoint != "" {
|
||||
if err := validateRPCEndpoint(c.Arbitrum.WSEndpoint); err != nil {
|
||||
return fmt.Errorf("invalid ARBITRUM_WS_ENDPOINT: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate Ethereum private key
|
||||
if c.Ethereum.PrivateKey == "" {
|
||||
return fmt.Errorf("ETHEREUM_PRIVATE_KEY environment variable is required")
|
||||
}
|
||||
|
||||
// Validate account address
|
||||
if c.Ethereum.AccountAddress == "" {
|
||||
return fmt.Errorf("ETHEREUM_ACCOUNT_ADDRESS environment variable is required")
|
||||
}
|
||||
|
||||
// Validate contract addresses
|
||||
if c.Contracts.ArbitrageExecutor == "" {
|
||||
return fmt.Errorf("CONTRACT_ARBITRAGE_EXECUTOR environment variable is required")
|
||||
}
|
||||
|
||||
if c.Contracts.FlashSwapper == "" {
|
||||
return fmt.Errorf("CONTRACT_FLASH_SWAPPER environment variable is required")
|
||||
}
|
||||
|
||||
// Validate numeric values
|
||||
if c.Arbitrum.RateLimit.RequestsPerSecond < 0 {
|
||||
return fmt.Errorf("RPC_REQUESTS_PER_SECOND must be non-negative")
|
||||
}
|
||||
|
||||
if c.Arbitrum.RateLimit.MaxConcurrent < 0 {
|
||||
return fmt.Errorf("RPC_MAX_CONCURRENT must be non-negative")
|
||||
}
|
||||
|
||||
if c.Bot.MaxWorkers <= 0 {
|
||||
return fmt.Errorf("BOT_MAX_WORKERS must be positive")
|
||||
}
|
||||
|
||||
if c.Bot.ChannelBufferSize < 0 {
|
||||
return fmt.Errorf("BOT_CHANNEL_BUFFER_SIZE must be non-negative")
|
||||
}
|
||||
|
||||
if c.Ethereum.GasPriceMultiplier < 0 {
|
||||
return fmt.Errorf("ETHEREUM_GAS_PRICE_MULTIPLIER must be non-negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRPCEndpoint validates RPC endpoint URL for security and format
|
||||
func validateRPCEndpoint(endpoint string) error {
|
||||
if endpoint == "" {
|
||||
return fmt.Errorf("RPC endpoint cannot be empty")
|
||||
}
|
||||
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid RPC endpoint URL: %w", err)
|
||||
}
|
||||
|
||||
// Check for valid schemes
|
||||
switch u.Scheme {
|
||||
case "http", "https", "ws", "wss":
|
||||
// Valid schemes
|
||||
default:
|
||||
return fmt.Errorf("invalid RPC scheme: %s (must be http, https, ws, or wss)", u.Scheme)
|
||||
}
|
||||
|
||||
// Check for localhost/private networks in production
|
||||
if strings.Contains(u.Hostname(), "localhost") || strings.Contains(u.Hostname(), "127.0.0.1") {
|
||||
// Allow localhost only if explicitly enabled
|
||||
if os.Getenv("MEV_BOT_ALLOW_LOCALHOST") != "true" {
|
||||
return fmt.Errorf("localhost RPC endpoints not allowed in production (set MEV_BOT_ALLOW_LOCALHOST=true to override)")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate hostname is not empty
|
||||
if u.Hostname() == "" {
|
||||
return fmt.Errorf("RPC endpoint must have a valid hostname")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ArbitrageConfig represents the arbitrage service configuration
|
||||
type ArbitrageConfig struct {
|
||||
// Enable or disable arbitrage service
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// Contract addresses
|
||||
ArbitrageContractAddress string `yaml:"arbitrage_contract_address"`
|
||||
FlashSwapContractAddress string `yaml:"flash_swap_contract_address"`
|
||||
|
||||
// Profitability settings
|
||||
MinProfitWei int64 `yaml:"min_profit_wei"`
|
||||
MinROIPercent float64 `yaml:"min_roi_percent"`
|
||||
MinSignificantSwapSize int64 `yaml:"min_significant_swap_size"`
|
||||
SlippageTolerance float64 `yaml:"slippage_tolerance"`
|
||||
|
||||
// Scanning configuration
|
||||
MinScanAmountWei int64 `yaml:"min_scan_amount_wei"`
|
||||
MaxScanAmountWei int64 `yaml:"max_scan_amount_wei"`
|
||||
|
||||
// Gas configuration
|
||||
MaxGasPriceWei int64 `yaml:"max_gas_price_wei"`
|
||||
|
||||
// Execution limits
|
||||
MaxConcurrentExecutions int `yaml:"max_concurrent_executions"`
|
||||
MaxOpportunitiesPerEvent int `yaml:"max_opportunities_per_event"`
|
||||
|
||||
// Timing settings
|
||||
OpportunityTTL time.Duration `yaml:"opportunity_ttl"`
|
||||
MaxPathAge time.Duration `yaml:"max_path_age"`
|
||||
StatsUpdateInterval time.Duration `yaml:"stats_update_interval"`
|
||||
|
||||
// Pool discovery configuration
|
||||
PoolDiscoveryConfig PoolDiscoveryConfig `yaml:"pool_discovery"`
|
||||
}
|
||||
|
||||
// PoolDiscoveryConfig represents pool discovery service configuration
|
||||
type PoolDiscoveryConfig struct {
|
||||
// Enable or disable pool discovery
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// Block range to scan for new pools
|
||||
BlockRange uint64 `yaml:"block_range"`
|
||||
|
||||
// Polling interval for new pools
|
||||
PollingInterval time.Duration `yaml:"polling_interval"`
|
||||
|
||||
// DEX factory addresses to monitor
|
||||
FactoryAddresses []string `yaml:"factory_addresses"`
|
||||
|
||||
// Minimum liquidity threshold for pools
|
||||
MinLiquidityWei int64 `yaml:"min_liquidity_wei"`
|
||||
|
||||
// Cache configuration
|
||||
CacheSize int `yaml:"cache_size"`
|
||||
CacheTTL time.Duration `yaml:"cache_ttl"`
|
||||
}
|
||||
|
||||
// Features represents Layer 2 optimization feature flags
|
||||
type Features struct {
|
||||
// Phase 1: Configuration tuning
|
||||
UseArbitrumOptimizedTimeouts bool `yaml:"use_arbitrum_optimized_timeouts"`
|
||||
UseDynamicTTL bool `yaml:"use_dynamic_ttl"`
|
||||
|
||||
// Phase 2: Transaction filtering
|
||||
EnableDEXPrefilter bool `yaml:"enable_dex_prefilter"`
|
||||
|
||||
// Phase 3: Sequencer optimization
|
||||
UseDirectSequencerFeed bool `yaml:"use_direct_sequencer_feed"`
|
||||
|
||||
// Phase 4-5: Timeboost
|
||||
EnableTimeboost bool `yaml:"enable_timeboost"`
|
||||
}
|
||||
|
||||
// ArbitrageOptimizedConfig represents Arbitrum-optimized arbitrage timing
|
||||
type ArbitrageOptimizedConfig struct {
|
||||
// Opportunity lifecycle (tuned for 250ms blocks)
|
||||
OpportunityTTL time.Duration `yaml:"opportunity_ttl"`
|
||||
MaxPathAge time.Duration `yaml:"max_path_age"`
|
||||
ExecutionDeadline time.Duration `yaml:"execution_deadline"`
|
||||
|
||||
// Legacy values for rollback
|
||||
LegacyOpportunityTTL time.Duration `yaml:"legacy_opportunity_ttl"`
|
||||
LegacyMaxPathAge time.Duration `yaml:"legacy_max_path_age"`
|
||||
|
||||
// Dynamic TTL settings
|
||||
DynamicTTL DynamicTTLConfig `yaml:"dynamic_ttl"`
|
||||
}
|
||||
|
||||
// DynamicTTLConfig represents dynamic TTL calculation settings
|
||||
type DynamicTTLConfig struct {
|
||||
MinTTLBlocks int `yaml:"min_ttl_blocks"`
|
||||
MaxTTLBlocks int `yaml:"max_ttl_blocks"`
|
||||
ProfitMultiplier bool `yaml:"profit_multiplier"`
|
||||
VolatilityAdjustment bool `yaml:"volatility_adjustment"`
|
||||
}
|
||||
|
||||
// GetOpportunityTTL returns the active opportunity TTL based on feature flags
|
||||
func (c *Config) GetOpportunityTTL() time.Duration {
|
||||
if c.Features.UseArbitrumOptimizedTimeouts {
|
||||
return c.ArbitrageOptimized.OpportunityTTL
|
||||
}
|
||||
// Fallback to legacy config
|
||||
if c.Arbitrage.OpportunityTTL > 0 {
|
||||
return c.Arbitrage.OpportunityTTL
|
||||
}
|
||||
// Default fallback
|
||||
return 30 * time.Second
|
||||
}
|
||||
|
||||
// GetMaxPathAge returns the active max path age based on feature flags
|
||||
func (c *Config) GetMaxPathAge() time.Duration {
|
||||
if c.Features.UseArbitrumOptimizedTimeouts {
|
||||
return c.ArbitrageOptimized.MaxPathAge
|
||||
}
|
||||
// Fallback to legacy config
|
||||
if c.Arbitrage.MaxPathAge > 0 {
|
||||
return c.Arbitrage.MaxPathAge
|
||||
}
|
||||
// Default fallback
|
||||
return 60 * time.Second
|
||||
}
|
||||
|
||||
// GetExecutionDeadline returns the execution deadline
|
||||
func (c *Config) GetExecutionDeadline() time.Duration {
|
||||
if c.Features.UseArbitrumOptimizedTimeouts && c.ArbitrageOptimized.ExecutionDeadline > 0 {
|
||||
return c.ArbitrageOptimized.ExecutionDeadline
|
||||
}
|
||||
// Default fallback for Arbitrum (12 blocks @ 250ms)
|
||||
return 3 * time.Second
|
||||
}
|
||||
139
orig/internal/config/config_test.go
Normal file
139
orig/internal/config/config_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create a temporary config file for testing
|
||||
tmpFile, err := os.CreateTemp("", "config_test_*.yaml")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
// Write test config content
|
||||
configContent := `
|
||||
arbitrum:
|
||||
rpc_endpoint: "${ARBITRUM_RPC_ENDPOINT}"
|
||||
ws_endpoint: "${ARBITRUM_WS_ENDPOINT}"
|
||||
chain_id: 42161
|
||||
rate_limit:
|
||||
requests_per_second: 5
|
||||
max_concurrent: 3
|
||||
burst: 10
|
||||
|
||||
bot:
|
||||
enabled: true
|
||||
polling_interval: 3
|
||||
min_profit_threshold: 10.0
|
||||
gas_price_multiplier: 1.2
|
||||
max_workers: 3
|
||||
channel_buffer_size: 50
|
||||
rpc_timeout: 30
|
||||
|
||||
uniswap:
|
||||
factory_address: "0x1F98431c8aD98523631AE4a59f267346ea31F984"
|
||||
position_manager_address: "0xC36442b4a4522E871399CD717aBDD847Ab11FE88"
|
||||
fee_tiers:
|
||||
- 500
|
||||
- 3000
|
||||
- 10000
|
||||
cache:
|
||||
enabled: true
|
||||
expiration: 300
|
||||
max_size: 10000
|
||||
|
||||
log:
|
||||
level: "debug"
|
||||
format: "text"
|
||||
file: "logs/mev-bot.log"
|
||||
|
||||
database:
|
||||
file: "mev-bot.db"
|
||||
max_open_connections: 10
|
||||
max_idle_connections: 5
|
||||
`
|
||||
_, err = tmpFile.Write([]byte(configContent))
|
||||
require.NoError(t, err)
|
||||
err = tmpFile.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set environment variables for test
|
||||
os.Setenv("ARBITRUM_RPC_ENDPOINT", "wss://arbitrum-mainnet.core.chainstack.com/53c30e7a941160679fdcc396c894fc57")
|
||||
os.Setenv("ARBITRUM_WS_ENDPOINT", "wss://arbitrum-mainnet.core.chainstack.com/53c30e7a941160679fdcc396c894fc57")
|
||||
defer func() {
|
||||
os.Unsetenv("ARBITRUM_RPC_ENDPOINT")
|
||||
os.Unsetenv("ARBITRUM_WS_ENDPOINT")
|
||||
}()
|
||||
|
||||
// Test loading the config
|
||||
cfg, err := Load(tmpFile.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the loaded config
|
||||
assert.Equal(t, "wss://arbitrum-mainnet.core.chainstack.com/53c30e7a941160679fdcc396c894fc57", cfg.Arbitrum.RPCEndpoint)
|
||||
assert.Equal(t, "wss://arbitrum-mainnet.core.chainstack.com/53c30e7a941160679fdcc396c894fc57", cfg.Arbitrum.WSEndpoint)
|
||||
assert.Equal(t, int64(42161), cfg.Arbitrum.ChainID)
|
||||
assert.Equal(t, 5, cfg.Arbitrum.RateLimit.RequestsPerSecond)
|
||||
assert.True(t, cfg.Bot.Enabled)
|
||||
assert.Equal(t, 3, cfg.Bot.PollingInterval)
|
||||
assert.Equal(t, 10.0, cfg.Bot.MinProfitThreshold)
|
||||
assert.Equal(t, "0x1F98431c8aD98523631AE4a59f267346ea31F984", cfg.Uniswap.FactoryAddress)
|
||||
assert.Len(t, cfg.Uniswap.FeeTiers, 3)
|
||||
assert.Equal(t, true, cfg.Uniswap.Cache.Enabled)
|
||||
assert.Equal(t, "debug", cfg.Log.Level)
|
||||
assert.Equal(t, "logs/mev-bot.log", cfg.Log.File)
|
||||
assert.Equal(t, "mev-bot.db", cfg.Database.File)
|
||||
}
|
||||
|
||||
func TestLoadWithInvalidFile(t *testing.T) {
|
||||
// Test loading a non-existent config file
|
||||
_, err := Load("/non/existent/file.yaml")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOverrideWithEnv(t *testing.T) {
|
||||
// Create a temporary config file for testing
|
||||
tmpFile, err := os.CreateTemp("", "config_test_*.yaml")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
// Write test config content
|
||||
configContent := `
|
||||
arbitrum:
|
||||
rpc_endpoint: "https://arb1.arbitrum.io/rpc"
|
||||
rate_limit:
|
||||
requests_per_second: 10
|
||||
max_concurrent: 5
|
||||
|
||||
bot:
|
||||
max_workers: 10
|
||||
channel_buffer_size: 100
|
||||
`
|
||||
_, err = tmpFile.Write([]byte(configContent))
|
||||
require.NoError(t, err)
|
||||
err = tmpFile.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set environment variables to override config
|
||||
os.Setenv("ARBITRUM_RPC_ENDPOINT", "https://override.arbitrum.io/rpc")
|
||||
os.Setenv("RPC_REQUESTS_PER_SECOND", "20")
|
||||
os.Setenv("BOT_MAX_WORKERS", "20")
|
||||
defer func() {
|
||||
os.Unsetenv("ARBITRUM_RPC_ENDPOINT")
|
||||
os.Unsetenv("RPC_REQUESTS_PER_SECOND")
|
||||
os.Unsetenv("BOT_MAX_WORKERS")
|
||||
}()
|
||||
|
||||
// Load the config
|
||||
cfg, err := Load(tmpFile.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the overridden values
|
||||
assert.Equal(t, "https://override.arbitrum.io/rpc", cfg.Arbitrum.RPCEndpoint)
|
||||
assert.Equal(t, 20, cfg.Arbitrum.RateLimit.RequestsPerSecond)
|
||||
assert.Equal(t, 20, cfg.Bot.MaxWorkers)
|
||||
}
|
||||
349
orig/internal/contracts/detector.go
Normal file
349
orig/internal/contracts/detector.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package contracts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum"
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/ethclient"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// ContractType represents the detected type of a contract
|
||||
type ContractType int
|
||||
|
||||
const (
|
||||
ContractTypeUnknown ContractType = iota
|
||||
ContractTypeERC20Token
|
||||
ContractTypeUniswapV2Pool
|
||||
ContractTypeUniswapV3Pool
|
||||
ContractTypeUniswapV2Router
|
||||
ContractTypeUniswapV3Router
|
||||
ContractTypeUniversalRouter
|
||||
ContractTypeFactory
|
||||
ContractTypeEOA // Externally Owned Account
|
||||
)
|
||||
|
||||
// String returns the string representation of the contract type
|
||||
func (ct ContractType) String() string {
|
||||
switch ct {
|
||||
case ContractTypeERC20Token:
|
||||
return "ERC-20"
|
||||
case ContractTypeUniswapV2Pool:
|
||||
return "UniswapV2Pool"
|
||||
case ContractTypeUniswapV3Pool:
|
||||
return "UniswapV3Pool"
|
||||
case ContractTypeUniswapV2Router:
|
||||
return "UniswapV2Router"
|
||||
case ContractTypeUniswapV3Router:
|
||||
return "UniswapV3Router"
|
||||
case ContractTypeUniversalRouter:
|
||||
return "UniversalRouter"
|
||||
case ContractTypeFactory:
|
||||
return "Factory"
|
||||
case ContractTypeEOA:
|
||||
return "EOA"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// DetectionResult contains the result of contract type detection
|
||||
type DetectionResult struct {
|
||||
ContractType ContractType
|
||||
IsContract bool
|
||||
HasCode bool
|
||||
SupportedFunctions []string
|
||||
Confidence float64 // 0.0 to 1.0
|
||||
Error error
|
||||
Warnings []string
|
||||
}
|
||||
|
||||
// ContractDetector provides runtime contract type detection
|
||||
type ContractDetector struct {
|
||||
client *ethclient.Client
|
||||
logger *logger.Logger
|
||||
cache map[common.Address]*DetectionResult
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewContractDetector creates a new contract detector
|
||||
func NewContractDetector(client *ethclient.Client, logger *logger.Logger) *ContractDetector {
|
||||
return &ContractDetector{
|
||||
client: client,
|
||||
logger: logger,
|
||||
cache: make(map[common.Address]*DetectionResult),
|
||||
timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// DetectContractType determines the type of contract at the given address
|
||||
func (cd *ContractDetector) DetectContractType(ctx context.Context, address common.Address) *DetectionResult {
|
||||
// Check cache first
|
||||
if result, exists := cd.cache[address]; exists {
|
||||
return result
|
||||
}
|
||||
|
||||
result := &DetectionResult{
|
||||
ContractType: ContractTypeUnknown,
|
||||
IsContract: false,
|
||||
HasCode: false,
|
||||
SupportedFunctions: []string{},
|
||||
Confidence: 0.0,
|
||||
Warnings: []string{},
|
||||
}
|
||||
|
||||
// Create context with timeout
|
||||
ctxWithTimeout, cancel := context.WithTimeout(ctx, cd.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Check if address has code (is a contract)
|
||||
code, err := cd.client.CodeAt(ctxWithTimeout, address, nil)
|
||||
if err != nil {
|
||||
result.Error = fmt.Errorf("failed to get code at address: %w", err)
|
||||
cd.cache[address] = result
|
||||
return result
|
||||
}
|
||||
|
||||
// If no code, it's an EOA
|
||||
if len(code) == 0 {
|
||||
result.ContractType = ContractTypeEOA
|
||||
result.IsContract = false
|
||||
result.Confidence = 1.0
|
||||
cd.cache[address] = result
|
||||
return result
|
||||
}
|
||||
|
||||
result.IsContract = true
|
||||
result.HasCode = true
|
||||
|
||||
// Detect contract type by testing function signatures
|
||||
contractType, confidence, supportedFunctions := cd.detectByFunctionSignatures(ctxWithTimeout, address)
|
||||
result.ContractType = contractType
|
||||
result.Confidence = confidence
|
||||
result.SupportedFunctions = supportedFunctions
|
||||
|
||||
// Additional validation for high-confidence detection
|
||||
if confidence > 0.8 {
|
||||
if err := cd.validateContractType(ctxWithTimeout, address, contractType); err != nil {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("validation warning: %v", err))
|
||||
result.Confidence *= 0.8 // Reduce confidence
|
||||
}
|
||||
}
|
||||
|
||||
cd.cache[address] = result
|
||||
return result
|
||||
}
|
||||
|
||||
// detectByFunctionSignatures detects contract type by testing known function signatures
|
||||
func (cd *ContractDetector) detectByFunctionSignatures(ctx context.Context, address common.Address) (ContractType, float64, []string) {
|
||||
supportedFunctions := []string{}
|
||||
scores := make(map[ContractType]float64)
|
||||
|
||||
// Test ERC-20 functions
|
||||
erc20Functions := map[string][]byte{
|
||||
"name()": {0x06, 0xfd, 0xde, 0x03},
|
||||
"symbol()": {0x95, 0xd8, 0x9b, 0x41},
|
||||
"decimals()": {0x31, 0x3c, 0xe5, 0x67},
|
||||
"totalSupply()": {0x18, 0x16, 0x0d, 0xdd},
|
||||
"balanceOf()": {0x70, 0xa0, 0x82, 0x31},
|
||||
}
|
||||
|
||||
erc20Score := cd.testFunctionSignatures(ctx, address, erc20Functions, &supportedFunctions)
|
||||
if erc20Score > 0.6 {
|
||||
scores[ContractTypeERC20Token] = erc20Score
|
||||
}
|
||||
|
||||
// Test Uniswap V2 Pool functions
|
||||
v2PoolFunctions := map[string][]byte{
|
||||
"token0()": {0x0d, 0xfe, 0x16, 0x81},
|
||||
"token1()": {0xd2, 0x12, 0x20, 0xa7},
|
||||
"getReserves()": {0x09, 0x02, 0xf1, 0xac},
|
||||
"price0CumulativeLast()": {0x54, 0x1c, 0x5c, 0xfa},
|
||||
"kLast()": {0x7d, 0xc0, 0xd1, 0xd0},
|
||||
}
|
||||
|
||||
v2PoolScore := cd.testFunctionSignatures(ctx, address, v2PoolFunctions, &supportedFunctions)
|
||||
if v2PoolScore > 0.6 {
|
||||
scores[ContractTypeUniswapV2Pool] = v2PoolScore
|
||||
}
|
||||
|
||||
// Test Uniswap V3 Pool functions
|
||||
v3PoolFunctions := map[string][]byte{
|
||||
"token0()": {0x0d, 0xfe, 0x16, 0x81},
|
||||
"token1()": {0xd2, 0x12, 0x20, 0xa7},
|
||||
"fee()": {0xdd, 0xca, 0x3f, 0x43},
|
||||
"slot0()": {0x38, 0x50, 0xc7, 0xbd},
|
||||
"liquidity()": {0x1a, 0x68, 0x65, 0x0f},
|
||||
"tickSpacing()": {0xd0, 0xc9, 0x32, 0x07},
|
||||
}
|
||||
|
||||
v3PoolScore := cd.testFunctionSignatures(ctx, address, v3PoolFunctions, &supportedFunctions)
|
||||
if v3PoolScore > 0.6 {
|
||||
scores[ContractTypeUniswapV3Pool] = v3PoolScore
|
||||
}
|
||||
|
||||
// Test Router functions
|
||||
routerFunctions := map[string][]byte{
|
||||
"WETH()": {0xad, 0x5c, 0x46, 0x48},
|
||||
"swapExactTokensForTokens()": {0x38, 0xed, 0x17, 0x39},
|
||||
"factory()": {0xc4, 0x5a, 0x01, 0x55},
|
||||
}
|
||||
|
||||
routerScore := cd.testFunctionSignatures(ctx, address, routerFunctions, &supportedFunctions)
|
||||
if routerScore > 0.5 {
|
||||
scores[ContractTypeUniswapV2Router] = routerScore
|
||||
}
|
||||
|
||||
// Find highest scoring type
|
||||
var bestType ContractType = ContractTypeUnknown
|
||||
var bestScore float64 = 0.0
|
||||
|
||||
for contractType, score := range scores {
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestType = contractType
|
||||
}
|
||||
}
|
||||
|
||||
return bestType, bestScore, supportedFunctions
|
||||
}
|
||||
|
||||
// testFunctionSignatures tests if a contract supports given function signatures
|
||||
func (cd *ContractDetector) testFunctionSignatures(ctx context.Context, address common.Address, functions map[string][]byte, supportedFunctions *[]string) float64 {
|
||||
supported := 0
|
||||
total := len(functions)
|
||||
|
||||
for funcName, signature := range functions {
|
||||
// Test the function call
|
||||
_, err := cd.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &address,
|
||||
Data: signature,
|
||||
}, nil)
|
||||
|
||||
if err == nil {
|
||||
supported++
|
||||
*supportedFunctions = append(*supportedFunctions, funcName)
|
||||
} else if !strings.Contains(err.Error(), "execution reverted") {
|
||||
// If it's not a revert, it might be a network error, so we don't count it against
|
||||
total--
|
||||
}
|
||||
}
|
||||
|
||||
if total == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return float64(supported) / float64(total)
|
||||
}
|
||||
|
||||
// validateContractType performs additional validation for detected contract types
|
||||
func (cd *ContractDetector) validateContractType(ctx context.Context, address common.Address, contractType ContractType) error {
|
||||
switch contractType {
|
||||
case ContractTypeERC20Token:
|
||||
return cd.validateERC20(ctx, address)
|
||||
case ContractTypeUniswapV2Pool:
|
||||
return cd.validateUniswapV2Pool(ctx, address)
|
||||
case ContractTypeUniswapV3Pool:
|
||||
return cd.validateUniswapV3Pool(ctx, address)
|
||||
default:
|
||||
return nil // No additional validation for other types
|
||||
}
|
||||
}
|
||||
|
||||
// validateERC20 validates that a contract is actually an ERC-20 token
|
||||
func (cd *ContractDetector) validateERC20(ctx context.Context, address common.Address) error {
|
||||
// Test decimals() - should return a reasonable value (0-18)
|
||||
decimalsData := []byte{0x31, 0x3c, 0xe5, 0x67} // decimals()
|
||||
result, err := cd.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &address,
|
||||
Data: decimalsData,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decimals() call failed: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 32 {
|
||||
decimals := new(big.Int).SetBytes(result).Uint64()
|
||||
if decimals > 18 {
|
||||
return fmt.Errorf("unrealistic decimals value: %d", decimals)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUniswapV2Pool validates that a contract is actually a Uniswap V2 pool
|
||||
func (cd *ContractDetector) validateUniswapV2Pool(ctx context.Context, address common.Address) error {
|
||||
// Test getReserves() - should return 3 values
|
||||
getReservesData := []byte{0x09, 0x02, 0xf1, 0xac} // getReserves()
|
||||
result, err := cd.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &address,
|
||||
Data: getReservesData,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getReserves() call failed: %w", err)
|
||||
}
|
||||
|
||||
// Should return 3 uint112 values (reserves + timestamp)
|
||||
if len(result) != 96 { // 3 * 32 bytes
|
||||
return fmt.Errorf("unexpected getReserves() return length: %d", len(result))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUniswapV3Pool validates that a contract is actually a Uniswap V3 pool
|
||||
func (cd *ContractDetector) validateUniswapV3Pool(ctx context.Context, address common.Address) error {
|
||||
// Test slot0() - should return current state
|
||||
slot0Data := []byte{0x38, 0x50, 0xc7, 0xbd} // slot0()
|
||||
result, err := cd.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &address,
|
||||
Data: slot0Data,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("slot0() call failed: %w", err)
|
||||
}
|
||||
|
||||
// Should return multiple values including sqrtPriceX96
|
||||
if len(result) < 32 {
|
||||
return fmt.Errorf("unexpected slot0() return length: %d", len(result))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsERC20Token checks if an address is an ERC-20 token
|
||||
func (cd *ContractDetector) IsERC20Token(ctx context.Context, address common.Address) bool {
|
||||
result := cd.DetectContractType(ctx, address)
|
||||
return result.ContractType == ContractTypeERC20Token && result.Confidence > 0.7
|
||||
}
|
||||
|
||||
// IsUniswapPool checks if an address is a Uniswap pool (V2 or V3)
|
||||
func (cd *ContractDetector) IsUniswapPool(ctx context.Context, address common.Address) bool {
|
||||
result := cd.DetectContractType(ctx, address)
|
||||
return (result.ContractType == ContractTypeUniswapV2Pool || result.ContractType == ContractTypeUniswapV3Pool) && result.Confidence > 0.7
|
||||
}
|
||||
|
||||
// IsRouter checks if an address is a router contract
|
||||
func (cd *ContractDetector) IsRouter(ctx context.Context, address common.Address) bool {
|
||||
result := cd.DetectContractType(ctx, address)
|
||||
return (result.ContractType == ContractTypeUniswapV2Router ||
|
||||
result.ContractType == ContractTypeUniswapV3Router ||
|
||||
result.ContractType == ContractTypeUniversalRouter) && result.Confidence > 0.7
|
||||
}
|
||||
|
||||
// ClearCache clears the detection cache
|
||||
func (cd *ContractDetector) ClearCache() {
|
||||
cd.cache = make(map[common.Address]*DetectionResult)
|
||||
}
|
||||
|
||||
// GetCacheSize returns the number of cached results
|
||||
func (cd *ContractDetector) GetCacheSize() int {
|
||||
return len(cd.cache)
|
||||
}
|
||||
239
orig/internal/contracts/signature_validator.go
Normal file
239
orig/internal/contracts/signature_validator.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package contracts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// FunctionSignature represents a known function signature
|
||||
type FunctionSignature struct {
|
||||
Name string
|
||||
Selector []byte
|
||||
AllowedTypes []ContractType
|
||||
}
|
||||
|
||||
// SignatureValidator validates function calls against contract types
|
||||
type SignatureValidator struct {
|
||||
detector *ContractDetector
|
||||
logger *logger.Logger
|
||||
signatures map[string]*FunctionSignature
|
||||
}
|
||||
|
||||
// NewSignatureValidator creates a new function signature validator
|
||||
func NewSignatureValidator(detector *ContractDetector, logger *logger.Logger) *SignatureValidator {
|
||||
sv := &SignatureValidator{
|
||||
detector: detector,
|
||||
logger: logger,
|
||||
signatures: make(map[string]*FunctionSignature),
|
||||
}
|
||||
|
||||
// Initialize known function signatures
|
||||
sv.initializeSignatures()
|
||||
return sv
|
||||
}
|
||||
|
||||
// initializeSignatures initializes the known function signatures and their allowed contract types
|
||||
func (sv *SignatureValidator) initializeSignatures() {
|
||||
// ERC-20 token functions
|
||||
sv.signatures["name()"] = &FunctionSignature{
|
||||
Name: "name()",
|
||||
Selector: []byte{0x06, 0xfd, 0xde, 0x03},
|
||||
AllowedTypes: []ContractType{ContractTypeERC20Token},
|
||||
}
|
||||
sv.signatures["symbol()"] = &FunctionSignature{
|
||||
Name: "symbol()",
|
||||
Selector: []byte{0x95, 0xd8, 0x9b, 0x41},
|
||||
AllowedTypes: []ContractType{ContractTypeERC20Token},
|
||||
}
|
||||
sv.signatures["decimals()"] = &FunctionSignature{
|
||||
Name: "decimals()",
|
||||
Selector: []byte{0x31, 0x3c, 0xe5, 0x67},
|
||||
AllowedTypes: []ContractType{ContractTypeERC20Token},
|
||||
}
|
||||
sv.signatures["totalSupply()"] = &FunctionSignature{
|
||||
Name: "totalSupply()",
|
||||
Selector: []byte{0x18, 0x16, 0x0d, 0xdd},
|
||||
AllowedTypes: []ContractType{ContractTypeERC20Token},
|
||||
}
|
||||
sv.signatures["balanceOf()"] = &FunctionSignature{
|
||||
Name: "balanceOf()",
|
||||
Selector: []byte{0x70, 0xa0, 0x82, 0x31},
|
||||
AllowedTypes: []ContractType{ContractTypeERC20Token},
|
||||
}
|
||||
|
||||
// Uniswap V2 Pool functions
|
||||
sv.signatures["token0()"] = &FunctionSignature{
|
||||
Name: "token0()",
|
||||
Selector: []byte{0x0d, 0xfe, 0x16, 0x81},
|
||||
AllowedTypes: []ContractType{
|
||||
ContractTypeUniswapV2Pool,
|
||||
ContractTypeUniswapV3Pool,
|
||||
},
|
||||
}
|
||||
sv.signatures["token1()"] = &FunctionSignature{
|
||||
Name: "token1()",
|
||||
Selector: []byte{0xd2, 0x12, 0x20, 0xa7},
|
||||
AllowedTypes: []ContractType{
|
||||
ContractTypeUniswapV2Pool,
|
||||
ContractTypeUniswapV3Pool,
|
||||
},
|
||||
}
|
||||
sv.signatures["getReserves()"] = &FunctionSignature{
|
||||
Name: "getReserves()",
|
||||
Selector: []byte{0x09, 0x02, 0xf1, 0xac},
|
||||
AllowedTypes: []ContractType{ContractTypeUniswapV2Pool},
|
||||
}
|
||||
|
||||
// Uniswap V3 Pool specific functions
|
||||
sv.signatures["slot0()"] = &FunctionSignature{
|
||||
Name: "slot0()",
|
||||
Selector: []byte{0x38, 0x50, 0xc7, 0xbd},
|
||||
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
|
||||
}
|
||||
sv.signatures["fee()"] = &FunctionSignature{
|
||||
Name: "fee()",
|
||||
Selector: []byte{0xdd, 0xca, 0x3f, 0x43},
|
||||
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
|
||||
}
|
||||
sv.signatures["liquidity()"] = &FunctionSignature{
|
||||
Name: "liquidity()",
|
||||
Selector: []byte{0x1a, 0x68, 0x65, 0x0f},
|
||||
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
|
||||
}
|
||||
sv.signatures["tickSpacing()"] = &FunctionSignature{
|
||||
Name: "tickSpacing()",
|
||||
Selector: []byte{0xd0, 0xc9, 0x32, 0x07},
|
||||
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
|
||||
}
|
||||
|
||||
// Router functions
|
||||
sv.signatures["WETH()"] = &FunctionSignature{
|
||||
Name: "WETH()",
|
||||
Selector: []byte{0xad, 0x5c, 0x46, 0x48},
|
||||
AllowedTypes: []ContractType{
|
||||
ContractTypeUniswapV2Router,
|
||||
ContractTypeUniswapV3Router,
|
||||
},
|
||||
}
|
||||
sv.signatures["factory()"] = &FunctionSignature{
|
||||
Name: "factory()",
|
||||
Selector: []byte{0xc4, 0x5a, 0x01, 0x55},
|
||||
AllowedTypes: []ContractType{
|
||||
ContractTypeUniswapV2Router,
|
||||
ContractTypeUniswapV3Router,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidationResult contains the result of function signature validation
|
||||
type ValidationResult struct {
|
||||
IsValid bool
|
||||
FunctionName string
|
||||
ContractType ContractType
|
||||
Error error
|
||||
Warnings []string
|
||||
}
|
||||
|
||||
// ValidateFunctionCall validates if a function can be called on a contract
|
||||
func (sv *SignatureValidator) ValidateFunctionCall(ctx context.Context, contractAddress common.Address, functionSelector []byte) *ValidationResult {
|
||||
result := &ValidationResult{
|
||||
IsValid: false,
|
||||
Warnings: []string{},
|
||||
}
|
||||
|
||||
// Detect contract type
|
||||
detection := sv.detector.DetectContractType(ctx, contractAddress)
|
||||
result.ContractType = detection.ContractType
|
||||
|
||||
if detection.Error != nil {
|
||||
result.Error = fmt.Errorf("contract type detection failed: %w", detection.Error)
|
||||
return result
|
||||
}
|
||||
|
||||
// Find matching function signature
|
||||
var matchedSignature *FunctionSignature
|
||||
for _, sig := range sv.signatures {
|
||||
if len(sig.Selector) >= 4 && len(functionSelector) >= 4 {
|
||||
if sig.Selector[0] == functionSelector[0] &&
|
||||
sig.Selector[1] == functionSelector[1] &&
|
||||
sig.Selector[2] == functionSelector[2] &&
|
||||
sig.Selector[3] == functionSelector[3] {
|
||||
matchedSignature = sig
|
||||
result.FunctionName = sig.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no signature match found, warn but allow (could be unknown function)
|
||||
if matchedSignature == nil {
|
||||
result.IsValid = true
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("unknown function selector: %x", functionSelector))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check if the detected contract type is allowed for this function
|
||||
allowed := false
|
||||
for _, allowedType := range matchedSignature.AllowedTypes {
|
||||
if detection.ContractType == allowedType {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
result.Error = fmt.Errorf("function %s cannot be called on contract type %s (allowed types: %v)",
|
||||
matchedSignature.Name, detection.ContractType.String(), matchedSignature.AllowedTypes)
|
||||
return result
|
||||
}
|
||||
|
||||
// Check confidence level
|
||||
if detection.Confidence < 0.7 {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("low confidence in contract type detection: %.2f", detection.Confidence))
|
||||
}
|
||||
|
||||
result.IsValid = true
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateToken0Call specifically validates token0() function calls
|
||||
func (sv *SignatureValidator) ValidateToken0Call(ctx context.Context, contractAddress common.Address) *ValidationResult {
|
||||
token0Selector := []byte{0x0d, 0xfe, 0x16, 0x81}
|
||||
return sv.ValidateFunctionCall(ctx, contractAddress, token0Selector)
|
||||
}
|
||||
|
||||
// ValidateToken1Call specifically validates token1() function calls
|
||||
func (sv *SignatureValidator) ValidateToken1Call(ctx context.Context, contractAddress common.Address) *ValidationResult {
|
||||
token1Selector := []byte{0xd2, 0x12, 0x20, 0xa7}
|
||||
return sv.ValidateFunctionCall(ctx, contractAddress, token1Selector)
|
||||
}
|
||||
|
||||
// ValidateGetReservesCall specifically validates getReserves() function calls
|
||||
func (sv *SignatureValidator) ValidateGetReservesCall(ctx context.Context, contractAddress common.Address) *ValidationResult {
|
||||
getReservesSelector := []byte{0x09, 0x02, 0xf1, 0xac}
|
||||
return sv.ValidateFunctionCall(ctx, contractAddress, getReservesSelector)
|
||||
}
|
||||
|
||||
// ValidateSlot0Call specifically validates slot0() function calls for Uniswap V3
|
||||
func (sv *SignatureValidator) ValidateSlot0Call(ctx context.Context, contractAddress common.Address) *ValidationResult {
|
||||
slot0Selector := []byte{0x38, 0x50, 0xc7, 0xbd}
|
||||
return sv.ValidateFunctionCall(ctx, contractAddress, slot0Selector)
|
||||
}
|
||||
|
||||
// GetSupportedFunctions returns the functions supported by a contract type
|
||||
func (sv *SignatureValidator) GetSupportedFunctions(contractType ContractType) []string {
|
||||
var functions []string
|
||||
for _, sig := range sv.signatures {
|
||||
for _, allowedType := range sig.AllowedTypes {
|
||||
if allowedType == contractType {
|
||||
functions = append(functions, sig.Name)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return functions
|
||||
}
|
||||
484
orig/internal/logger/logger.go
Normal file
484
orig/internal/logger/logger.go
Normal file
@@ -0,0 +1,484 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
pkgerrors "github.com/fraktal/mev-beta/pkg/errors"
|
||||
)
|
||||
|
||||
// LogLevel represents different log levels
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
DEBUG LogLevel = iota
|
||||
INFO
|
||||
WARN
|
||||
ERROR
|
||||
OPPORTUNITY // Special level for opportunities
|
||||
)
|
||||
|
||||
var logLevelNames = map[LogLevel]string{
|
||||
DEBUG: "DEBUG",
|
||||
INFO: "INFO",
|
||||
WARN: "WARN",
|
||||
ERROR: "ERROR",
|
||||
OPPORTUNITY: "OPPORTUNITY",
|
||||
}
|
||||
|
||||
var suppressedWarningSubstrings = []string{
|
||||
"extractTokensGeneric",
|
||||
"extractTokensFromMulticall",
|
||||
}
|
||||
|
||||
// Logger represents a multi-file logger with separation of concerns
|
||||
type Logger struct {
|
||||
// Main application logger
|
||||
logger *log.Logger
|
||||
level LogLevel
|
||||
|
||||
// Specialized loggers for different concerns
|
||||
opportunityLogger *log.Logger // MEV opportunities and arbitrage attempts
|
||||
errorLogger *log.Logger // Errors and warnings only
|
||||
performanceLogger *log.Logger // Performance metrics and RPC calls
|
||||
transactionLogger *log.Logger // Detailed transaction analysis
|
||||
|
||||
// Security filtering
|
||||
secureFilter *SecureFilter
|
||||
|
||||
levelName string
|
||||
}
|
||||
|
||||
// parseLogLevel converts string log level to LogLevel enum
|
||||
func parseLogLevel(level string) LogLevel {
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
return DEBUG
|
||||
case "info":
|
||||
return INFO
|
||||
case "warn", "warning":
|
||||
return WARN
|
||||
case "error":
|
||||
return ERROR
|
||||
default:
|
||||
return INFO // Default to INFO level
|
||||
}
|
||||
}
|
||||
|
||||
// createLogFile creates a log file or returns stdout if it fails
|
||||
func createLogFile(filename string) *os.File {
|
||||
if filename == "" {
|
||||
return os.Stdout
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(filename), 0o755); err != nil {
|
||||
log.Printf("Failed to create log directory for %s: %v, falling back to stdout", filename, err)
|
||||
return os.Stdout
|
||||
}
|
||||
|
||||
// Check and rotate log file if needed (100MB max size)
|
||||
maxSize := int64(100 * 1024 * 1024) // 100 MB
|
||||
if err := rotateLogFile(filename, maxSize); err != nil {
|
||||
log.Printf("Failed to rotate log file %s: %v", filename, err)
|
||||
// Continue anyway, rotation failure shouldn't stop logging
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
log.Printf("Failed to create log file %s: %v, falling back to stdout", filename, err)
|
||||
return os.Stdout
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// New creates a new multi-file logger with separation of concerns
|
||||
func New(level string, format string, file string) *Logger {
|
||||
// Parse base filename for specialized logs
|
||||
baseDir := "logs"
|
||||
baseName := "mev_bot"
|
||||
if file != "" {
|
||||
// Extract directory and base filename
|
||||
parts := strings.Split(file, "/")
|
||||
if len(parts) > 1 {
|
||||
baseDir = strings.Join(parts[:len(parts)-1], "/")
|
||||
}
|
||||
filename := parts[len(parts)-1]
|
||||
if strings.Contains(filename, ".") {
|
||||
baseName = strings.Split(filename, ".")[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Create specialized log files
|
||||
mainFile := createLogFile(file)
|
||||
opportunityFile := createLogFile(fmt.Sprintf("%s/%s_opportunities.log", baseDir, baseName))
|
||||
errorFile := createLogFile(fmt.Sprintf("%s/%s_errors.log", baseDir, baseName))
|
||||
performanceFile := createLogFile(fmt.Sprintf("%s/%s_performance.log", baseDir, baseName))
|
||||
transactionFile := createLogFile(fmt.Sprintf("%s/%s_transactions.log", baseDir, baseName))
|
||||
|
||||
// Create loggers with no prefixes (we format ourselves)
|
||||
logLevel := parseLogLevel(level)
|
||||
|
||||
// Determine security level based on environment and log level
|
||||
var securityLevel SecurityLevel
|
||||
env := os.Getenv("GO_ENV")
|
||||
switch {
|
||||
case env == "production":
|
||||
securityLevel = SecurityLevelProduction
|
||||
case logLevel >= WARN:
|
||||
securityLevel = SecurityLevelInfo
|
||||
default:
|
||||
securityLevel = SecurityLevelDebug
|
||||
}
|
||||
|
||||
return &Logger{
|
||||
logger: log.New(mainFile, "", 0),
|
||||
opportunityLogger: log.New(opportunityFile, "", 0),
|
||||
errorLogger: log.New(errorFile, "", 0),
|
||||
performanceLogger: log.New(performanceFile, "", 0),
|
||||
transactionLogger: log.New(transactionFile, "", 0),
|
||||
level: logLevel,
|
||||
secureFilter: NewSecureFilter(securityLevel),
|
||||
levelName: level,
|
||||
}
|
||||
}
|
||||
|
||||
// shouldLog determines if a message should be logged based on level
|
||||
func (l *Logger) shouldLog(level LogLevel) bool {
|
||||
return level >= l.level
|
||||
}
|
||||
|
||||
// formatMessage formats a log message with timestamp and level
|
||||
func (l *Logger) formatMessage(level LogLevel, v ...interface{}) string {
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
levelName := logLevelNames[level]
|
||||
message := formatKVMessage(v...)
|
||||
return fmt.Sprintf("%s [%s] %s", timestamp, levelName, message)
|
||||
}
|
||||
|
||||
// formatKVMessage converts a variadic list of arguments into a structured log string.
|
||||
// It treats consecutive key/value pairs (string key followed by any value) specially
|
||||
// so that existing logger calls like logger.Error("msg", "key", value) render as
|
||||
// `msg key=value`.
|
||||
func formatKVMessage(args ...interface{}) string {
|
||||
if len(args) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
// Always print the first argument verbatim to preserve legacy formatting.
|
||||
fmt.Fprintf(&b, "%v", args[0])
|
||||
|
||||
// Process subsequent arguments as key/value pairs where possible.
|
||||
for i := 1; i < len(args); i++ {
|
||||
key, ok := args[i].(string)
|
||||
if !ok || i == len(args)-1 {
|
||||
// Not a key/value pair, fall back to simple spacing.
|
||||
fmt.Fprintf(&b, " %v", args[i])
|
||||
continue
|
||||
}
|
||||
|
||||
value := args[i+1]
|
||||
fmt.Fprintf(&b, " %s=%v", key, value)
|
||||
i++
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func (l *Logger) Debug(v ...interface{}) {
|
||||
if l.shouldLog(DEBUG) {
|
||||
l.logger.Println(l.formatMessage(DEBUG, v...))
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func (l *Logger) Info(v ...interface{}) {
|
||||
if l.shouldLog(INFO) {
|
||||
l.logger.Println(l.formatMessage(INFO, v...))
|
||||
}
|
||||
}
|
||||
|
||||
// Warn logs a warning message
|
||||
func (l *Logger) Warn(v ...interface{}) {
|
||||
if l.shouldLog(WARN) {
|
||||
message := l.formatMessage(WARN, v...)
|
||||
for _, substr := range suppressedWarningSubstrings {
|
||||
if strings.Contains(message, substr) {
|
||||
return
|
||||
}
|
||||
}
|
||||
l.logger.Println(message)
|
||||
l.errorLogger.Println(message) // Also log to error file
|
||||
}
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func (l *Logger) Error(v ...interface{}) {
|
||||
if l.shouldLog(ERROR) {
|
||||
message := l.formatMessage(ERROR, v...)
|
||||
l.logger.Println(message)
|
||||
l.errorLogger.Println(message) // Also log to error file
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorStructured logs a structured error with full context
|
||||
func (l *Logger) ErrorStructured(err *pkgerrors.StructuredError) {
|
||||
if !l.shouldLog(ERROR) {
|
||||
return
|
||||
}
|
||||
|
||||
// Log compact format to main log
|
||||
compactMsg := fmt.Sprintf("%s [%s] %s",
|
||||
time.Now().Format("2006/01/02 15:04:05"),
|
||||
"ERROR",
|
||||
err.FormatCompact())
|
||||
l.logger.Println(compactMsg)
|
||||
|
||||
// Log full detailed format to error log
|
||||
fullMsg := fmt.Sprintf("%s [%s] %s",
|
||||
time.Now().Format("2006/01/02 15:04:05"),
|
||||
"ERROR",
|
||||
err.FormatForLogging())
|
||||
l.errorLogger.Println(fullMsg)
|
||||
}
|
||||
|
||||
// WarnStructured logs a structured warning with full context
|
||||
func (l *Logger) WarnStructured(err *pkgerrors.StructuredError) {
|
||||
if !l.shouldLog(WARN) {
|
||||
return
|
||||
}
|
||||
|
||||
// Log compact format to main log
|
||||
compactMsg := fmt.Sprintf("%s [%s] %s",
|
||||
time.Now().Format("2006/01/02 15:04:05"),
|
||||
"WARN",
|
||||
err.FormatCompact())
|
||||
|
||||
// Check if warning should be suppressed
|
||||
for _, substr := range suppressedWarningSubstrings {
|
||||
if strings.Contains(compactMsg, substr) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
l.logger.Println(compactMsg)
|
||||
l.errorLogger.Println(compactMsg)
|
||||
}
|
||||
|
||||
// Opportunity logs a found opportunity with detailed information
|
||||
// This always logs regardless of level since opportunities are critical
|
||||
func (l *Logger) Opportunity(txHash, from, to, method, protocol string, amountIn, amountOut, minOut, profitUSD float64, additionalData map[string]interface{}) {
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
|
||||
// Create sanitized additional data for production
|
||||
sanitizedData := l.secureFilter.SanitizeForProduction(additionalData)
|
||||
|
||||
message := fmt.Sprintf(`%s [OPPORTUNITY] 🎯 ARBITRAGE OPPORTUNITY DETECTED
|
||||
├── Transaction: %s
|
||||
├── From: %s → To: %s
|
||||
├── Method: %s (%s)
|
||||
├── Amount In: %.6f tokens
|
||||
├── Amount Out: %.6f tokens
|
||||
├── Min Out: %.6f tokens
|
||||
├── Estimated Profit: $%.2f USD
|
||||
└── Additional Data: %v`,
|
||||
timestamp, txHash, from, to, method, protocol,
|
||||
amountIn, amountOut, minOut, profitUSD, sanitizedData)
|
||||
|
||||
// Apply security filtering to the entire message
|
||||
filteredMessage := l.secureFilter.FilterMessage(message)
|
||||
|
||||
l.logger.Println(filteredMessage)
|
||||
l.opportunityLogger.Println(filteredMessage) // Dedicated opportunity log
|
||||
}
|
||||
|
||||
// OpportunitySimple logs a simple opportunity message (for backwards compatibility)
|
||||
func (l *Logger) OpportunitySimple(v ...interface{}) {
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
message := fmt.Sprintf("%s [OPPORTUNITY] %s", timestamp, fmt.Sprint(v...))
|
||||
l.logger.Println(message)
|
||||
l.opportunityLogger.Println(message) // Dedicated opportunity log
|
||||
}
|
||||
|
||||
// Performance logs performance metrics for optimization analysis
|
||||
func (l *Logger) Performance(component, operation string, duration time.Duration, metadata map[string]interface{}) {
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
|
||||
// Add standard performance fields
|
||||
data := map[string]interface{}{
|
||||
"component": component,
|
||||
"operation": operation,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
"duration_ns": duration.Nanoseconds(),
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
// Merge with provided metadata
|
||||
for k, v := range metadata {
|
||||
data[k] = v
|
||||
}
|
||||
|
||||
message := fmt.Sprintf(`%s [PERFORMANCE] 📊 %s.%s completed in %v - %v`,
|
||||
timestamp, component, operation, duration, data)
|
||||
|
||||
l.performanceLogger.Println(message) // Dedicated performance log only
|
||||
}
|
||||
|
||||
// Metrics logs business metrics for analysis
|
||||
func (l *Logger) Metrics(name string, value float64, unit string, tags map[string]string) {
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
|
||||
message := fmt.Sprintf(`%s [METRICS] 📈 %s: %.6f %s %v`,
|
||||
timestamp, name, value, unit, tags)
|
||||
|
||||
l.performanceLogger.Println(message) // Metrics go to performance log
|
||||
}
|
||||
|
||||
// Transaction logs detailed transaction information for MEV analysis
|
||||
func (l *Logger) Transaction(txHash, from, to, method, protocol string, gasUsed, gasPrice uint64, value float64, success bool, metadata map[string]interface{}) {
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
|
||||
status := "FAILED"
|
||||
if success {
|
||||
status = "SUCCESS"
|
||||
}
|
||||
|
||||
// Sanitize metadata for production
|
||||
sanitizedMetadata := l.secureFilter.SanitizeForProduction(metadata)
|
||||
|
||||
message := fmt.Sprintf(`%s [TRANSACTION] 💳 %s
|
||||
├── Hash: %s
|
||||
├── From: %s → To: %s
|
||||
├── Method: %s (%s)
|
||||
├── Gas Used: %d (Price: %d wei)
|
||||
├── Value: %.6f ETH
|
||||
├── Status: %s
|
||||
└── Metadata: %v`,
|
||||
timestamp, status, txHash, from, to, method, protocol,
|
||||
gasUsed, gasPrice, value, status, sanitizedMetadata)
|
||||
|
||||
// Apply security filtering to the entire message
|
||||
filteredMessage := l.secureFilter.FilterMessage(message)
|
||||
|
||||
l.transactionLogger.Println(filteredMessage) // Dedicated transaction log only
|
||||
}
|
||||
|
||||
// BlockProcessing logs block processing metrics for sequencer monitoring
|
||||
func (l *Logger) BlockProcessing(blockNumber uint64, txCount, dexTxCount int, processingTime time.Duration) {
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
|
||||
message := fmt.Sprintf(`%s [BLOCK_PROCESSING] 🧱 Block %d: %d txs (%d DEX) processed in %v`,
|
||||
timestamp, blockNumber, txCount, dexTxCount, processingTime)
|
||||
|
||||
l.performanceLogger.Println(message) // Block processing metrics go to performance log
|
||||
}
|
||||
|
||||
// ArbitrageAnalysis logs arbitrage opportunity analysis results
|
||||
func (l *Logger) ArbitrageAnalysis(poolA, poolB, tokenPair string, priceA, priceB, priceDiff, estimatedProfit float64, feasible bool) {
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
|
||||
status := "REJECTED"
|
||||
if feasible {
|
||||
status = "VIABLE"
|
||||
}
|
||||
|
||||
message := fmt.Sprintf(`%s [ARBITRAGE_ANALYSIS] 🔍 %s %s
|
||||
├── Pool A: %s (Price: %.6f)
|
||||
├── Pool B: %s (Price: %.6f)
|
||||
├── Price Difference: %.4f%%
|
||||
├── Estimated Profit: $%.2f
|
||||
└── Status: %s`,
|
||||
timestamp, status, tokenPair, poolA, priceA, poolB, priceB,
|
||||
priceDiff*100, estimatedProfit, status)
|
||||
|
||||
// Apply security filtering to protect sensitive pricing data
|
||||
filteredMessage := l.secureFilter.FilterMessage(message)
|
||||
|
||||
l.opportunityLogger.Println(filteredMessage) // Arbitrage analysis goes to opportunity log
|
||||
}
|
||||
|
||||
// RPC logs RPC call metrics for endpoint optimization
|
||||
func (l *Logger) RPC(endpoint, method string, duration time.Duration, success bool, errorMsg string) {
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
|
||||
status := "SUCCESS"
|
||||
if !success {
|
||||
status = "FAILED"
|
||||
}
|
||||
|
||||
message := fmt.Sprintf(`%s [RPC] 🌐 %s %s.%s in %v`,
|
||||
timestamp, status, endpoint, method, duration)
|
||||
|
||||
if !success && errorMsg != "" {
|
||||
message += fmt.Sprintf(" - Error: %s", errorMsg)
|
||||
}
|
||||
|
||||
l.performanceLogger.Println(message) // RPC metrics go to performance log
|
||||
}
|
||||
|
||||
// SwapAnalysis logs swap event analysis with security filtering
|
||||
func (l *Logger) SwapAnalysis(tokenIn, tokenOut string, amountIn, amountOut float64, protocol, poolAddr string, metadata map[string]interface{}) {
|
||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
|
||||
// Sanitize metadata for production
|
||||
sanitizedMetadata := l.secureFilter.SanitizeForProduction(metadata)
|
||||
|
||||
message := fmt.Sprintf(`%s [SWAP_ANALYSIS] 🔄 %s → %s
|
||||
├── Amount In: %.6f %s
|
||||
├── Amount Out: %.6f %s
|
||||
├── Protocol: %s
|
||||
├── Pool: %s
|
||||
└── Metadata: %v`,
|
||||
timestamp, tokenIn, tokenOut, amountIn, tokenIn, amountOut, tokenOut,
|
||||
protocol, poolAddr, sanitizedMetadata)
|
||||
|
||||
// Apply security filtering to the entire message
|
||||
filteredMessage := l.secureFilter.FilterMessage(message)
|
||||
|
||||
l.transactionLogger.Println(filteredMessage) // Dedicated transaction log
|
||||
}
|
||||
|
||||
// rotateLogFile rotates a log file when it exceeds the maximum size
|
||||
func rotateLogFile(filename string, maxSize int64) error {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(filename); os.IsNotExist(err) {
|
||||
return nil // File doesn't exist, nothing to rotate
|
||||
}
|
||||
|
||||
// Get file info
|
||||
fileInfo, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file info: %w", err)
|
||||
}
|
||||
|
||||
// Check if file exceeds max size
|
||||
if fileInfo.Size() < maxSize {
|
||||
return nil // File is within size limits
|
||||
}
|
||||
|
||||
// Create archive directory if it doesn't exist
|
||||
archiveDir := "logs/archived"
|
||||
if err := os.MkdirAll(archiveDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create archive directory: %w", err)
|
||||
}
|
||||
|
||||
// Generate archive filename with timestamp
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
baseName := filepath.Base(filename)
|
||||
ext := filepath.Ext(baseName)
|
||||
name := strings.TrimSuffix(baseName, ext)
|
||||
archiveFilename := filepath.Join(archiveDir, fmt.Sprintf("%s_%s%s", name, timestamp, ext))
|
||||
|
||||
// Close current file handle and rename
|
||||
if err := os.Rename(filename, archiveFilename); err != nil {
|
||||
return fmt.Errorf("failed to rotate log file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
243
orig/internal/logger/logger_test.go
Normal file
243
orig/internal/logger/logger_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewLogger(t *testing.T) {
|
||||
// Test creating a logger with stdout
|
||||
logger := New("info", "text", "")
|
||||
assert.NotNil(t, logger)
|
||||
assert.NotNil(t, logger.logger)
|
||||
assert.Equal(t, "info", logger.levelName)
|
||||
}
|
||||
|
||||
func TestNewLoggerWithFile(t *testing.T) {
|
||||
// Create a temporary file for testing
|
||||
tmpFile, err := os.CreateTemp("", "logger_test_*.log")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
err = tmpFile.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test creating a logger with a file
|
||||
logger := New("info", "text", tmpFile.Name())
|
||||
assert.NotNil(t, logger)
|
||||
assert.Equal(t, "info", logger.levelName)
|
||||
}
|
||||
|
||||
func TestDebug(t *testing.T) {
|
||||
// Capture stdout
|
||||
old := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// Create logger with debug level
|
||||
logger := New("debug", "text", "")
|
||||
|
||||
// Log a debug message
|
||||
logger.Debug("test debug message")
|
||||
|
||||
// Restore stdout
|
||||
w.Close()
|
||||
os.Stdout = old
|
||||
|
||||
// Read the captured output
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
output := buf.String()
|
||||
|
||||
// Check that the log message contains the expected content with brackets
|
||||
assert.Contains(t, output, "[DEBUG] test debug message")
|
||||
}
|
||||
|
||||
func TestDebugWithInfoLevel(t *testing.T) {
|
||||
// Capture stdout
|
||||
old := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// Create logger with info level (should not log debug messages)
|
||||
logger := New("info", "text", "")
|
||||
|
||||
// Log a debug message
|
||||
logger.Debug("test debug message")
|
||||
|
||||
// Restore stdout
|
||||
w.Close()
|
||||
os.Stdout = old
|
||||
|
||||
// Read the captured output
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
output := buf.String()
|
||||
|
||||
// Verify the output does not contain the debug message
|
||||
assert.NotContains(t, output, "DEBUG:")
|
||||
assert.NotContains(t, output, "test debug message")
|
||||
}
|
||||
|
||||
func TestInfo(t *testing.T) {
|
||||
// Capture stdout
|
||||
old := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// Create logger with info level
|
||||
logger := New("info", "text", "")
|
||||
|
||||
// Log an info message
|
||||
logger.Info("test info message")
|
||||
|
||||
// Restore stdout
|
||||
w.Close()
|
||||
os.Stdout = old
|
||||
|
||||
// Read the captured output
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
output := buf.String()
|
||||
|
||||
// Verify the output contains the info message
|
||||
assert.Contains(t, output, "[INFO]")
|
||||
assert.Contains(t, output, "test info message")
|
||||
}
|
||||
|
||||
func TestInfoWithDebugLevel(t *testing.T) {
|
||||
// Capture stdout
|
||||
old := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// Create logger with debug level
|
||||
logger := New("debug", "text", "")
|
||||
|
||||
// Log an info message
|
||||
logger.Info("test info message")
|
||||
|
||||
// Restore stdout
|
||||
w.Close()
|
||||
os.Stdout = old
|
||||
|
||||
// Read the captured output
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
output := buf.String()
|
||||
|
||||
// Verify the output contains the info message
|
||||
assert.Contains(t, output, "[INFO]")
|
||||
assert.Contains(t, output, "test info message")
|
||||
}
|
||||
|
||||
func TestWarn(t *testing.T) {
|
||||
// Capture stdout
|
||||
old := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// Create logger with warn level
|
||||
logger := New("warn", "text", "")
|
||||
|
||||
// Log a warning message
|
||||
logger.Warn("test warn message")
|
||||
|
||||
// Restore stdout
|
||||
w.Close()
|
||||
os.Stdout = old
|
||||
|
||||
// Read the captured output
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
output := buf.String()
|
||||
|
||||
// Verify the output contains the warning message
|
||||
assert.Contains(t, output, "[WARN]")
|
||||
assert.Contains(t, output, "test warn message")
|
||||
}
|
||||
|
||||
func TestWarnWithInfoLevel(t *testing.T) {
|
||||
// Capture stdout
|
||||
old := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// Create logger with info level (should log warnings)
|
||||
logger := New("info", "text", "")
|
||||
|
||||
// Log a warning message
|
||||
logger.Warn("test warn message")
|
||||
|
||||
// Restore stdout
|
||||
w.Close()
|
||||
os.Stdout = old
|
||||
|
||||
// Read the captured output
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
output := buf.String()
|
||||
|
||||
// Verify the output contains the warning message
|
||||
assert.Contains(t, output, "[WARN]")
|
||||
assert.Contains(t, output, "test warn message")
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
// Capture stdout
|
||||
old := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// Create logger
|
||||
logger := New("error", "text", "")
|
||||
|
||||
// Log an error message
|
||||
logger.Error("test error message")
|
||||
|
||||
// Restore stdout
|
||||
w.Close()
|
||||
os.Stdout = old
|
||||
|
||||
// Read the captured output
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
output := buf.String()
|
||||
|
||||
// Verify the output contains the error message
|
||||
assert.Contains(t, output, "[ERROR]")
|
||||
assert.Contains(t, output, "test error message")
|
||||
}
|
||||
|
||||
func TestErrorWithAllLevels(t *testing.T) {
|
||||
// Test that error messages are logged at all levels
|
||||
levels := []string{"debug", "info", "warn", "error"}
|
||||
for _, level := range levels {
|
||||
// Capture stdout
|
||||
old := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// Create logger with current level
|
||||
logger := New(level, "text", "")
|
||||
|
||||
// Log an error message
|
||||
logger.Error("test error message")
|
||||
|
||||
// Restore stdout
|
||||
w.Close()
|
||||
os.Stdout = old
|
||||
|
||||
// Read the captured output
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
output := buf.String()
|
||||
|
||||
// Verify the output contains the error message
|
||||
assert.Contains(t, output, "[ERROR]")
|
||||
assert.Contains(t, output, "test error message")
|
||||
}
|
||||
}
|
||||
241
orig/internal/logger/secure_audit.go
Normal file
241
orig/internal/logger/secure_audit.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FilterMessageEnhanced provides comprehensive filtering with audit logging
|
||||
func (sf *SecureFilter) FilterMessageEnhanced(message string, context map[string]interface{}) string {
|
||||
originalMessage := message
|
||||
filtered := sf.FilterMessage(message)
|
||||
|
||||
// Audit sensitive data detection if enabled
|
||||
if sf.auditEnabled {
|
||||
auditData := sf.detectSensitiveData(originalMessage, context)
|
||||
if len(auditData) > 0 {
|
||||
sf.logAuditEvent(auditData)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// detectSensitiveData identifies and catalogs sensitive data found in messages
|
||||
func (sf *SecureFilter) detectSensitiveData(message string, context map[string]interface{}) map[string]interface{} {
|
||||
detected := make(map[string]interface{})
|
||||
detected["timestamp"] = time.Now().UTC().Format(time.RFC3339)
|
||||
detected["security_level"] = sf.level
|
||||
|
||||
if context != nil {
|
||||
detected["context"] = context
|
||||
}
|
||||
|
||||
// Check for different types of sensitive data
|
||||
sensitiveTypes := []string{}
|
||||
|
||||
// Check for private keys (CRITICAL)
|
||||
for _, pattern := range sf.privateKeyPatterns {
|
||||
if pattern.MatchString(message) {
|
||||
sensitiveTypes = append(sensitiveTypes, "private_key")
|
||||
detected["severity"] = "CRITICAL"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check for transaction hashes BEFORE addresses (64 chars vs 40 chars)
|
||||
for _, pattern := range sf.hashPatterns {
|
||||
if pattern.MatchString(message) {
|
||||
sensitiveTypes = append(sensitiveTypes, "transaction_hash")
|
||||
if detected["severity"] == nil {
|
||||
detected["severity"] = "LOW"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check for addresses AFTER hashes
|
||||
for _, pattern := range sf.addressPatterns {
|
||||
if pattern.MatchString(message) {
|
||||
sensitiveTypes = append(sensitiveTypes, "address")
|
||||
if detected["severity"] == nil {
|
||||
detected["severity"] = "MEDIUM"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check for amounts/values
|
||||
for _, pattern := range sf.amountPatterns {
|
||||
if pattern.MatchString(message) {
|
||||
sensitiveTypes = append(sensitiveTypes, "amount")
|
||||
if detected["severity"] == nil {
|
||||
detected["severity"] = "LOW"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(sensitiveTypes) > 0 {
|
||||
detected["types"] = sensitiveTypes
|
||||
detected["message_length"] = len(message)
|
||||
detected["filtered_length"] = len(sf.FilterMessage(message))
|
||||
return detected
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logAuditEvent logs sensitive data detection events
|
||||
func (sf *SecureFilter) logAuditEvent(auditData map[string]interface{}) {
|
||||
// Create audit log entry
|
||||
auditEntry := map[string]interface{}{
|
||||
"event_type": "sensitive_data_detected",
|
||||
"timestamp": auditData["timestamp"],
|
||||
"severity": auditData["severity"],
|
||||
"types": auditData["types"],
|
||||
"message_stats": map[string]interface{}{
|
||||
"original_length": auditData["message_length"],
|
||||
"filtered_length": auditData["filtered_length"],
|
||||
},
|
||||
}
|
||||
|
||||
if auditData["context"] != nil {
|
||||
auditEntry["context"] = auditData["context"]
|
||||
}
|
||||
|
||||
// Encrypt audit data if encryption is enabled
|
||||
if sf.auditEncryption && len(sf.encryptionKey) > 0 {
|
||||
encrypted, err := sf.encryptAuditData(auditEntry)
|
||||
if err == nil {
|
||||
auditEntry = map[string]interface{}{
|
||||
"encrypted": true,
|
||||
"data": encrypted,
|
||||
"timestamp": auditData["timestamp"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log to audit trail (this would typically go to a separate audit log file)
|
||||
// For now, we'll add it to a structured format that can be easily extracted
|
||||
auditJSON, _ := json.Marshal(auditEntry)
|
||||
fmt.Printf("AUDIT_LOG: %s\n", string(auditJSON))
|
||||
}
|
||||
|
||||
// encryptAuditData encrypts sensitive audit data
|
||||
func (sf *SecureFilter) encryptAuditData(data map[string]interface{}) (string, error) {
|
||||
if len(sf.encryptionKey) == 0 {
|
||||
return "", fmt.Errorf("no encryption key available")
|
||||
}
|
||||
|
||||
// Serialize data to JSON
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal audit data: %w", err)
|
||||
}
|
||||
|
||||
// Create AES-GCM cipher (AEAD - authenticated encryption)
|
||||
key := sha256.Sum256(sf.encryptionKey)
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM instance
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
// Generate random nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt and authenticate data
|
||||
encrypted := gcm.Seal(nonce, nonce, jsonData, nil)
|
||||
return hex.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
// DecryptAuditData decrypts audit data (for authorized access)
|
||||
func (sf *SecureFilter) DecryptAuditData(encryptedHex string) (map[string]interface{}, error) {
|
||||
if len(sf.encryptionKey) == 0 {
|
||||
return nil, fmt.Errorf("no encryption key available")
|
||||
}
|
||||
|
||||
// Decode hex
|
||||
encryptedData, err := hex.DecodeString(encryptedHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode hex: %w", err)
|
||||
}
|
||||
|
||||
// Create AES-GCM cipher (AEAD - authenticated encryption)
|
||||
key := sha256.Sum256(sf.encryptionKey)
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM instance
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
// Check minimum length (nonce + encrypted data + tag)
|
||||
if len(encryptedData) < gcm.NonceSize() {
|
||||
return nil, fmt.Errorf("encrypted data too short")
|
||||
}
|
||||
|
||||
// Extract nonce and encrypted data
|
||||
nonce := encryptedData[:gcm.NonceSize()]
|
||||
ciphertext := encryptedData[gcm.NonceSize():]
|
||||
|
||||
// Decrypt and authenticate data
|
||||
decrypted, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt data: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal JSON
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(decrypted, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal decrypted data: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EnableAuditLogging enables audit logging with optional encryption
|
||||
func (sf *SecureFilter) EnableAuditLogging(encryptionKey []byte) {
|
||||
sf.auditEnabled = true
|
||||
if len(encryptionKey) > 0 {
|
||||
sf.encryptionKey = encryptionKey
|
||||
sf.auditEncryption = true
|
||||
}
|
||||
}
|
||||
|
||||
// DisableAuditLogging disables audit logging
|
||||
func (sf *SecureFilter) DisableAuditLogging() {
|
||||
sf.auditEnabled = false
|
||||
sf.auditEncryption = false
|
||||
}
|
||||
|
||||
// SetSecurityLevel changes the security level dynamically
|
||||
func (sf *SecureFilter) SetSecurityLevel(level SecurityLevel) {
|
||||
sf.level = level
|
||||
}
|
||||
|
||||
// GetSecurityLevel returns the current security level
|
||||
func (sf *SecureFilter) GetSecurityLevel() SecurityLevel {
|
||||
return sf.level
|
||||
}
|
||||
301
orig/internal/logger/secure_filter.go
Normal file
301
orig/internal/logger/secure_filter.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// SecurityLevel defines the logging security level
|
||||
type SecurityLevel int
|
||||
|
||||
const (
|
||||
SecurityLevelDebug SecurityLevel = iota // Log everything (development only)
|
||||
SecurityLevelInfo // Log basic info, filter amounts
|
||||
SecurityLevelProduction // Log minimal info, filter sensitive data
|
||||
)
|
||||
|
||||
// SecureFilter filters sensitive data from log messages
|
||||
type SecureFilter struct {
|
||||
level SecurityLevel
|
||||
|
||||
// Patterns to detect sensitive data
|
||||
amountPatterns []*regexp.Regexp
|
||||
addressPatterns []*regexp.Regexp
|
||||
valuePatterns []*regexp.Regexp
|
||||
hashPatterns []*regexp.Regexp
|
||||
privateKeyPatterns []*regexp.Regexp
|
||||
encryptionKey []byte
|
||||
auditEnabled bool
|
||||
auditEncryption bool
|
||||
}
|
||||
|
||||
// SecureFilterConfig contains configuration for the secure filter
|
||||
type SecureFilterConfig struct {
|
||||
Level SecurityLevel
|
||||
EncryptionKey []byte
|
||||
AuditEnabled bool
|
||||
AuditEncryption bool
|
||||
}
|
||||
|
||||
// NewSecureFilter creates a new secure filter with enhanced configuration
|
||||
func NewSecureFilter(level SecurityLevel) *SecureFilter {
|
||||
return NewSecureFilterWithConfig(&SecureFilterConfig{
|
||||
Level: level,
|
||||
AuditEnabled: false,
|
||||
AuditEncryption: false,
|
||||
})
|
||||
}
|
||||
|
||||
// NewSecureFilterWithConfig creates a new secure filter with full configuration
|
||||
func NewSecureFilterWithConfig(config *SecureFilterConfig) *SecureFilter {
|
||||
return &SecureFilter{
|
||||
level: config.Level,
|
||||
encryptionKey: config.EncryptionKey,
|
||||
auditEnabled: config.AuditEnabled,
|
||||
auditEncryption: config.AuditEncryption,
|
||||
amountPatterns: []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)amount[^=]*=\s*[0-9]+`),
|
||||
regexp.MustCompile(`(?i)value[^=]*=\s*[0-9]+`),
|
||||
regexp.MustCompile(`\$[0-9]+\.?[0-9]*`),
|
||||
regexp.MustCompile(`[0-9]+\.[0-9]+ USD`),
|
||||
regexp.MustCompile(`(?i)amountIn=[0-9]+`),
|
||||
regexp.MustCompile(`(?i)amountOut=[0-9]+`),
|
||||
regexp.MustCompile(`(?i)balance[^=]*=\s*[0-9]+`),
|
||||
regexp.MustCompile(`(?i)profit[^=]*=\s*[0-9]+`),
|
||||
regexp.MustCompile(`(?i)gas[Pp]rice[^=]*=\s*[0-9]+`),
|
||||
regexp.MustCompile(`\b[0-9]{15,}\b`), // Very large numbers likely to be wei amounts (but not hex addresses)
|
||||
},
|
||||
addressPatterns: []*regexp.Regexp{
|
||||
regexp.MustCompile(`0x[a-fA-F0-9]{40}`),
|
||||
regexp.MustCompile(`(?i)address[^=]*=\s*0x[a-fA-F0-9]{40}`),
|
||||
regexp.MustCompile(`(?i)contract[^=]*=\s*0x[a-fA-F0-9]{40}`),
|
||||
regexp.MustCompile(`(?i)token[^=]*=\s*0x[a-fA-F0-9]{40}`),
|
||||
},
|
||||
valuePatterns: []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)value:\s*\$[0-9]+\.?[0-9]*`),
|
||||
regexp.MustCompile(`(?i)profit[^=]*=\s*\$?[0-9]+\.?[0-9]*`),
|
||||
regexp.MustCompile(`(?i)total:\s*\$[0-9]+\.?[0-9]*`),
|
||||
regexp.MustCompile(`(?i)revenue[^=]*=\s*\$?[0-9]+\.?[0-9]*`),
|
||||
regexp.MustCompile(`(?i)fee[^=]*=\s*\$?[0-9]+\.?[0-9]*`),
|
||||
},
|
||||
hashPatterns: []*regexp.Regexp{
|
||||
regexp.MustCompile(`0x[a-fA-F0-9]{64}`), // Transaction hashes
|
||||
regexp.MustCompile(`(?i)txHash[^=]*=\s*0x[a-fA-F0-9]{64}`),
|
||||
regexp.MustCompile(`(?i)blockHash[^=]*=\s*0x[a-fA-F0-9]{64}`),
|
||||
},
|
||||
privateKeyPatterns: []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)private[_\s]*key[^=]*=\s*[a-fA-F0-9]{64}`),
|
||||
regexp.MustCompile(`(?i)secret[^=]*=\s*[a-fA-F0-9]{64}`),
|
||||
regexp.MustCompile(`(?i)mnemonic[^=]*=\s*\"[^\"]*\"`),
|
||||
regexp.MustCompile(`(?i)seed[^=]*=\s*\"[^\"]*\"`),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// FilterMessage filters sensitive data from a log message with enhanced input sanitization
|
||||
func (sf *SecureFilter) FilterMessage(message string) string {
|
||||
if sf.level == SecurityLevelDebug {
|
||||
return sf.sanitizeInput(message) // Still sanitize for security even in debug mode
|
||||
}
|
||||
|
||||
filtered := sf.sanitizeInput(message)
|
||||
|
||||
// Filter private keys FIRST (highest security priority)
|
||||
for _, pattern := range sf.privateKeyPatterns {
|
||||
filtered = pattern.ReplaceAllString(filtered, "[PRIVATE_KEY_FILTERED]")
|
||||
}
|
||||
|
||||
// Filter transaction hashes
|
||||
if sf.level >= SecurityLevelInfo {
|
||||
for _, pattern := range sf.hashPatterns {
|
||||
filtered = pattern.ReplaceAllStringFunc(filtered, func(hash string) string {
|
||||
if len(hash) == 66 { // Full transaction hash
|
||||
return hash[:10] + "..." + hash[62:] // Show first 8 and last 4 chars
|
||||
}
|
||||
return "[HASH_FILTERED]"
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Filter addresses NEXT (before amounts) to prevent addresses from being treated as numbers
|
||||
if sf.level >= SecurityLevelProduction {
|
||||
for _, pattern := range sf.addressPatterns {
|
||||
filtered = pattern.ReplaceAllStringFunc(filtered, func(addr string) string {
|
||||
if len(addr) == 42 { // Full Ethereum address
|
||||
return addr[:6] + "..." + addr[38:] // Show first 4 and last 4 chars
|
||||
}
|
||||
return "[ADDR_FILTERED]"
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Filter amounts LAST
|
||||
if sf.level >= SecurityLevelInfo {
|
||||
for _, pattern := range sf.amountPatterns {
|
||||
filtered = pattern.ReplaceAllString(filtered, "[AMOUNT_FILTERED]")
|
||||
}
|
||||
|
||||
for _, pattern := range sf.valuePatterns {
|
||||
filtered = pattern.ReplaceAllString(filtered, "[VALUE_FILTERED]")
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// sanitizeInput performs comprehensive input sanitization for log messages
|
||||
func (sf *SecureFilter) sanitizeInput(input string) string {
|
||||
// Remove null bytes and other control characters that could cause issues
|
||||
sanitized := strings.ReplaceAll(input, "\x00", "")
|
||||
sanitized = strings.ReplaceAll(sanitized, "\r", "")
|
||||
|
||||
// Remove potential log injection patterns
|
||||
sanitized = strings.ReplaceAll(sanitized, "\n", " ") // Replace newlines with spaces
|
||||
sanitized = strings.ReplaceAll(sanitized, "\t", " ") // Replace tabs with spaces
|
||||
|
||||
// Remove ANSI escape codes that could interfere with log parsing
|
||||
ansiPattern := regexp.MustCompile(`\x1b\[[0-9;]*[a-zA-Z]`)
|
||||
sanitized = ansiPattern.ReplaceAllString(sanitized, "")
|
||||
|
||||
// Remove other control characters (keep only printable ASCII and common Unicode)
|
||||
controlPattern := regexp.MustCompile(`[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]`)
|
||||
sanitized = controlPattern.ReplaceAllString(sanitized, "")
|
||||
|
||||
// Prevent log injection by escaping special characters
|
||||
sanitized = strings.ReplaceAll(sanitized, "%", "%%") // Escape printf format specifiers
|
||||
|
||||
// Limit message length to prevent DoS via large log messages
|
||||
const maxLogMessageLength = 4096
|
||||
if len(sanitized) > maxLogMessageLength {
|
||||
sanitized = sanitized[:maxLogMessageLength-3] + "..."
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// FilterTransactionData provides enhanced filtering for transaction data logging
|
||||
func (sf *SecureFilter) FilterTransactionData(txHash, from, to string, value, gasPrice *big.Int, data []byte) map[string]interface{} {
|
||||
result := map[string]interface{}{}
|
||||
|
||||
// Always include sanitized transaction hash
|
||||
result["tx_hash"] = sf.sanitizeInput(txHash)
|
||||
|
||||
switch sf.level {
|
||||
case SecurityLevelDebug:
|
||||
result["from"] = sf.sanitizeInput(from)
|
||||
result["to"] = sf.sanitizeInput(to)
|
||||
if value != nil {
|
||||
result["value"] = value.String()
|
||||
}
|
||||
if gasPrice != nil {
|
||||
result["gas_price"] = gasPrice.String()
|
||||
}
|
||||
result["data_size"] = len(data)
|
||||
|
||||
case SecurityLevelInfo:
|
||||
result["from"] = sf.shortenAddress(common.HexToAddress(from))
|
||||
result["to"] = sf.shortenAddress(common.HexToAddress(to))
|
||||
if value != nil {
|
||||
result["value_range"] = sf.getAmountRange(value)
|
||||
}
|
||||
if gasPrice != nil {
|
||||
result["gas_price_range"] = sf.getAmountRange(gasPrice)
|
||||
}
|
||||
result["data_size"] = len(data)
|
||||
|
||||
case SecurityLevelProduction:
|
||||
result["has_from"] = from != ""
|
||||
result["has_to"] = to != ""
|
||||
result["has_value"] = value != nil && value.Sign() > 0
|
||||
result["data_size"] = len(data)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterSwapData creates a safe representation of swap data for logging
|
||||
func (sf *SecureFilter) FilterSwapData(tokenIn, tokenOut common.Address, amountIn, amountOut *big.Int, protocol string) map[string]interface{} {
|
||||
data := map[string]interface{}{
|
||||
"protocol": protocol,
|
||||
}
|
||||
|
||||
switch sf.level {
|
||||
case SecurityLevelDebug:
|
||||
data["tokenIn"] = tokenIn.Hex()
|
||||
data["tokenOut"] = tokenOut.Hex()
|
||||
data["amountIn"] = amountIn.String()
|
||||
data["amountOut"] = amountOut.String()
|
||||
|
||||
case SecurityLevelInfo:
|
||||
data["tokenInShort"] = sf.shortenAddress(tokenIn)
|
||||
data["tokenOutShort"] = sf.shortenAddress(tokenOut)
|
||||
data["amountRange"] = sf.getAmountRange(amountIn)
|
||||
data["amountOutRange"] = sf.getAmountRange(amountOut)
|
||||
|
||||
case SecurityLevelProduction:
|
||||
data["tokenCount"] = 2
|
||||
data["hasAmounts"] = amountIn != nil && amountOut != nil
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// shortenAddress returns a shortened version of an address
|
||||
func (sf *SecureFilter) shortenAddress(addr common.Address) string {
|
||||
hex := addr.Hex()
|
||||
if len(hex) >= 10 {
|
||||
return hex[:6] + "..." + hex[len(hex)-4:]
|
||||
}
|
||||
return "[ADDR]"
|
||||
}
|
||||
|
||||
// getAmountRange returns a range category for an amount
|
||||
func (sf *SecureFilter) getAmountRange(amount *big.Int) string {
|
||||
if amount == nil {
|
||||
return "nil"
|
||||
}
|
||||
|
||||
// Convert to rough USD equivalent (assuming 18 decimals)
|
||||
usdValue := new(big.Float).Quo(new(big.Float).SetInt(amount), big.NewFloat(1e18))
|
||||
usdFloat, _ := usdValue.Float64()
|
||||
|
||||
switch {
|
||||
case usdFloat < 1:
|
||||
return "micro"
|
||||
case usdFloat < 100:
|
||||
return "small"
|
||||
case usdFloat < 10000:
|
||||
return "medium"
|
||||
case usdFloat < 1000000:
|
||||
return "large"
|
||||
default:
|
||||
return "whale"
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeForProduction removes all sensitive data for production logging
|
||||
func (sf *SecureFilter) SanitizeForProduction(data map[string]interface{}) map[string]interface{} {
|
||||
sanitized := make(map[string]interface{})
|
||||
|
||||
for key, value := range data {
|
||||
switch strings.ToLower(key) {
|
||||
case "amount", "amountin", "amountout", "value", "profit", "usd", "price":
|
||||
sanitized[key] = "[FILTERED]"
|
||||
case "address", "token", "tokenin", "tokenout", "pool", "contract":
|
||||
if addr, ok := value.(common.Address); ok {
|
||||
sanitized[key] = sf.shortenAddress(addr)
|
||||
} else if str, ok := value.(string); ok && strings.HasPrefix(str, "0x") && len(str) == 42 {
|
||||
sanitized[key] = str[:6] + "..." + str[38:]
|
||||
} else {
|
||||
sanitized[key] = value
|
||||
}
|
||||
default:
|
||||
sanitized[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
226
orig/internal/logger/secure_filter_enhanced_test.go
Normal file
226
orig/internal/logger/secure_filter_enhanced_test.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
func TestSecureFilterEnhanced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
level SecurityLevel
|
||||
message string
|
||||
expectFiltered bool
|
||||
expectedLevel string
|
||||
}{
|
||||
{
|
||||
name: "Private key detection",
|
||||
level: SecurityLevelProduction,
|
||||
message: "private_key=1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
|
||||
expectFiltered: true,
|
||||
expectedLevel: "CRITICAL",
|
||||
},
|
||||
{
|
||||
name: "Address detection",
|
||||
level: SecurityLevelProduction,
|
||||
message: "Swapping on address 0x1234567890123456789012345678901234567890",
|
||||
expectFiltered: true,
|
||||
expectedLevel: "MEDIUM",
|
||||
},
|
||||
{
|
||||
name: "Amount detection",
|
||||
level: SecurityLevelInfo,
|
||||
message: "Profit amount=1000000 wei detected",
|
||||
expectFiltered: true,
|
||||
expectedLevel: "LOW",
|
||||
},
|
||||
{
|
||||
name: "Transaction hash detection",
|
||||
level: SecurityLevelInfo,
|
||||
message: "txHash=0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
|
||||
expectFiltered: true,
|
||||
expectedLevel: "LOW",
|
||||
},
|
||||
{
|
||||
name: "No sensitive data",
|
||||
level: SecurityLevelProduction,
|
||||
message: "Normal log message with no sensitive data",
|
||||
expectFiltered: false,
|
||||
expectedLevel: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &SecureFilterConfig{
|
||||
Level: tt.level,
|
||||
AuditEnabled: true,
|
||||
AuditEncryption: false,
|
||||
}
|
||||
filter := NewSecureFilterWithConfig(config)
|
||||
|
||||
// Test detection without filtering yet
|
||||
auditData := filter.detectSensitiveData(tt.message, nil)
|
||||
|
||||
if tt.expectFiltered {
|
||||
if auditData == nil {
|
||||
t.Errorf("Expected sensitive data detection for: %s", tt.message)
|
||||
return
|
||||
}
|
||||
if auditData["severity"] != tt.expectedLevel {
|
||||
t.Errorf("Expected severity %s, got %v", tt.expectedLevel, auditData["severity"])
|
||||
}
|
||||
} else {
|
||||
if auditData != nil {
|
||||
t.Errorf("Unexpected sensitive data detection for: %s", tt.message)
|
||||
}
|
||||
}
|
||||
|
||||
// Test the actual filtering
|
||||
filtered := filter.FilterMessage(tt.message)
|
||||
if tt.expectFiltered && filtered == tt.message {
|
||||
t.Errorf("Expected message to be filtered, but it wasn't: %s", tt.message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureFilterWithEncryption(t *testing.T) {
|
||||
// Generate a random encryption key
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate encryption key: %v", err)
|
||||
}
|
||||
|
||||
config := &SecureFilterConfig{
|
||||
Level: SecurityLevelProduction,
|
||||
EncryptionKey: key,
|
||||
AuditEnabled: true,
|
||||
AuditEncryption: true,
|
||||
}
|
||||
filter := NewSecureFilterWithConfig(config)
|
||||
|
||||
testData := map[string]interface{}{
|
||||
"test_field": "test_value",
|
||||
"number": 123,
|
||||
"nested": map[string]interface{}{
|
||||
"inner": "value",
|
||||
},
|
||||
}
|
||||
|
||||
// Test encryption and decryption
|
||||
encrypted, err := filter.encryptAuditData(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt audit data: %v", err)
|
||||
}
|
||||
|
||||
if encrypted == "" {
|
||||
t.Fatal("Encrypted data should not be empty")
|
||||
}
|
||||
|
||||
// Test decryption
|
||||
decrypted, err := filter.DecryptAuditData(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt audit data: %v", err)
|
||||
}
|
||||
|
||||
// Verify data integrity
|
||||
if decrypted["test_field"] != testData["test_field"] {
|
||||
t.Errorf("Decrypted data doesn't match original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureFilterAddressFiltering(t *testing.T) {
|
||||
filter := NewSecureFilter(SecurityLevelProduction)
|
||||
|
||||
address := common.HexToAddress("0x1234567890123456789012345678901234567890")
|
||||
testMessage := "Processing transaction for address " + address.Hex()
|
||||
|
||||
filtered := filter.FilterMessage(testMessage)
|
||||
|
||||
// Should contain shortened address
|
||||
if !strings.Contains(filtered, "0x1234...7890") {
|
||||
t.Errorf("Expected shortened address in filtered message, got: %s", filtered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureFilterAmountFiltering(t *testing.T) {
|
||||
filter := NewSecureFilter(SecurityLevelInfo)
|
||||
|
||||
testCases := []struct {
|
||||
message string
|
||||
contains string
|
||||
}{
|
||||
{"amount=1000000", "[AMOUNT_FILTERED]"},
|
||||
{"Profit $123.45 detected", "[AMOUNT_FILTERED]"},
|
||||
{"balance=999999999999999999", "[AMOUNT_FILTERED]"},
|
||||
{"gasPrice=20000000000", "[AMOUNT_FILTERED]"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
filtered := filter.FilterMessage(tc.message)
|
||||
if !strings.Contains(filtered, tc.contains) {
|
||||
t.Errorf("Expected %s in filtered message for input '%s', got: %s", tc.contains, tc.message, filtered)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureFilterConfiguration(t *testing.T) {
|
||||
filter := NewSecureFilter(SecurityLevelDebug)
|
||||
|
||||
// Test initial level
|
||||
if filter.GetSecurityLevel() != SecurityLevelDebug {
|
||||
t.Errorf("Expected initial level to be Debug, got: %v", filter.GetSecurityLevel())
|
||||
}
|
||||
|
||||
// Test level change
|
||||
filter.SetSecurityLevel(SecurityLevelProduction)
|
||||
if filter.GetSecurityLevel() != SecurityLevelProduction {
|
||||
t.Errorf("Expected level to be Production, got: %v", filter.GetSecurityLevel())
|
||||
}
|
||||
|
||||
// Test audit enabling
|
||||
filter.EnableAuditLogging([]byte("test-key"))
|
||||
if !filter.auditEnabled {
|
||||
t.Error("Expected audit to be enabled")
|
||||
}
|
||||
|
||||
if !filter.auditEncryption {
|
||||
t.Error("Expected audit encryption to be enabled")
|
||||
}
|
||||
|
||||
// Test audit disabling
|
||||
filter.DisableAuditLogging()
|
||||
if filter.auditEnabled {
|
||||
t.Error("Expected audit to be disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSecureFilter(b *testing.B) {
|
||||
filter := NewSecureFilter(SecurityLevelProduction)
|
||||
testMessage := "Processing transaction 0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef for address 0x1111111111111111111111111111111111111111 with amount=1000000"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
filter.FilterMessage(testMessage)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSecureFilterEnhanced(b *testing.B) {
|
||||
config := &SecureFilterConfig{
|
||||
Level: SecurityLevelProduction,
|
||||
AuditEnabled: true,
|
||||
AuditEncryption: false,
|
||||
}
|
||||
filter := NewSecureFilterWithConfig(config)
|
||||
testMessage := "Processing transaction 0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef for address 0x1111111111111111111111111111111111111111 with amount=1000000"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
filter.FilterMessageEnhanced(testMessage, nil)
|
||||
}
|
||||
}
|
||||
439
orig/internal/logger/secure_filter_test.go
Normal file
439
orig/internal/logger/secure_filter_test.go
Normal file
@@ -0,0 +1,439 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewSecureFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
level SecurityLevel
|
||||
}{
|
||||
{"Debug level", SecurityLevelDebug},
|
||||
{"Info level", SecurityLevelInfo},
|
||||
{"Production level", SecurityLevelProduction},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
filter := NewSecureFilter(tt.level)
|
||||
assert.NotNil(t, filter)
|
||||
assert.Equal(t, tt.level, filter.level)
|
||||
assert.NotNil(t, filter.amountPatterns)
|
||||
assert.NotNil(t, filter.addressPatterns)
|
||||
assert.NotNil(t, filter.valuePatterns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterMessage_DebugLevel(t *testing.T) {
|
||||
filter := NewSecureFilter(SecurityLevelDebug)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Debug shows everything",
|
||||
input: "Amount: 1000.5 ETH, Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789, Value: $5000.00",
|
||||
expected: "Amount: 1000.5 ETH, Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789, Value: $5000.00",
|
||||
},
|
||||
{
|
||||
name: "Large amounts shown in debug",
|
||||
input: "Profit: 999999.123456789 USDC",
|
||||
expected: "Profit: 999999.123456789 USDC",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filter.FilterMessage(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterMessage_InfoLevel(t *testing.T) {
|
||||
filter := NewSecureFilter(SecurityLevelInfo)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Info filters amounts but shows full addresses",
|
||||
input: "Amount: 1000.5 ETH, Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789",
|
||||
expected: "Amount: 1000.5 ETH, Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789",
|
||||
},
|
||||
{
|
||||
name: "USD values filtered",
|
||||
input: "Profit: $5000.00 USD",
|
||||
expected: "Profit: [AMOUNT_FILTERED] USD",
|
||||
},
|
||||
{
|
||||
name: "Multiple amounts filtered",
|
||||
input: "Swap 100.0 USDC for 0.05 ETH",
|
||||
expected: "Swap [AMOUNT_FILTERED]C for 0.05 ETH",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filter.FilterMessage(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterMessage_ProductionLevel(t *testing.T) {
|
||||
filter := NewSecureFilter(SecurityLevelProduction)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Production filters everything sensitive",
|
||||
input: "Amount: 1000.5 ETH, Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789, Value: $5000.00",
|
||||
expected: "Amount: 1000.5 ETH, Address: 0x742d...6789, Value: [AMOUNT_FILTERED]",
|
||||
},
|
||||
{
|
||||
name: "Complex transaction filtered",
|
||||
input: "Swap 1500.789 USDC from 0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD to 0xaf88d065e77c8cC2239327C5EDb3A432268e5831 for $1500.00 profit",
|
||||
expected: "Swap [AMOUNT_FILTERED]C from 0xA0b8...7ACD to 0xaf88...5831 for [AMOUNT_FILTERED] profit",
|
||||
},
|
||||
{
|
||||
name: "Multiple addresses and amounts",
|
||||
input: "Transfer 500 tokens from 0x1234567890123456789012345678901234567890 to 0x0987654321098765432109876543210987654321 worth $1000",
|
||||
expected: "Transfer 500 tokens from 0x1234...7890 to 0x0987...4321 worth [AMOUNT_FILTERED]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filter.FilterMessage(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShortenAddress(t *testing.T) {
|
||||
filter := NewSecureFilter(SecurityLevelInfo)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input common.Address
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Standard address",
|
||||
input: common.HexToAddress("0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789"),
|
||||
expected: "0x742D...6789",
|
||||
},
|
||||
{
|
||||
name: "Zero address",
|
||||
input: common.HexToAddress("0x0000000000000000000000000000000000000000"),
|
||||
expected: "0x0000...0000",
|
||||
},
|
||||
{
|
||||
name: "Max address",
|
||||
input: common.HexToAddress("0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"),
|
||||
expected: "0xFFfF...FFfF",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filter.shortenAddress(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCategorizeAmount(t *testing.T) {
|
||||
_ = NewSecureFilter(SecurityLevelInfo) // Reference to avoid unused variable warning
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input *big.Int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Nil amount",
|
||||
input: nil,
|
||||
expected: "nil",
|
||||
},
|
||||
{
|
||||
name: "Micro amount (< $1)",
|
||||
input: big.NewInt(500000000000000000), // 0.5 ETH assuming 18 decimals
|
||||
expected: "micro",
|
||||
},
|
||||
{
|
||||
name: "Small amount ($1-$100)",
|
||||
input: func() *big.Int { val, _ := new(big.Int).SetString("50000000000000000000", 10); return val }(), // 50 ETH
|
||||
expected: "small",
|
||||
},
|
||||
{
|
||||
name: "Medium amount ($100-$10k)",
|
||||
input: func() *big.Int { val, _ := new(big.Int).SetString("5000000000000000000000", 10); return val }(), // 5000 ETH
|
||||
expected: "medium",
|
||||
},
|
||||
{
|
||||
name: "Large amount ($10k-$1M)",
|
||||
input: func() *big.Int { val, _ := new(big.Int).SetString("500000000000000000000000", 10); return val }(), // 500k ETH
|
||||
expected: "large",
|
||||
},
|
||||
{
|
||||
name: "Whale amount (>$1M)",
|
||||
input: func() *big.Int { val, _ := new(big.Int).SetString("2000000000000000000000000", 10); return val }(), // 2M ETH
|
||||
expected: "whale",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Note: categorizeAmount is private, so we can't test it directly
|
||||
// This test would need to be adapted to test the public API that uses it
|
||||
_ = tt.input // Reference to avoid unused variable warning
|
||||
_ = tt.expected // Reference to avoid unused variable warning
|
||||
// Just pass the test since we can't test private methods directly
|
||||
assert.True(t, true, "categorizeAmount is private - testing would need public wrapper")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeForProduction(t *testing.T) {
|
||||
filter := NewSecureFilter(SecurityLevelProduction)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string]interface{}
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "Amounts filtered",
|
||||
input: map[string]interface{}{
|
||||
"amount": 1000.5,
|
||||
"amountIn": 500,
|
||||
"amountOut": 750,
|
||||
"value": 999.99,
|
||||
"other": "should remain",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"amount": "[FILTERED]",
|
||||
"amountIn": "[FILTERED]",
|
||||
"amountOut": "[FILTERED]",
|
||||
"value": "[FILTERED]",
|
||||
"other": "should remain",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Addresses filtered and shortened",
|
||||
input: map[string]interface{}{
|
||||
"address": common.HexToAddress("0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789"),
|
||||
"token": "0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD",
|
||||
"pool": "0xaf88d065e77c8cC2239327C5EDb3A432268e5831",
|
||||
"other": "should remain",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"address": "0x742D...6789",
|
||||
"token": "0xA0b8...7ACD",
|
||||
"pool": "0xaf88...5831",
|
||||
"other": "should remain",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Mixed data types",
|
||||
input: map[string]interface{}{
|
||||
"profit": 1000.0,
|
||||
"tokenOut": common.HexToAddress("0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789"),
|
||||
"fee": 30,
|
||||
"protocol": "UniswapV3",
|
||||
"timestamp": 1640995200,
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"profit": "[FILTERED]",
|
||||
"tokenOut": "0x742D...6789",
|
||||
"fee": 30,
|
||||
"protocol": "UniswapV3",
|
||||
"timestamp": 1640995200,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filter.SanitizeForProduction(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterMessage_ComplexScenarios(t *testing.T) {
|
||||
productionFilter := NewSecureFilter(SecurityLevelProduction)
|
||||
infoFilter := NewSecureFilter(SecurityLevelInfo)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
production string
|
||||
info string
|
||||
}{
|
||||
{
|
||||
name: "MEV opportunity log",
|
||||
input: "🎯 ARBITRAGE OPPORTUNITY: Swap 1500.789 USDC via 0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD for profit $250.50",
|
||||
production: "🎯 ARBITRAGE OPPORTUNITY: Swap [AMOUNT_FILTERED]C via 0xA0b8...7ACD for profit [AMOUNT_FILTERED]",
|
||||
info: "🎯 ARBITRAGE OPPORTUNITY: Swap [AMOUNT_FILTERED]C via 0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD for profit [AMOUNT_FILTERED]",
|
||||
},
|
||||
{
|
||||
name: "Transaction log with multiple sensitive data",
|
||||
input: "TX: 0x123...abc Amount: 999.123456 ETH → 1500000.5 USDC, Gas: 150000, Value: $2500000.75",
|
||||
production: "TX: 0x123...abc Amount: 999.123456 ETH → [AMOUNT_FILTERED]C, Gas: 150000, Value: [AMOUNT_FILTERED]",
|
||||
info: "TX: 0x123...abc Amount: 999.123456 ETH → [AMOUNT_FILTERED]C, Gas: 150000, Value: [AMOUNT_FILTERED]",
|
||||
},
|
||||
{
|
||||
name: "Pool creation event",
|
||||
input: "Pool created: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789 with 1000000.0 liquidity worth $5000000",
|
||||
production: "Pool created: 0x742d...6789 with 1000000.0 liquidity worth [AMOUNT_FILTERED]",
|
||||
info: "Pool created: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789 with 1000000.0 liquidity worth [AMOUNT_FILTERED]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name+" - Production", func(t *testing.T) {
|
||||
result := productionFilter.FilterMessage(tt.input)
|
||||
assert.Equal(t, tt.production, result)
|
||||
})
|
||||
|
||||
t.Run(tt.name+" - Info", func(t *testing.T) {
|
||||
result := infoFilter.FilterMessage(tt.input)
|
||||
assert.Equal(t, tt.info, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterMessage_EdgeCases(t *testing.T) {
|
||||
filter := NewSecureFilter(SecurityLevelProduction)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "No sensitive data",
|
||||
input: "Simple log message with no sensitive information",
|
||||
expected: "Simple log message with no sensitive information",
|
||||
},
|
||||
{
|
||||
name: "Only numbers (not amounts)",
|
||||
input: "Block number: 12345, Gas limit: 8000000",
|
||||
expected: "Block number: 12345, Gas limit: 8000000",
|
||||
},
|
||||
{
|
||||
name: "Scientific notation",
|
||||
input: "Amount: 1.5e18 wei",
|
||||
expected: "Amount: 1.5e18 wei",
|
||||
},
|
||||
{
|
||||
name: "Multiple decimal places",
|
||||
input: "Price: 1234.567890123456 tokens",
|
||||
expected: "Price: 1234.567890123456 tokens",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filter.FilterMessage(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkFilterMessage_Production(b *testing.B) {
|
||||
filter := NewSecureFilter(SecurityLevelProduction)
|
||||
input := "🎯 ARBITRAGE OPPORTUNITY: Swap 1500.789 USDC via 0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD for profit $250.50"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
filter.FilterMessage(input)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFilterMessage_Info(b *testing.B) {
|
||||
filter := NewSecureFilter(SecurityLevelInfo)
|
||||
input := "Transaction: 1000.5 ETH from 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789 to 0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
filter.FilterMessage(input)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSanitizeForProduction(b *testing.B) {
|
||||
filter := NewSecureFilter(SecurityLevelProduction)
|
||||
data := map[string]interface{}{
|
||||
"amount": 1000.5,
|
||||
"address": common.HexToAddress("0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789"),
|
||||
"profit": 250.75,
|
||||
"protocol": "UniswapV3",
|
||||
"token": "0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
filter.SanitizeForProduction(data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityLevelConstants(t *testing.T) {
|
||||
// Verify security level constants are defined correctly
|
||||
assert.Equal(t, SecurityLevel(0), SecurityLevelDebug)
|
||||
assert.Equal(t, SecurityLevel(1), SecurityLevelInfo)
|
||||
assert.Equal(t, SecurityLevel(2), SecurityLevelProduction)
|
||||
}
|
||||
|
||||
func TestRegexPatterns(t *testing.T) {
|
||||
filter := NewSecureFilter(SecurityLevelProduction)
|
||||
|
||||
// Test that patterns are properly compiled
|
||||
assert.True(t, len(filter.amountPatterns) > 0, "Should have amount patterns")
|
||||
assert.True(t, len(filter.addressPatterns) > 0, "Should have address patterns")
|
||||
assert.True(t, len(filter.valuePatterns) > 0, "Should have value patterns")
|
||||
|
||||
// Test pattern matching
|
||||
testCases := []struct {
|
||||
patterns []*regexp.Regexp
|
||||
input string
|
||||
should string
|
||||
}{
|
||||
{filter.amountPatterns, "amount=123", "match amount patterns"},
|
||||
{filter.addressPatterns, "Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789", "match address patterns"},
|
||||
{filter.valuePatterns, "profit=$1234.56", "match value patterns"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
found := false
|
||||
for _, pattern := range tc.patterns {
|
||||
if pattern.MatchString(tc.input) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, tc.should)
|
||||
}
|
||||
}
|
||||
400
orig/internal/monitoring/alert_handlers.go
Normal file
400
orig/internal/monitoring/alert_handlers.go
Normal file
@@ -0,0 +1,400 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// LogAlertHandler logs alerts to the application logger
|
||||
type LogAlertHandler struct {
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// NewLogAlertHandler creates a new log-based alert handler
|
||||
func NewLogAlertHandler(logger *logger.Logger) *LogAlertHandler {
|
||||
return &LogAlertHandler{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAlert logs the alert using structured logging
|
||||
func (lah *LogAlertHandler) HandleAlert(alert CorruptionAlert) error {
|
||||
switch alert.Severity {
|
||||
case AlertSeverityEmergency:
|
||||
lah.logger.Error("🚨 EMERGENCY ALERT",
|
||||
"message", alert.Message,
|
||||
"severity", alert.Severity.String(),
|
||||
"timestamp", alert.Timestamp,
|
||||
"context", alert.Context)
|
||||
case AlertSeverityCritical:
|
||||
lah.logger.Error("🔴 CRITICAL ALERT",
|
||||
"message", alert.Message,
|
||||
"severity", alert.Severity.String(),
|
||||
"timestamp", alert.Timestamp,
|
||||
"context", alert.Context)
|
||||
case AlertSeverityWarning:
|
||||
lah.logger.Warn("🟡 WARNING ALERT",
|
||||
"message", alert.Message,
|
||||
"severity", alert.Severity.String(),
|
||||
"timestamp", alert.Timestamp,
|
||||
"context", alert.Context)
|
||||
default:
|
||||
lah.logger.Info("ℹ️ INFO ALERT",
|
||||
"message", alert.Message,
|
||||
"severity", alert.Severity.String(),
|
||||
"timestamp", alert.Timestamp,
|
||||
"context", alert.Context)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FileAlertHandler writes alerts to a file in JSON format
|
||||
type FileAlertHandler struct {
|
||||
mu sync.Mutex
|
||||
filePath string
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// NewFileAlertHandler creates a new file-based alert handler
|
||||
func NewFileAlertHandler(filePath string, logger *logger.Logger) *FileAlertHandler {
|
||||
return &FileAlertHandler{
|
||||
filePath: filePath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAlert writes the alert to a file
|
||||
func (fah *FileAlertHandler) HandleAlert(alert CorruptionAlert) error {
|
||||
fah.mu.Lock()
|
||||
defer fah.mu.Unlock()
|
||||
|
||||
// Create alert record for file
|
||||
alertRecord := map[string]interface{}{
|
||||
"timestamp": alert.Timestamp.Format(time.RFC3339),
|
||||
"severity": alert.Severity.String(),
|
||||
"message": alert.Message,
|
||||
"address": alert.Address.Hex(),
|
||||
"corruption_score": alert.CorruptionScore,
|
||||
"source": alert.Source,
|
||||
"context": alert.Context,
|
||||
}
|
||||
|
||||
// Convert to JSON
|
||||
alertJSON, err := json.Marshal(alertRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal alert: %w", err)
|
||||
}
|
||||
|
||||
// Open file for appending
|
||||
file, err := os.OpenFile(fah.filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open alert file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write alert with newline
|
||||
if _, err := file.Write(append(alertJSON, '\n')); err != nil {
|
||||
return fmt.Errorf("failed to write alert to file: %w", err)
|
||||
}
|
||||
|
||||
fah.logger.Debug("Alert written to file",
|
||||
"file", fah.filePath,
|
||||
"severity", alert.Severity.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HTTPAlertHandler sends alerts to an HTTP endpoint (e.g., Slack, Discord, PagerDuty)
|
||||
type HTTPAlertHandler struct {
|
||||
mu sync.Mutex
|
||||
webhookURL string
|
||||
client *http.Client
|
||||
logger *logger.Logger
|
||||
retryCount int
|
||||
}
|
||||
|
||||
// NewHTTPAlertHandler creates a new HTTP-based alert handler
|
||||
func NewHTTPAlertHandler(webhookURL string, logger *logger.Logger) *HTTPAlertHandler {
|
||||
return &HTTPAlertHandler{
|
||||
webhookURL: webhookURL,
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: logger,
|
||||
retryCount: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAlert sends the alert to the configured HTTP endpoint
|
||||
func (hah *HTTPAlertHandler) HandleAlert(alert CorruptionAlert) error {
|
||||
if hah.webhookURL == "" {
|
||||
return fmt.Errorf("webhook URL not configured")
|
||||
}
|
||||
|
||||
// Create payload based on webhook type
|
||||
payload := hah.createPayload(alert)
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadJSON, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal webhook payload: %w", err)
|
||||
}
|
||||
|
||||
// Send with retries
|
||||
for attempt := 1; attempt <= hah.retryCount; attempt++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", hah.webhookURL, strings.NewReader(string(payloadJSON)))
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "MEV-Bot-AlertHandler/1.0")
|
||||
|
||||
resp, err := hah.client.Do(req)
|
||||
if err != nil {
|
||||
hah.logger.Warn("Failed to send alert to webhook",
|
||||
"attempt", attempt,
|
||||
"error", err)
|
||||
if attempt == hah.retryCount {
|
||||
return fmt.Errorf("failed to send alert after %d attempts: %w", hah.retryCount, err)
|
||||
}
|
||||
time.Sleep(time.Duration(attempt) * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body for debugging
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
hah.logger.Debug("Alert sent successfully",
|
||||
"webhook_url", hah.webhookURL,
|
||||
"status_code", resp.StatusCode,
|
||||
"response", string(body))
|
||||
return nil
|
||||
}
|
||||
|
||||
hah.logger.Warn("Webhook returned error status",
|
||||
"attempt", attempt,
|
||||
"status_code", resp.StatusCode,
|
||||
"response", string(body))
|
||||
|
||||
if attempt == hah.retryCount {
|
||||
return fmt.Errorf("webhook returned status %d after %d attempts", resp.StatusCode, hah.retryCount)
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(attempt) * time.Second)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createPayload creates the webhook payload based on the webhook type
|
||||
func (hah *HTTPAlertHandler) createPayload(alert CorruptionAlert) map[string]interface{} {
|
||||
// Detect webhook type based on URL
|
||||
if strings.Contains(hah.webhookURL, "slack.com") {
|
||||
return hah.createSlackPayload(alert)
|
||||
} else if strings.Contains(hah.webhookURL, "discord.com") {
|
||||
return hah.createDiscordPayload(alert)
|
||||
}
|
||||
|
||||
// Generic webhook payload
|
||||
return map[string]interface{}{
|
||||
"timestamp": alert.Timestamp.Format(time.RFC3339),
|
||||
"severity": alert.Severity.String(),
|
||||
"message": alert.Message,
|
||||
"address": alert.Address.Hex(),
|
||||
"corruption_score": alert.CorruptionScore,
|
||||
"source": alert.Source,
|
||||
"context": alert.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// createSlackPayload creates a Slack-compatible webhook payload
|
||||
func (hah *HTTPAlertHandler) createSlackPayload(alert CorruptionAlert) map[string]interface{} {
|
||||
color := "good"
|
||||
switch alert.Severity {
|
||||
case AlertSeverityWarning:
|
||||
color = "warning"
|
||||
case AlertSeverityCritical:
|
||||
color = "danger"
|
||||
case AlertSeverityEmergency:
|
||||
color = "#FF0000" // Bright red for emergency
|
||||
}
|
||||
|
||||
attachment := map[string]interface{}{
|
||||
"color": color,
|
||||
"title": fmt.Sprintf("%s Alert - MEV Bot", alert.Severity.String()),
|
||||
"text": alert.Message,
|
||||
"timestamp": alert.Timestamp.Unix(),
|
||||
"fields": []map[string]interface{}{
|
||||
{
|
||||
"title": "Address",
|
||||
"value": alert.Address.Hex(),
|
||||
"short": true,
|
||||
},
|
||||
{
|
||||
"title": "Corruption Score",
|
||||
"value": fmt.Sprintf("%d", alert.CorruptionScore),
|
||||
"short": true,
|
||||
},
|
||||
{
|
||||
"title": "Source",
|
||||
"value": alert.Source,
|
||||
"short": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"text": fmt.Sprintf("MEV Bot Alert: %s", alert.Severity.String()),
|
||||
"attachments": []map[string]interface{}{attachment},
|
||||
}
|
||||
}
|
||||
|
||||
// createDiscordPayload creates a Discord-compatible webhook payload
|
||||
func (hah *HTTPAlertHandler) createDiscordPayload(alert CorruptionAlert) map[string]interface{} {
|
||||
color := 0x00FF00 // Green
|
||||
switch alert.Severity {
|
||||
case AlertSeverityWarning:
|
||||
color = 0xFFFF00 // Yellow
|
||||
case AlertSeverityCritical:
|
||||
color = 0xFF8000 // Orange
|
||||
case AlertSeverityEmergency:
|
||||
color = 0xFF0000 // Red
|
||||
}
|
||||
|
||||
embed := map[string]interface{}{
|
||||
"title": fmt.Sprintf("%s Alert - MEV Bot", alert.Severity.String()),
|
||||
"description": alert.Message,
|
||||
"color": color,
|
||||
"timestamp": alert.Timestamp.Format(time.RFC3339),
|
||||
"fields": []map[string]interface{}{
|
||||
{
|
||||
"name": "Address",
|
||||
"value": alert.Address.Hex(),
|
||||
"inline": true,
|
||||
},
|
||||
{
|
||||
"name": "Corruption Score",
|
||||
"value": fmt.Sprintf("%d", alert.CorruptionScore),
|
||||
"inline": true,
|
||||
},
|
||||
{
|
||||
"name": "Source",
|
||||
"value": alert.Source,
|
||||
"inline": true,
|
||||
},
|
||||
},
|
||||
"footer": map[string]interface{}{
|
||||
"text": "MEV Bot Integrity Monitor",
|
||||
},
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"embeds": []map[string]interface{}{embed},
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsAlertHandler integrates with metrics systems (Prometheus, etc.)
|
||||
type MetricsAlertHandler struct {
|
||||
mu sync.Mutex
|
||||
logger *logger.Logger
|
||||
counters map[string]int64
|
||||
}
|
||||
|
||||
// NewMetricsAlertHandler creates a new metrics-based alert handler
|
||||
func NewMetricsAlertHandler(logger *logger.Logger) *MetricsAlertHandler {
|
||||
return &MetricsAlertHandler{
|
||||
logger: logger,
|
||||
counters: make(map[string]int64),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAlert updates metrics counters based on alert
|
||||
func (mah *MetricsAlertHandler) HandleAlert(alert CorruptionAlert) error {
|
||||
mah.mu.Lock()
|
||||
defer mah.mu.Unlock()
|
||||
|
||||
// Increment counters
|
||||
mah.counters["total_alerts"]++
|
||||
mah.counters[fmt.Sprintf("alerts_%s", strings.ToLower(alert.Severity.String()))]++
|
||||
|
||||
if alert.CorruptionScore > 0 {
|
||||
mah.counters["corruption_alerts"]++
|
||||
}
|
||||
|
||||
mah.logger.Debug("Metrics updated for alert",
|
||||
"severity", alert.Severity.String(),
|
||||
"total_alerts", mah.counters["total_alerts"])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCounters returns the current alert counters
|
||||
func (mah *MetricsAlertHandler) GetCounters() map[string]int64 {
|
||||
mah.mu.Lock()
|
||||
defer mah.mu.Unlock()
|
||||
|
||||
// Return a copy
|
||||
counters := make(map[string]int64)
|
||||
for k, v := range mah.counters {
|
||||
counters[k] = v
|
||||
}
|
||||
|
||||
return counters
|
||||
}
|
||||
|
||||
// CompositeAlertHandler combines multiple alert handlers
|
||||
type CompositeAlertHandler struct {
|
||||
handlers []AlertSubscriber
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// NewCompositeAlertHandler creates a composite alert handler
|
||||
func NewCompositeAlertHandler(logger *logger.Logger, handlers ...AlertSubscriber) *CompositeAlertHandler {
|
||||
return &CompositeAlertHandler{
|
||||
handlers: handlers,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAlert sends the alert to all configured handlers
|
||||
func (cah *CompositeAlertHandler) HandleAlert(alert CorruptionAlert) error {
|
||||
errors := make([]error, 0)
|
||||
|
||||
for i, handler := range cah.handlers {
|
||||
if err := handler.HandleAlert(alert); err != nil {
|
||||
cah.logger.Error("Alert handler failed",
|
||||
"handler_index", i,
|
||||
"handler_type", fmt.Sprintf("%T", handler),
|
||||
"error", err)
|
||||
errors = append(errors, fmt.Errorf("handler %d (%T): %w", i, handler, err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("alert handler errors: %v", errors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddHandler adds a new handler to the composite
|
||||
func (cah *CompositeAlertHandler) AddHandler(handler AlertSubscriber) {
|
||||
cah.handlers = append(cah.handlers, handler)
|
||||
}
|
||||
549
orig/internal/monitoring/dashboard.go
Normal file
549
orig/internal/monitoring/dashboard.go
Normal file
@@ -0,0 +1,549 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// DashboardServer provides a web-based monitoring dashboard
|
||||
type DashboardServer struct {
|
||||
logger *logger.Logger
|
||||
integrityMonitor *IntegrityMonitor
|
||||
healthChecker *HealthCheckRunner
|
||||
port int
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
// NewDashboardServer creates a new dashboard server
|
||||
func NewDashboardServer(logger *logger.Logger, integrityMonitor *IntegrityMonitor, healthChecker *HealthCheckRunner, port int) *DashboardServer {
|
||||
return &DashboardServer{
|
||||
logger: logger,
|
||||
integrityMonitor: integrityMonitor,
|
||||
healthChecker: healthChecker,
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the dashboard HTTP server
|
||||
func (ds *DashboardServer) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Register endpoints
|
||||
mux.HandleFunc("/", ds.handleDashboard)
|
||||
mux.HandleFunc("/api/health", ds.handleAPIHealth)
|
||||
mux.HandleFunc("/api/metrics", ds.handleAPIMetrics)
|
||||
mux.HandleFunc("/api/history", ds.handleAPIHistory)
|
||||
mux.HandleFunc("/api/alerts", ds.handleAPIAlerts)
|
||||
mux.HandleFunc("/static/", ds.handleStatic)
|
||||
|
||||
ds.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", ds.port),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
ds.logger.Info("Starting monitoring dashboard",
|
||||
"port", ds.port,
|
||||
"url", fmt.Sprintf("http://localhost:%d", ds.port))
|
||||
|
||||
return ds.server.ListenAndServe()
|
||||
}
|
||||
|
||||
// Stop stops the dashboard server
|
||||
func (ds *DashboardServer) Stop() error {
|
||||
if ds.server != nil {
|
||||
return ds.server.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleDashboard serves the main dashboard HTML page
|
||||
func (ds *DashboardServer) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
|
||||
// Get current metrics and health data
|
||||
metrics := ds.integrityMonitor.GetMetrics()
|
||||
healthSummary := ds.integrityMonitor.GetHealthSummary()
|
||||
healthHistory := ds.healthChecker.GetRecentSnapshots(20)
|
||||
|
||||
// Render dashboard template
|
||||
tmpl := ds.getDashboardTemplate()
|
||||
data := struct {
|
||||
Metrics MetricsSnapshot
|
||||
HealthSummary map[string]interface{}
|
||||
HealthHistory []HealthSnapshot
|
||||
Timestamp time.Time
|
||||
}{
|
||||
Metrics: metrics,
|
||||
HealthSummary: healthSummary,
|
||||
HealthHistory: healthHistory,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(w, data); err != nil {
|
||||
ds.logger.Error("Failed to render dashboard", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleAPIHealth provides JSON health endpoint
|
||||
func (ds *DashboardServer) handleAPIHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
healthSummary := ds.integrityMonitor.GetHealthSummary()
|
||||
healthCheckerSummary := ds.healthChecker.GetHealthSummary()
|
||||
|
||||
// Combine summaries
|
||||
response := map[string]interface{}{
|
||||
"integrity_monitor": healthSummary,
|
||||
"health_checker": healthCheckerSummary,
|
||||
"timestamp": time.Now(),
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
ds.logger.Error("Failed to encode health response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// handleAPIMetrics provides JSON metrics endpoint
|
||||
func (ds *DashboardServer) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
metrics := ds.integrityMonitor.GetMetrics()
|
||||
|
||||
if err := json.NewEncoder(w).Encode(metrics); err != nil {
|
||||
ds.logger.Error("Failed to encode metrics response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// handleAPIHistory provides JSON health history endpoint
|
||||
func (ds *DashboardServer) handleAPIHistory(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Get count parameter (default 20)
|
||||
countStr := r.URL.Query().Get("count")
|
||||
count := 20
|
||||
if countStr != "" {
|
||||
if c, err := strconv.Atoi(countStr); err == nil && c > 0 && c <= 100 {
|
||||
count = c
|
||||
}
|
||||
}
|
||||
|
||||
history := ds.healthChecker.GetRecentSnapshots(count)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(history); err != nil {
|
||||
ds.logger.Error("Failed to encode history response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// handleAPIAlerts provides recent alerts for integrity and health monitoring.
|
||||
func (ds *DashboardServer) handleAPIAlerts(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
limit := 20
|
||||
if q := r.URL.Query().Get("limit"); q != "" {
|
||||
if parsed, err := strconv.Atoi(q); err == nil && parsed > 0 && parsed <= 200 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
alerts := ds.integrityMonitor.GetRecentAlerts(limit)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"alerts": alerts,
|
||||
"count": len(alerts),
|
||||
"timestamp": time.Now(),
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(payload); err != nil {
|
||||
ds.logger.Error("Failed to encode alerts response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// handleStatic serves static assets (CSS, JS)
|
||||
func (ds *DashboardServer) handleStatic(w http.ResponseWriter, r *http.Request) {
|
||||
// For simplicity, we'll inline CSS and JS in the HTML template
|
||||
// In a production system, you'd serve actual static files
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
||||
// getDashboardTemplate returns the HTML template for the dashboard
|
||||
func (ds *DashboardServer) getDashboardTemplate() *template.Template {
|
||||
htmlTemplate := `
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>MEV Bot - Data Integrity Monitor</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
background-color: #f5f5f5;
|
||||
color: #333;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.header {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
padding: 1rem 0;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 0 1rem;
|
||||
}
|
||||
|
||||
.header h1 {
|
||||
font-size: 2rem;
|
||||
font-weight: 300;
|
||||
}
|
||||
|
||||
.header .subtitle {
|
||||
opacity: 0.9;
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.dashboard {
|
||||
padding: 2rem 0;
|
||||
}
|
||||
|
||||
.grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
|
||||
gap: 1.5rem;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.card {
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #667eea;
|
||||
}
|
||||
|
||||
.card h3 {
|
||||
color: #333;
|
||||
margin-bottom: 1rem;
|
||||
font-size: 1.25rem;
|
||||
}
|
||||
|
||||
.metric {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.5rem 0;
|
||||
border-bottom: 1px solid #eee;
|
||||
}
|
||||
|
||||
.metric:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.metric-label {
|
||||
font-weight: 500;
|
||||
color: #666;
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.health-score {
|
||||
font-size: 2rem;
|
||||
font-weight: bold;
|
||||
text-align: center;
|
||||
padding: 1rem;
|
||||
border-radius: 50%;
|
||||
width: 100px;
|
||||
height: 100px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin: 0 auto 1rem;
|
||||
}
|
||||
|
||||
.health-excellent { background: #4CAF50; color: white; }
|
||||
.health-good { background: #8BC34A; color: white; }
|
||||
.health-fair { background: #FF9800; color: white; }
|
||||
.health-poor { background: #F44336; color: white; }
|
||||
|
||||
.status-indicator {
|
||||
display: inline-block;
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.status-healthy { background: #4CAF50; }
|
||||
.status-warning { background: #FF9800; }
|
||||
.status-critical { background: #F44336; }
|
||||
|
||||
.chart-container {
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
||||
margin-top: 1.5rem;
|
||||
}
|
||||
|
||||
.refresh-indicator {
|
||||
position: fixed;
|
||||
top: 20px;
|
||||
right: 20px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.timestamp {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
font-size: 0.875rem;
|
||||
margin-top: 2rem;
|
||||
}
|
||||
|
||||
.recovery-actions {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 1rem;
|
||||
margin-top: 1rem;
|
||||
}
|
||||
|
||||
.recovery-action {
|
||||
background: #f8f9fa;
|
||||
padding: 0.75rem;
|
||||
border-radius: 4px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.recovery-action-count {
|
||||
font-size: 1.5rem;
|
||||
font-weight: bold;
|
||||
color: #667eea;
|
||||
}
|
||||
|
||||
.recovery-action-label {
|
||||
font-size: 0.875rem;
|
||||
color: #666;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<div class="container">
|
||||
<h1>🤖 MEV Bot - Data Integrity Monitor</h1>
|
||||
<p class="subtitle">Real-time monitoring of corruption detection and recovery systems</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="dashboard">
|
||||
<div class="container">
|
||||
<div class="grid">
|
||||
<!-- Health Score Card -->
|
||||
<div class="card">
|
||||
<h3>System Health</h3>
|
||||
<div class="health-score {{.HealthSummary.health_score | healthClass}}">
|
||||
{{.HealthSummary.health_score | printf "%.1f"}}
|
||||
</div>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Status</span>
|
||||
<span class="metric-value">
|
||||
<span class="status-indicator {{.HealthSummary.health_score | statusClass}}"></span>
|
||||
{{.HealthSummary.health_score | healthStatus}}
|
||||
</span>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Monitor Enabled</span>
|
||||
<span class="metric-value">{{if .HealthSummary.enabled}}✅ Yes{{else}}❌ No{{end}}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Processing Statistics -->
|
||||
<div class="card">
|
||||
<h3>Processing Statistics</h3>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Total Addresses</span>
|
||||
<span class="metric-value">{{.Metrics.TotalAddressesProcessed | printf "%,d"}}</span>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Corruption Detected</span>
|
||||
<span class="metric-value">{{.Metrics.CorruptAddressesDetected | printf "%,d"}}</span>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Corruption Rate</span>
|
||||
<span class="metric-value">{{.HealthSummary.corruption_rate | printf "%.4f%%"}}</span>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Avg Corruption Score</span>
|
||||
<span class="metric-value">{{.Metrics.AverageCorruptionScore | printf "%.1f"}}</span>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Max Corruption Score</span>
|
||||
<span class="metric-value">{{.Metrics.MaxCorruptionScore}}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Validation Results -->
|
||||
<div class="card">
|
||||
<h3>Validation Results</h3>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Validation Passed</span>
|
||||
<span class="metric-value">{{.Metrics.AddressValidationPassed | printf "%,d"}}</span>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Validation Failed</span>
|
||||
<span class="metric-value">{{.Metrics.AddressValidationFailed | printf "%,d"}}</span>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Success Rate</span>
|
||||
<span class="metric-value">{{.HealthSummary.validation_success_rate | printf "%.2f%%"}}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Contract Calls -->
|
||||
<div class="card">
|
||||
<h3>Contract Calls</h3>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Successful Calls</span>
|
||||
<span class="metric-value">{{.Metrics.ContractCallsSucceeded | printf "%,d"}}</span>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Failed Calls</span>
|
||||
<span class="metric-value">{{.Metrics.ContractCallsFailed | printf "%,d"}}</span>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<span class="metric-label">Success Rate</span>
|
||||
<span class="metric-value">{{.HealthSummary.contract_call_success_rate | printf "%.2f%%"}}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Recovery Actions -->
|
||||
<div class="chart-container">
|
||||
<h3>Recovery System Activity</h3>
|
||||
<div class="recovery-actions">
|
||||
<div class="recovery-action">
|
||||
<div class="recovery-action-count">{{.Metrics.RetryOperationsTriggered}}</div>
|
||||
<div class="recovery-action-label">Retry Operations</div>
|
||||
</div>
|
||||
<div class="recovery-action">
|
||||
<div class="recovery-action-count">{{.Metrics.FallbackOperationsUsed}}</div>
|
||||
<div class="recovery-action-label">Fallback Used</div>
|
||||
</div>
|
||||
<div class="recovery-action">
|
||||
<div class="recovery-action-count">{{.Metrics.CircuitBreakersTripped}}</div>
|
||||
<div class="recovery-action-label">Circuit Breakers</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="timestamp">
|
||||
Last updated: {{.Timestamp.Format "2006-01-02 15:04:05 UTC"}}
|
||||
<br>
|
||||
Auto-refresh every 30 seconds
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="refresh-indicator">🔄 Live</div>
|
||||
|
||||
<script>
|
||||
// Auto-refresh every 30 seconds
|
||||
setInterval(function() {
|
||||
window.location.reload();
|
||||
}, 30000);
|
||||
|
||||
// Add smooth transitions
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
const cards = document.querySelectorAll('.card');
|
||||
cards.forEach((card, index) => {
|
||||
card.style.animationDelay = (index * 0.1) + 's';
|
||||
card.style.animation = 'fadeInUp 0.6s ease forwards';
|
||||
});
|
||||
});
|
||||
</script>
|
||||
|
||||
<style>
|
||||
@keyframes fadeInUp {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
// Create template with custom functions
|
||||
funcMap := template.FuncMap{
|
||||
"healthClass": func(score interface{}) string {
|
||||
s := score.(float64)
|
||||
if s >= 0.9 {
|
||||
return "health-excellent"
|
||||
} else if s >= 0.7 {
|
||||
return "health-good"
|
||||
} else if s >= 0.5 {
|
||||
return "health-fair"
|
||||
}
|
||||
return "health-poor"
|
||||
},
|
||||
"statusClass": func(score interface{}) string {
|
||||
s := score.(float64)
|
||||
if s >= 0.7 {
|
||||
return "status-healthy"
|
||||
} else if s >= 0.5 {
|
||||
return "status-warning"
|
||||
}
|
||||
return "status-critical"
|
||||
},
|
||||
"healthStatus": func(score interface{}) string {
|
||||
s := score.(float64)
|
||||
if s >= 0.9 {
|
||||
return "Excellent"
|
||||
} else if s >= 0.7 {
|
||||
return "Good"
|
||||
} else if s >= 0.5 {
|
||||
return "Fair"
|
||||
}
|
||||
return "Poor"
|
||||
},
|
||||
}
|
||||
|
||||
return template.Must(template.New("dashboard").Funcs(funcMap).Parse(htmlTemplate))
|
||||
}
|
||||
|
||||
// GetDashboardURL returns the dashboard URL
|
||||
func (ds *DashboardServer) GetDashboardURL() string {
|
||||
return fmt.Sprintf("http://localhost:%d", ds.port)
|
||||
}
|
||||
447
orig/internal/monitoring/health_checker.go
Normal file
447
orig/internal/monitoring/health_checker.go
Normal file
@@ -0,0 +1,447 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// HealthCheckRunner performs periodic health checks and monitoring
|
||||
type HealthCheckRunner struct {
|
||||
mu sync.RWMutex
|
||||
logger *logger.Logger
|
||||
integrityMonitor *IntegrityMonitor
|
||||
checkInterval time.Duration
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
lastHealthCheck time.Time
|
||||
healthHistory []HealthSnapshot
|
||||
maxHistorySize int
|
||||
warmupSamples int
|
||||
minAddressesForAlerts int64
|
||||
}
|
||||
|
||||
// HealthSnapshot represents a point-in-time health snapshot
|
||||
type HealthSnapshot struct {
|
||||
Timestamp time.Time
|
||||
HealthScore float64
|
||||
CorruptionRate float64
|
||||
ValidationSuccess float64
|
||||
ContractCallSuccess float64
|
||||
ActiveAlerts int
|
||||
Trend HealthTrend
|
||||
}
|
||||
|
||||
// HealthTrend indicates the direction of health metrics
|
||||
type HealthTrend int
|
||||
|
||||
const (
|
||||
HealthTrendUnknown HealthTrend = iota
|
||||
HealthTrendImproving
|
||||
HealthTrendStable
|
||||
HealthTrendDeclining
|
||||
HealthTrendCritical
|
||||
)
|
||||
|
||||
func (t HealthTrend) String() string {
|
||||
switch t {
|
||||
case HealthTrendImproving:
|
||||
return "IMPROVING"
|
||||
case HealthTrendStable:
|
||||
return "STABLE"
|
||||
case HealthTrendDeclining:
|
||||
return "DECLINING"
|
||||
case HealthTrendCritical:
|
||||
return "CRITICAL"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// NewHealthCheckRunner creates a new health check runner
|
||||
func NewHealthCheckRunner(logger *logger.Logger, integrityMonitor *IntegrityMonitor) *HealthCheckRunner {
|
||||
return &HealthCheckRunner{
|
||||
logger: logger,
|
||||
integrityMonitor: integrityMonitor,
|
||||
checkInterval: 30 * time.Second, // Default 30 second intervals
|
||||
stopChan: make(chan struct{}),
|
||||
healthHistory: make([]HealthSnapshot, 0),
|
||||
maxHistorySize: 100, // Keep last 100 snapshots (50 minutes at 30s intervals)
|
||||
warmupSamples: 3,
|
||||
minAddressesForAlerts: 25,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the periodic health checking routine
|
||||
func (hcr *HealthCheckRunner) Start(ctx context.Context) {
|
||||
hcr.mu.Lock()
|
||||
if hcr.running {
|
||||
hcr.mu.Unlock()
|
||||
return
|
||||
}
|
||||
hcr.running = true
|
||||
hcr.mu.Unlock()
|
||||
|
||||
hcr.logger.Info("Starting health check runner",
|
||||
"interval", hcr.checkInterval)
|
||||
|
||||
go hcr.healthCheckLoop(ctx)
|
||||
}
|
||||
|
||||
// Stop stops the health checking routine
|
||||
func (hcr *HealthCheckRunner) Stop() {
|
||||
hcr.mu.Lock()
|
||||
defer hcr.mu.Unlock()
|
||||
|
||||
if !hcr.running {
|
||||
return
|
||||
}
|
||||
|
||||
hcr.running = false
|
||||
close(hcr.stopChan)
|
||||
hcr.logger.Info("Health check runner stopped")
|
||||
}
|
||||
|
||||
// healthCheckLoop runs the periodic health checking
|
||||
func (hcr *HealthCheckRunner) healthCheckLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(hcr.checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Perform initial health check
|
||||
hcr.performHealthCheck()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
hcr.logger.Info("Health check runner stopped due to context cancellation")
|
||||
return
|
||||
case <-hcr.stopChan:
|
||||
hcr.logger.Info("Health check runner stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
hcr.performHealthCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck executes a comprehensive health check
|
||||
func (hcr *HealthCheckRunner) performHealthCheck() {
|
||||
start := time.Now()
|
||||
hcr.lastHealthCheck = start
|
||||
|
||||
if !hcr.integrityMonitor.IsEnabled() {
|
||||
hcr.logger.Debug("Skipping health check - integrity monitor disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Get current metrics
|
||||
metrics := hcr.integrityMonitor.GetMetrics()
|
||||
healthSummary := hcr.integrityMonitor.GetHealthSummary()
|
||||
|
||||
// Calculate rates
|
||||
corruptionRate := 0.0
|
||||
if metrics.TotalAddressesProcessed > 0 {
|
||||
corruptionRate = float64(metrics.CorruptAddressesDetected) / float64(metrics.TotalAddressesProcessed)
|
||||
}
|
||||
|
||||
validationSuccessRate := 0.0
|
||||
totalValidations := metrics.AddressValidationPassed + metrics.AddressValidationFailed
|
||||
if totalValidations > 0 {
|
||||
validationSuccessRate = float64(metrics.AddressValidationPassed) / float64(totalValidations)
|
||||
}
|
||||
|
||||
contractCallSuccessRate := 0.0
|
||||
totalCalls := metrics.ContractCallsSucceeded + metrics.ContractCallsFailed
|
||||
if totalCalls > 0 {
|
||||
contractCallSuccessRate = float64(metrics.ContractCallsSucceeded) / float64(totalCalls)
|
||||
}
|
||||
|
||||
// Create health snapshot
|
||||
snapshot := HealthSnapshot{
|
||||
Timestamp: start,
|
||||
HealthScore: metrics.HealthScore,
|
||||
CorruptionRate: corruptionRate,
|
||||
ValidationSuccess: validationSuccessRate,
|
||||
ContractCallSuccess: contractCallSuccessRate,
|
||||
ActiveAlerts: 0, // Will be calculated based on current conditions
|
||||
Trend: hcr.calculateHealthTrend(metrics.HealthScore),
|
||||
}
|
||||
|
||||
// Add to history
|
||||
hcr.addHealthSnapshot(snapshot)
|
||||
|
||||
// Check for threshold violations and generate alerts
|
||||
hcr.checkThresholds(healthSummary, snapshot)
|
||||
|
||||
// Log health status periodically
|
||||
hcr.logHealthStatus(snapshot, time.Since(start))
|
||||
}
|
||||
|
||||
// addHealthSnapshot adds a snapshot to the health history
|
||||
func (hcr *HealthCheckRunner) addHealthSnapshot(snapshot HealthSnapshot) {
|
||||
hcr.mu.Lock()
|
||||
defer hcr.mu.Unlock()
|
||||
|
||||
hcr.healthHistory = append(hcr.healthHistory, snapshot)
|
||||
|
||||
// Trim history if it exceeds max size
|
||||
if len(hcr.healthHistory) > hcr.maxHistorySize {
|
||||
hcr.healthHistory = hcr.healthHistory[len(hcr.healthHistory)-hcr.maxHistorySize:]
|
||||
}
|
||||
}
|
||||
|
||||
// calculateHealthTrend analyzes recent health scores to determine trend
|
||||
func (hcr *HealthCheckRunner) calculateHealthTrend(currentScore float64) HealthTrend {
|
||||
hcr.mu.RLock()
|
||||
defer hcr.mu.RUnlock()
|
||||
|
||||
if len(hcr.healthHistory) < 3 {
|
||||
return HealthTrendUnknown
|
||||
}
|
||||
|
||||
// Get last few scores for trend analysis
|
||||
recentScores := make([]float64, 0, 5)
|
||||
start := len(hcr.healthHistory) - 5
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
for i := start; i < len(hcr.healthHistory); i++ {
|
||||
recentScores = append(recentScores, hcr.healthHistory[i].HealthScore)
|
||||
}
|
||||
recentScores = append(recentScores, currentScore)
|
||||
|
||||
// Calculate trend
|
||||
if currentScore < 0.5 {
|
||||
return HealthTrendCritical
|
||||
}
|
||||
|
||||
// Simple linear trend calculation
|
||||
if len(recentScores) >= 3 {
|
||||
first := recentScores[0]
|
||||
last := recentScores[len(recentScores)-1]
|
||||
diff := last - first
|
||||
|
||||
if diff > 0.05 {
|
||||
return HealthTrendImproving
|
||||
} else if diff < -0.05 {
|
||||
return HealthTrendDeclining
|
||||
} else {
|
||||
return HealthTrendStable
|
||||
}
|
||||
}
|
||||
|
||||
return HealthTrendUnknown
|
||||
}
|
||||
|
||||
// checkThresholds checks for threshold violations and generates alerts
|
||||
func (hcr *HealthCheckRunner) checkThresholds(healthSummary map[string]interface{}, snapshot HealthSnapshot) {
|
||||
if !hcr.readyForAlerts(healthSummary, snapshot) {
|
||||
hcr.logger.Debug("Health alerts suppressed during warm-up",
|
||||
"health_score", snapshot.HealthScore,
|
||||
"total_addresses_processed", safeNumericLookup(healthSummary, "total_addresses_processed"),
|
||||
"history_size", hcr.historySize())
|
||||
return
|
||||
}
|
||||
|
||||
// Critical health score alert
|
||||
if snapshot.HealthScore < 0.5 {
|
||||
alert := CorruptionAlert{
|
||||
Timestamp: time.Now(),
|
||||
Severity: AlertSeverityEmergency,
|
||||
Message: fmt.Sprintf("CRITICAL: System health score is %.2f (below 0.5)", snapshot.HealthScore),
|
||||
Context: map[string]interface{}{
|
||||
"health_score": snapshot.HealthScore,
|
||||
"corruption_rate": snapshot.CorruptionRate,
|
||||
"validation_success": snapshot.ValidationSuccess,
|
||||
"contract_call_success": snapshot.ContractCallSuccess,
|
||||
"trend": snapshot.Trend.String(),
|
||||
},
|
||||
}
|
||||
hcr.integrityMonitor.sendAlert(alert)
|
||||
}
|
||||
|
||||
// High corruption rate alert
|
||||
if snapshot.CorruptionRate > 0.10 { // 10% corruption rate
|
||||
alert := CorruptionAlert{
|
||||
Timestamp: time.Now(),
|
||||
Severity: AlertSeverityCritical,
|
||||
Message: fmt.Sprintf("High corruption rate detected: %.2f%%", snapshot.CorruptionRate*100),
|
||||
Context: map[string]interface{}{
|
||||
"corruption_rate": snapshot.CorruptionRate,
|
||||
"threshold": 0.10,
|
||||
"addresses_affected": snapshot.CorruptionRate,
|
||||
},
|
||||
}
|
||||
hcr.integrityMonitor.sendAlert(alert)
|
||||
}
|
||||
|
||||
// Declining trend alert
|
||||
if snapshot.Trend == HealthTrendDeclining || snapshot.Trend == HealthTrendCritical {
|
||||
alert := CorruptionAlert{
|
||||
Timestamp: time.Now(),
|
||||
Severity: AlertSeverityWarning,
|
||||
Message: fmt.Sprintf("System health trend is %s (current score: %.2f)", snapshot.Trend.String(), snapshot.HealthScore),
|
||||
Context: map[string]interface{}{
|
||||
"trend": snapshot.Trend.String(),
|
||||
"health_score": snapshot.HealthScore,
|
||||
"recent_snapshots": hcr.getRecentSnapshots(5),
|
||||
},
|
||||
}
|
||||
hcr.integrityMonitor.sendAlert(alert)
|
||||
}
|
||||
}
|
||||
|
||||
func (hcr *HealthCheckRunner) readyForAlerts(healthSummary map[string]interface{}, snapshot HealthSnapshot) bool {
|
||||
hcr.mu.RLock()
|
||||
historyLen := len(hcr.healthHistory)
|
||||
hcr.mu.RUnlock()
|
||||
|
||||
if historyLen < hcr.warmupSamples {
|
||||
return false
|
||||
}
|
||||
|
||||
totalProcessed := safeNumericLookup(healthSummary, "total_addresses_processed")
|
||||
if totalProcessed >= 0 && totalProcessed < float64(hcr.minAddressesForAlerts) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Require at least one validation or contract call attempt before alarming.
|
||||
if snapshot.ValidationSuccess == 0 && snapshot.ContractCallSuccess == 0 && totalProcessed == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func safeNumericLookup(summary map[string]interface{}, key string) float64 {
|
||||
if summary == nil {
|
||||
return -1
|
||||
}
|
||||
|
||||
value, ok := summary[key]
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return float64(v)
|
||||
case int32:
|
||||
return float64(v)
|
||||
case int64:
|
||||
return float64(v)
|
||||
case uint:
|
||||
return float64(v)
|
||||
case uint32:
|
||||
return float64(v)
|
||||
case uint64:
|
||||
return float64(v)
|
||||
case float32:
|
||||
return float64(v)
|
||||
case float64:
|
||||
return v
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func (hcr *HealthCheckRunner) historySize() int {
|
||||
hcr.mu.RLock()
|
||||
defer hcr.mu.RUnlock()
|
||||
return len(hcr.healthHistory)
|
||||
}
|
||||
|
||||
// logHealthStatus logs periodic health status information
|
||||
func (hcr *HealthCheckRunner) logHealthStatus(snapshot HealthSnapshot, duration time.Duration) {
|
||||
// Log detailed status every 5 minutes (10 checks at 30s intervals)
|
||||
if len(hcr.healthHistory)%10 == 0 {
|
||||
hcr.logger.Info("System health status",
|
||||
"health_score", snapshot.HealthScore,
|
||||
"corruption_rate", fmt.Sprintf("%.4f", snapshot.CorruptionRate),
|
||||
"validation_success", fmt.Sprintf("%.4f", snapshot.ValidationSuccess),
|
||||
"contract_call_success", fmt.Sprintf("%.4f", snapshot.ContractCallSuccess),
|
||||
"trend", snapshot.Trend.String(),
|
||||
"check_duration", duration)
|
||||
} else {
|
||||
// Brief status for regular checks
|
||||
hcr.logger.Debug("Health check completed",
|
||||
"health_score", snapshot.HealthScore,
|
||||
"trend", snapshot.Trend.String(),
|
||||
"duration", duration)
|
||||
}
|
||||
}
|
||||
|
||||
// GetRecentSnapshots returns the most recent health snapshots
|
||||
func (hcr *HealthCheckRunner) GetRecentSnapshots(count int) []HealthSnapshot {
|
||||
return hcr.getRecentSnapshots(count)
|
||||
}
|
||||
|
||||
// getRecentSnapshots internal implementation
|
||||
func (hcr *HealthCheckRunner) getRecentSnapshots(count int) []HealthSnapshot {
|
||||
hcr.mu.RLock()
|
||||
defer hcr.mu.RUnlock()
|
||||
|
||||
if len(hcr.healthHistory) == 0 {
|
||||
return []HealthSnapshot{}
|
||||
}
|
||||
|
||||
start := len(hcr.healthHistory) - count
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
// Create a copy to avoid external modification
|
||||
snapshots := make([]HealthSnapshot, len(hcr.healthHistory[start:]))
|
||||
copy(snapshots, hcr.healthHistory[start:])
|
||||
|
||||
return snapshots
|
||||
}
|
||||
|
||||
// GetHealthSummary returns a comprehensive health summary
|
||||
func (hcr *HealthCheckRunner) GetHealthSummary() map[string]interface{} {
|
||||
hcr.mu.RLock()
|
||||
defer hcr.mu.RUnlock()
|
||||
|
||||
if len(hcr.healthHistory) == 0 {
|
||||
return map[string]interface{}{
|
||||
"running": hcr.running,
|
||||
"check_interval": hcr.checkInterval.String(),
|
||||
"history_size": 0,
|
||||
"last_check": nil,
|
||||
}
|
||||
}
|
||||
|
||||
lastSnapshot := hcr.healthHistory[len(hcr.healthHistory)-1]
|
||||
|
||||
return map[string]interface{}{
|
||||
"running": hcr.running,
|
||||
"check_interval": hcr.checkInterval.String(),
|
||||
"history_size": len(hcr.healthHistory),
|
||||
"last_check": hcr.lastHealthCheck,
|
||||
"current_health_score": lastSnapshot.HealthScore,
|
||||
"current_trend": lastSnapshot.Trend.String(),
|
||||
"corruption_rate": lastSnapshot.CorruptionRate,
|
||||
"validation_success": lastSnapshot.ValidationSuccess,
|
||||
"contract_call_success": lastSnapshot.ContractCallSuccess,
|
||||
"recent_snapshots": hcr.getRecentSnapshots(10),
|
||||
}
|
||||
}
|
||||
|
||||
// SetCheckInterval sets the health check interval
|
||||
func (hcr *HealthCheckRunner) SetCheckInterval(interval time.Duration) {
|
||||
hcr.mu.Lock()
|
||||
defer hcr.mu.Unlock()
|
||||
hcr.checkInterval = interval
|
||||
hcr.logger.Info("Health check interval updated", "interval", interval)
|
||||
}
|
||||
|
||||
// IsRunning returns whether the health checker is running
|
||||
func (hcr *HealthCheckRunner) IsRunning() bool {
|
||||
hcr.mu.RLock()
|
||||
defer hcr.mu.RUnlock()
|
||||
return hcr.running
|
||||
}
|
||||
533
orig/internal/monitoring/integrity_monitor.go
Normal file
533
orig/internal/monitoring/integrity_monitor.go
Normal file
@@ -0,0 +1,533 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
"github.com/fraktal/mev-beta/internal/recovery"
|
||||
)
|
||||
|
||||
// IntegrityMetrics tracks data integrity statistics
|
||||
type IntegrityMetrics struct {
|
||||
mu sync.RWMutex
|
||||
TotalAddressesProcessed int64
|
||||
CorruptAddressesDetected int64
|
||||
AddressValidationPassed int64
|
||||
AddressValidationFailed int64
|
||||
ContractCallsSucceeded int64
|
||||
ContractCallsFailed int64
|
||||
RetryOperationsTriggered int64
|
||||
FallbackOperationsUsed int64
|
||||
CircuitBreakersTripped int64
|
||||
LastCorruptionDetection time.Time
|
||||
AverageCorruptionScore float64
|
||||
MaxCorruptionScore int
|
||||
HealthScore float64
|
||||
HighScore float64
|
||||
RecoveryActions map[recovery.RecoveryAction]int64
|
||||
ErrorsByType map[recovery.ErrorType]int64
|
||||
}
|
||||
|
||||
// MetricsSnapshot represents a copy of metrics without mutex for safe external access
|
||||
type MetricsSnapshot struct {
|
||||
TotalAddressesProcessed int64 `json:"total_addresses_processed"`
|
||||
CorruptAddressesDetected int64 `json:"corrupt_addresses_detected"`
|
||||
AddressValidationPassed int64 `json:"address_validation_passed"`
|
||||
AddressValidationFailed int64 `json:"address_validation_failed"`
|
||||
ContractCallsSucceeded int64 `json:"contract_calls_succeeded"`
|
||||
ContractCallsFailed int64 `json:"contract_calls_failed"`
|
||||
RetryOperationsTriggered int64 `json:"retry_operations_triggered"`
|
||||
FallbackOperationsUsed int64 `json:"fallback_operations_used"`
|
||||
CircuitBreakersTripped int64 `json:"circuit_breakers_tripped"`
|
||||
LastCorruptionDetection time.Time `json:"last_corruption_detection"`
|
||||
AverageCorruptionScore float64 `json:"average_corruption_score"`
|
||||
MaxCorruptionScore int `json:"max_corruption_score"`
|
||||
HealthScore float64 `json:"health_score"`
|
||||
HighScore float64 `json:"high_score"`
|
||||
RecoveryActions map[recovery.RecoveryAction]int64 `json:"recovery_actions"`
|
||||
ErrorsByType map[recovery.ErrorType]int64 `json:"errors_by_type"`
|
||||
}
|
||||
|
||||
// CorruptionAlert represents a corruption detection alert
|
||||
type CorruptionAlert struct {
|
||||
Timestamp time.Time
|
||||
Address common.Address
|
||||
CorruptionScore int
|
||||
Source string
|
||||
Severity AlertSeverity
|
||||
Message string
|
||||
Context map[string]interface{}
|
||||
}
|
||||
|
||||
// AlertSeverity defines alert severity levels
|
||||
type AlertSeverity int
|
||||
|
||||
const (
|
||||
AlertSeverityInfo AlertSeverity = iota
|
||||
AlertSeverityWarning
|
||||
AlertSeverityCritical
|
||||
AlertSeverityEmergency
|
||||
)
|
||||
|
||||
func (s AlertSeverity) String() string {
|
||||
switch s {
|
||||
case AlertSeverityInfo:
|
||||
return "INFO"
|
||||
case AlertSeverityWarning:
|
||||
return "WARNING"
|
||||
case AlertSeverityCritical:
|
||||
return "CRITICAL"
|
||||
case AlertSeverityEmergency:
|
||||
return "EMERGENCY"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// IntegrityMonitor monitors and tracks data integrity metrics
|
||||
type IntegrityMonitor struct {
|
||||
mu sync.RWMutex
|
||||
logger *logger.Logger
|
||||
metrics *IntegrityMetrics
|
||||
alertThresholds map[string]float64
|
||||
alertSubscribers []AlertSubscriber
|
||||
healthCheckRunner *HealthCheckRunner
|
||||
enabled bool
|
||||
alerts []CorruptionAlert
|
||||
alertsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// AlertSubscriber defines the interface for alert handlers
|
||||
type AlertSubscriber interface {
|
||||
HandleAlert(alert CorruptionAlert) error
|
||||
}
|
||||
|
||||
// NewIntegrityMonitor creates a new integrity monitoring system
|
||||
func NewIntegrityMonitor(logger *logger.Logger) *IntegrityMonitor {
|
||||
monitor := &IntegrityMonitor{
|
||||
logger: logger,
|
||||
metrics: &IntegrityMetrics{
|
||||
RecoveryActions: make(map[recovery.RecoveryAction]int64),
|
||||
ErrorsByType: make(map[recovery.ErrorType]int64),
|
||||
HealthScore: 1.0,
|
||||
HighScore: 1.0,
|
||||
},
|
||||
alertThresholds: make(map[string]float64),
|
||||
enabled: true,
|
||||
alerts: make([]CorruptionAlert, 0, 256),
|
||||
}
|
||||
|
||||
// Set default thresholds
|
||||
monitor.setDefaultThresholds()
|
||||
|
||||
// Initialize health check runner
|
||||
monitor.healthCheckRunner = NewHealthCheckRunner(logger, monitor)
|
||||
|
||||
return monitor
|
||||
}
|
||||
|
||||
// setDefaultThresholds configures default alert thresholds
|
||||
func (im *IntegrityMonitor) setDefaultThresholds() {
|
||||
im.alertThresholds["corruption_rate"] = 0.05 // 5% corruption rate
|
||||
im.alertThresholds["failure_rate"] = 0.10 // 10% failure rate
|
||||
im.alertThresholds["health_score_min"] = 0.80 // 80% minimum health
|
||||
im.alertThresholds["max_corruption_score"] = 70.0 // Maximum individual corruption score
|
||||
im.alertThresholds["circuit_breaker_rate"] = 0.02 // 2% circuit breaker rate
|
||||
}
|
||||
|
||||
// RecordAddressProcessed increments the counter for processed addresses
|
||||
func (im *IntegrityMonitor) RecordAddressProcessed() {
|
||||
if !im.enabled {
|
||||
return
|
||||
}
|
||||
|
||||
im.metrics.mu.Lock()
|
||||
im.metrics.TotalAddressesProcessed++
|
||||
im.metrics.mu.Unlock()
|
||||
|
||||
im.updateHealthScore()
|
||||
}
|
||||
|
||||
// RecordCorruptionDetected records a corruption detection event
|
||||
func (im *IntegrityMonitor) RecordCorruptionDetected(address common.Address, corruptionScore int, source string) {
|
||||
if !im.enabled {
|
||||
return
|
||||
}
|
||||
|
||||
im.metrics.mu.Lock()
|
||||
im.metrics.CorruptAddressesDetected++
|
||||
im.metrics.LastCorruptionDetection = time.Now()
|
||||
|
||||
// Update corruption statistics
|
||||
if corruptionScore > im.metrics.MaxCorruptionScore {
|
||||
im.metrics.MaxCorruptionScore = corruptionScore
|
||||
}
|
||||
|
||||
// Calculate rolling average corruption score
|
||||
total := float64(im.metrics.CorruptAddressesDetected)
|
||||
im.metrics.AverageCorruptionScore = ((im.metrics.AverageCorruptionScore * (total - 1)) + float64(corruptionScore)) / total
|
||||
im.metrics.mu.Unlock()
|
||||
|
||||
// Generate alert based on corruption score
|
||||
severity := im.getCorruptionSeverity(corruptionScore)
|
||||
alert := CorruptionAlert{
|
||||
Timestamp: time.Now(),
|
||||
Address: address,
|
||||
CorruptionScore: corruptionScore,
|
||||
Source: source,
|
||||
Severity: severity,
|
||||
Message: fmt.Sprintf("Corruption detected: address %s, score %d, source %s", address.Hex(), corruptionScore, source),
|
||||
Context: map[string]interface{}{
|
||||
"address": address.Hex(),
|
||||
"corruption_score": corruptionScore,
|
||||
"source": source,
|
||||
"timestamp": time.Now().Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
im.sendAlert(alert)
|
||||
im.updateHealthScore()
|
||||
|
||||
im.logger.Warn("Corruption detected",
|
||||
"address", address.Hex(),
|
||||
"corruption_score", corruptionScore,
|
||||
"source", source,
|
||||
"severity", severity.String())
|
||||
}
|
||||
|
||||
// RecordValidationResult records address validation results
|
||||
func (im *IntegrityMonitor) RecordValidationResult(passed bool) {
|
||||
if !im.enabled {
|
||||
return
|
||||
}
|
||||
|
||||
im.metrics.mu.Lock()
|
||||
if passed {
|
||||
im.metrics.AddressValidationPassed++
|
||||
} else {
|
||||
im.metrics.AddressValidationFailed++
|
||||
}
|
||||
im.metrics.mu.Unlock()
|
||||
|
||||
im.updateHealthScore()
|
||||
}
|
||||
|
||||
// RecordContractCallResult records contract call success/failure
|
||||
func (im *IntegrityMonitor) RecordContractCallResult(succeeded bool) {
|
||||
if !im.enabled {
|
||||
return
|
||||
}
|
||||
|
||||
im.metrics.mu.Lock()
|
||||
if succeeded {
|
||||
im.metrics.ContractCallsSucceeded++
|
||||
} else {
|
||||
im.metrics.ContractCallsFailed++
|
||||
}
|
||||
im.metrics.mu.Unlock()
|
||||
|
||||
im.updateHealthScore()
|
||||
}
|
||||
|
||||
// RecordRecoveryAction records recovery action usage
|
||||
func (im *IntegrityMonitor) RecordRecoveryAction(action recovery.RecoveryAction) {
|
||||
if !im.enabled {
|
||||
return
|
||||
}
|
||||
|
||||
im.metrics.mu.Lock()
|
||||
im.metrics.RecoveryActions[action]++
|
||||
|
||||
// Track specific metrics
|
||||
switch action {
|
||||
case recovery.ActionRetryWithBackoff:
|
||||
im.metrics.RetryOperationsTriggered++
|
||||
case recovery.ActionUseFallbackData:
|
||||
im.metrics.FallbackOperationsUsed++
|
||||
case recovery.ActionCircuitBreaker:
|
||||
im.metrics.CircuitBreakersTripped++
|
||||
}
|
||||
im.metrics.mu.Unlock()
|
||||
|
||||
im.updateHealthScore()
|
||||
}
|
||||
|
||||
// RecordErrorType records error by type
|
||||
func (im *IntegrityMonitor) RecordErrorType(errorType recovery.ErrorType) {
|
||||
if !im.enabled {
|
||||
return
|
||||
}
|
||||
|
||||
im.metrics.mu.Lock()
|
||||
im.metrics.ErrorsByType[errorType]++
|
||||
im.metrics.mu.Unlock()
|
||||
}
|
||||
|
||||
// getCorruptionSeverity determines alert severity based on corruption score
|
||||
func (im *IntegrityMonitor) getCorruptionSeverity(corruptionScore int) AlertSeverity {
|
||||
if corruptionScore >= 90 {
|
||||
return AlertSeverityEmergency
|
||||
} else if corruptionScore >= 70 {
|
||||
return AlertSeverityCritical
|
||||
} else if corruptionScore >= 40 {
|
||||
return AlertSeverityWarning
|
||||
}
|
||||
return AlertSeverityInfo
|
||||
}
|
||||
|
||||
// updateHealthScore calculates overall system health score
|
||||
func (im *IntegrityMonitor) updateHealthScore() {
|
||||
im.metrics.mu.Lock()
|
||||
defer im.metrics.mu.Unlock()
|
||||
|
||||
if im.metrics.TotalAddressesProcessed == 0 {
|
||||
im.metrics.HealthScore = 1.0
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate component scores
|
||||
corruptionRate := float64(im.metrics.CorruptAddressesDetected) / float64(im.metrics.TotalAddressesProcessed)
|
||||
|
||||
var validationSuccessRate float64 = 1.0
|
||||
validationTotal := im.metrics.AddressValidationPassed + im.metrics.AddressValidationFailed
|
||||
if validationTotal > 0 {
|
||||
validationSuccessRate = float64(im.metrics.AddressValidationPassed) / float64(validationTotal)
|
||||
}
|
||||
|
||||
var contractCallSuccessRate float64 = 1.0
|
||||
contractTotal := im.metrics.ContractCallsSucceeded + im.metrics.ContractCallsFailed
|
||||
if contractTotal > 0 {
|
||||
contractCallSuccessRate = float64(im.metrics.ContractCallsSucceeded) / float64(contractTotal)
|
||||
}
|
||||
|
||||
// Weighted health score calculation
|
||||
healthScore := 0.0
|
||||
healthScore += (1.0 - corruptionRate) * 0.4 // 40% weight on corruption prevention
|
||||
healthScore += validationSuccessRate * 0.3 // 30% weight on validation success
|
||||
healthScore += contractCallSuccessRate * 0.3 // 30% weight on contract call success
|
||||
|
||||
// Cap at 1.0 and handle edge cases
|
||||
if healthScore > 1.0 {
|
||||
healthScore = 1.0
|
||||
} else if healthScore < 0.0 {
|
||||
healthScore = 0.0
|
||||
}
|
||||
|
||||
im.metrics.HealthScore = healthScore
|
||||
if healthScore > im.metrics.HighScore {
|
||||
im.metrics.HighScore = healthScore
|
||||
}
|
||||
|
||||
// Check for health score threshold alerts
|
||||
if healthScore < im.alertThresholds["health_score_min"] {
|
||||
alert := CorruptionAlert{
|
||||
Timestamp: time.Now(),
|
||||
Severity: AlertSeverityCritical,
|
||||
Message: fmt.Sprintf("System health score dropped to %.2f (threshold: %.2f)", healthScore, im.alertThresholds["health_score_min"]),
|
||||
Context: map[string]interface{}{
|
||||
"health_score": healthScore,
|
||||
"threshold": im.alertThresholds["health_score_min"],
|
||||
"corruption_rate": corruptionRate,
|
||||
"validation_success": validationSuccessRate,
|
||||
"contract_call_success": contractCallSuccessRate,
|
||||
},
|
||||
}
|
||||
im.sendAlert(alert)
|
||||
}
|
||||
}
|
||||
|
||||
// sendAlert sends alerts to all subscribers
|
||||
func (im *IntegrityMonitor) sendAlert(alert CorruptionAlert) {
|
||||
im.alertsMutex.Lock()
|
||||
im.alerts = append(im.alerts, alert)
|
||||
if len(im.alerts) > 1000 {
|
||||
trimmed := make([]CorruptionAlert, 1000)
|
||||
copy(trimmed, im.alerts[len(im.alerts)-1000:])
|
||||
im.alerts = trimmed
|
||||
}
|
||||
im.alertsMutex.Unlock()
|
||||
|
||||
for _, subscriber := range im.alertSubscribers {
|
||||
if err := subscriber.HandleAlert(alert); err != nil {
|
||||
im.logger.Error("Failed to send alert",
|
||||
"subscriber", fmt.Sprintf("%T", subscriber),
|
||||
"error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddAlertSubscriber adds an alert subscriber
|
||||
func (im *IntegrityMonitor) AddAlertSubscriber(subscriber AlertSubscriber) {
|
||||
im.mu.Lock()
|
||||
defer im.mu.Unlock()
|
||||
im.alertSubscribers = append(im.alertSubscribers, subscriber)
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of current metrics
|
||||
func (im *IntegrityMonitor) GetMetrics() MetricsSnapshot {
|
||||
im.metrics.mu.RLock()
|
||||
defer im.metrics.mu.RUnlock()
|
||||
|
||||
// Create a deep copy
|
||||
metrics := IntegrityMetrics{
|
||||
TotalAddressesProcessed: im.metrics.TotalAddressesProcessed,
|
||||
CorruptAddressesDetected: im.metrics.CorruptAddressesDetected,
|
||||
AddressValidationPassed: im.metrics.AddressValidationPassed,
|
||||
AddressValidationFailed: im.metrics.AddressValidationFailed,
|
||||
ContractCallsSucceeded: im.metrics.ContractCallsSucceeded,
|
||||
ContractCallsFailed: im.metrics.ContractCallsFailed,
|
||||
RetryOperationsTriggered: im.metrics.RetryOperationsTriggered,
|
||||
FallbackOperationsUsed: im.metrics.FallbackOperationsUsed,
|
||||
CircuitBreakersTripped: im.metrics.CircuitBreakersTripped,
|
||||
LastCorruptionDetection: im.metrics.LastCorruptionDetection,
|
||||
AverageCorruptionScore: im.metrics.AverageCorruptionScore,
|
||||
MaxCorruptionScore: im.metrics.MaxCorruptionScore,
|
||||
HealthScore: im.metrics.HealthScore,
|
||||
HighScore: im.metrics.HighScore,
|
||||
RecoveryActions: make(map[recovery.RecoveryAction]int64),
|
||||
ErrorsByType: make(map[recovery.ErrorType]int64),
|
||||
}
|
||||
|
||||
// Copy maps
|
||||
for k, v := range im.metrics.RecoveryActions {
|
||||
metrics.RecoveryActions[k] = v
|
||||
}
|
||||
for k, v := range im.metrics.ErrorsByType {
|
||||
metrics.ErrorsByType[k] = v
|
||||
}
|
||||
|
||||
// Return a safe copy without mutex
|
||||
return MetricsSnapshot{
|
||||
TotalAddressesProcessed: metrics.TotalAddressesProcessed,
|
||||
CorruptAddressesDetected: metrics.CorruptAddressesDetected,
|
||||
AddressValidationPassed: metrics.AddressValidationPassed,
|
||||
AddressValidationFailed: metrics.AddressValidationFailed,
|
||||
ContractCallsSucceeded: metrics.ContractCallsSucceeded,
|
||||
ContractCallsFailed: metrics.ContractCallsFailed,
|
||||
RetryOperationsTriggered: metrics.RetryOperationsTriggered,
|
||||
FallbackOperationsUsed: metrics.FallbackOperationsUsed,
|
||||
CircuitBreakersTripped: metrics.CircuitBreakersTripped,
|
||||
LastCorruptionDetection: metrics.LastCorruptionDetection,
|
||||
AverageCorruptionScore: metrics.AverageCorruptionScore,
|
||||
MaxCorruptionScore: metrics.MaxCorruptionScore,
|
||||
HealthScore: metrics.HealthScore,
|
||||
HighScore: metrics.HighScore,
|
||||
RecoveryActions: metrics.RecoveryActions,
|
||||
ErrorsByType: metrics.ErrorsByType,
|
||||
}
|
||||
}
|
||||
|
||||
// GetHealthSummary returns a comprehensive health summary
|
||||
func (im *IntegrityMonitor) GetHealthSummary() map[string]interface{} {
|
||||
metrics := im.GetMetrics()
|
||||
|
||||
corruptionRate := 0.0
|
||||
if metrics.TotalAddressesProcessed > 0 {
|
||||
corruptionRate = float64(metrics.CorruptAddressesDetected) / float64(metrics.TotalAddressesProcessed)
|
||||
}
|
||||
|
||||
validationSuccessRate := 0.0
|
||||
totalValidations := metrics.AddressValidationPassed + metrics.AddressValidationFailed
|
||||
if totalValidations > 0 {
|
||||
validationSuccessRate = float64(metrics.AddressValidationPassed) / float64(totalValidations)
|
||||
}
|
||||
|
||||
contractCallSuccessRate := 0.0
|
||||
totalCalls := metrics.ContractCallsSucceeded + metrics.ContractCallsFailed
|
||||
if totalCalls > 0 {
|
||||
contractCallSuccessRate = float64(metrics.ContractCallsSucceeded) / float64(totalCalls)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"enabled": im.enabled,
|
||||
"health_score": metrics.HealthScore,
|
||||
"total_addresses_processed": metrics.TotalAddressesProcessed,
|
||||
"corruption_detections": metrics.CorruptAddressesDetected,
|
||||
"corruption_rate": corruptionRate,
|
||||
"validation_success_rate": validationSuccessRate,
|
||||
"contract_call_success_rate": contractCallSuccessRate,
|
||||
"average_corruption_score": metrics.AverageCorruptionScore,
|
||||
"max_corruption_score": metrics.MaxCorruptionScore,
|
||||
"retry_operations": metrics.RetryOperationsTriggered,
|
||||
"fallback_operations": metrics.FallbackOperationsUsed,
|
||||
"circuit_breakers_tripped": metrics.CircuitBreakersTripped,
|
||||
"last_corruption": metrics.LastCorruptionDetection,
|
||||
"recovery_actions": metrics.RecoveryActions,
|
||||
"errors_by_type": metrics.ErrorsByType,
|
||||
"alert_thresholds": im.alertThresholds,
|
||||
"alert_subscribers": len(im.alertSubscribers),
|
||||
}
|
||||
}
|
||||
|
||||
// GetRecentAlerts returns the most recent corruption alerts up to the specified limit.
|
||||
func (im *IntegrityMonitor) GetRecentAlerts(limit int) []CorruptionAlert {
|
||||
im.alertsMutex.RLock()
|
||||
defer im.alertsMutex.RUnlock()
|
||||
|
||||
if limit <= 0 || limit > len(im.alerts) {
|
||||
limit = len(im.alerts)
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
return []CorruptionAlert{}
|
||||
}
|
||||
|
||||
start := len(im.alerts) - limit
|
||||
alertsCopy := make([]CorruptionAlert, limit)
|
||||
copy(alertsCopy, im.alerts[start:])
|
||||
return alertsCopy
|
||||
}
|
||||
|
||||
// SetThreshold sets an alert threshold
|
||||
func (im *IntegrityMonitor) SetThreshold(name string, value float64) {
|
||||
im.mu.Lock()
|
||||
defer im.mu.Unlock()
|
||||
im.alertThresholds[name] = value
|
||||
}
|
||||
|
||||
// Enable enables the integrity monitor
|
||||
func (im *IntegrityMonitor) Enable() {
|
||||
im.mu.Lock()
|
||||
defer im.mu.Unlock()
|
||||
im.enabled = true
|
||||
im.logger.Info("Integrity monitor enabled")
|
||||
}
|
||||
|
||||
// Disable disables the integrity monitor
|
||||
func (im *IntegrityMonitor) Disable() {
|
||||
im.mu.Lock()
|
||||
defer im.mu.Unlock()
|
||||
im.enabled = false
|
||||
im.logger.Info("Integrity monitor disabled")
|
||||
}
|
||||
|
||||
// IsEnabled returns whether the monitor is enabled
|
||||
func (im *IntegrityMonitor) IsEnabled() bool {
|
||||
im.mu.RLock()
|
||||
defer im.mu.RUnlock()
|
||||
return im.enabled
|
||||
}
|
||||
|
||||
// StartHealthCheckRunner starts the periodic health check routine
|
||||
func (im *IntegrityMonitor) StartHealthCheckRunner(ctx context.Context) {
|
||||
if im.healthCheckRunner != nil {
|
||||
im.healthCheckRunner.Start(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// StopHealthCheckRunner stops the periodic health check routine
|
||||
func (im *IntegrityMonitor) StopHealthCheckRunner() {
|
||||
if im.healthCheckRunner != nil {
|
||||
im.healthCheckRunner.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// GetHealthCheckRunner returns the health check runner
|
||||
func (im *IntegrityMonitor) GetHealthCheckRunner() *HealthCheckRunner {
|
||||
return im.healthCheckRunner
|
||||
}
|
||||
391
orig/internal/monitoring/integrity_monitor_test.go
Normal file
391
orig/internal/monitoring/integrity_monitor_test.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
"github.com/fraktal/mev-beta/internal/recovery"
|
||||
)
|
||||
|
||||
// MockAlertSubscriber for testing
|
||||
type MockAlertSubscriber struct {
|
||||
alerts []CorruptionAlert
|
||||
}
|
||||
|
||||
func (m *MockAlertSubscriber) HandleAlert(alert CorruptionAlert) error {
|
||||
m.alerts = append(m.alerts, alert)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockAlertSubscriber) GetAlerts() []CorruptionAlert {
|
||||
return m.alerts
|
||||
}
|
||||
|
||||
func (m *MockAlertSubscriber) Reset() {
|
||||
m.alerts = nil
|
||||
}
|
||||
|
||||
func TestIntegrityMonitor_RecordCorruptionDetected(t *testing.T) {
|
||||
log := logger.New("error", "text", "")
|
||||
monitor := NewIntegrityMonitor(log)
|
||||
mockSubscriber := &MockAlertSubscriber{}
|
||||
monitor.AddAlertSubscriber(mockSubscriber)
|
||||
|
||||
// Test various corruption scenarios
|
||||
testCases := []struct {
|
||||
name string
|
||||
address string
|
||||
corruptionScore int
|
||||
source string
|
||||
expectedSeverity AlertSeverity
|
||||
}{
|
||||
{
|
||||
name: "Low corruption",
|
||||
address: "0x1234567890123456789012345678901234567890",
|
||||
corruptionScore: 30,
|
||||
source: "test_source",
|
||||
expectedSeverity: AlertSeverityInfo,
|
||||
},
|
||||
{
|
||||
name: "Medium corruption",
|
||||
address: "0x1234000000000000000000000000000000000000",
|
||||
corruptionScore: 50,
|
||||
source: "token_extraction",
|
||||
expectedSeverity: AlertSeverityWarning,
|
||||
},
|
||||
{
|
||||
name: "High corruption",
|
||||
address: "0x0000001000000000000000000000000000000000",
|
||||
corruptionScore: 80,
|
||||
source: "abi_decoder",
|
||||
expectedSeverity: AlertSeverityCritical,
|
||||
},
|
||||
{
|
||||
name: "Critical corruption - TOKEN_0x000000",
|
||||
address: "0x0000000300000000000000000000000000000000",
|
||||
corruptionScore: 100,
|
||||
source: "generic_extraction",
|
||||
expectedSeverity: AlertSeverityEmergency,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockSubscriber.Reset()
|
||||
|
||||
addr := common.HexToAddress(tc.address)
|
||||
monitor.RecordCorruptionDetected(addr, tc.corruptionScore, tc.source)
|
||||
|
||||
// Verify metrics were updated
|
||||
metrics := monitor.GetMetrics()
|
||||
assert.Greater(t, metrics.CorruptAddressesDetected, int64(0))
|
||||
assert.GreaterOrEqual(t, metrics.MaxCorruptionScore, tc.corruptionScore)
|
||||
|
||||
// Verify alert was generated
|
||||
alerts := mockSubscriber.GetAlerts()
|
||||
require.Len(t, alerts, 1)
|
||||
|
||||
alert := alerts[0]
|
||||
assert.Equal(t, tc.expectedSeverity, alert.Severity)
|
||||
assert.Equal(t, addr, alert.Address)
|
||||
assert.Equal(t, tc.corruptionScore, alert.CorruptionScore)
|
||||
assert.Equal(t, tc.source, alert.Source)
|
||||
assert.Contains(t, alert.Message, "Corruption detected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrityMonitor_HealthScoreCalculation(t *testing.T) {
|
||||
log := logger.New("error", "text", "")
|
||||
monitor := NewIntegrityMonitor(log)
|
||||
|
||||
// Test initial health score
|
||||
metrics := monitor.GetMetrics()
|
||||
assert.Equal(t, 1.0, metrics.HealthScore) // Perfect health initially
|
||||
|
||||
// Record some activity
|
||||
monitor.RecordAddressProcessed()
|
||||
monitor.RecordAddressProcessed()
|
||||
monitor.RecordValidationResult(true)
|
||||
monitor.RecordValidationResult(true)
|
||||
monitor.RecordContractCallResult(true)
|
||||
monitor.RecordContractCallResult(true)
|
||||
|
||||
// Health should still be perfect
|
||||
metrics = monitor.GetMetrics()
|
||||
assert.Equal(t, 1.0, metrics.HealthScore)
|
||||
|
||||
// Introduce some corruption
|
||||
addr := common.HexToAddress("0x0000000300000000000000000000000000000000")
|
||||
monitor.RecordCorruptionDetected(addr, 80, "test")
|
||||
|
||||
// Health score should decrease
|
||||
metrics = monitor.GetMetrics()
|
||||
assert.Less(t, metrics.HealthScore, 1.0)
|
||||
assert.Greater(t, metrics.HealthScore, 0.0)
|
||||
|
||||
// Add validation failures
|
||||
monitor.RecordValidationResult(false)
|
||||
monitor.RecordValidationResult(false)
|
||||
|
||||
// Health should decrease further
|
||||
newMetrics := monitor.GetMetrics()
|
||||
assert.Less(t, newMetrics.HealthScore, metrics.HealthScore)
|
||||
}
|
||||
|
||||
func TestIntegrityMonitor_RecoveryActionTracking(t *testing.T) {
|
||||
log := logger.New("error", "text", "")
|
||||
monitor := NewIntegrityMonitor(log)
|
||||
|
||||
// Record various recovery actions
|
||||
monitor.RecordRecoveryAction(recovery.ActionRetryWithBackoff)
|
||||
monitor.RecordRecoveryAction(recovery.ActionRetryWithBackoff)
|
||||
monitor.RecordRecoveryAction(recovery.ActionUseFallbackData)
|
||||
monitor.RecordRecoveryAction(recovery.ActionCircuitBreaker)
|
||||
|
||||
metrics := monitor.GetMetrics()
|
||||
|
||||
// Verify action counts
|
||||
assert.Equal(t, int64(2), metrics.RecoveryActions[recovery.ActionRetryWithBackoff])
|
||||
assert.Equal(t, int64(1), metrics.RecoveryActions[recovery.ActionUseFallbackData])
|
||||
assert.Equal(t, int64(1), metrics.RecoveryActions[recovery.ActionCircuitBreaker])
|
||||
|
||||
// Verify specific counters
|
||||
assert.Equal(t, int64(2), metrics.RetryOperationsTriggered)
|
||||
assert.Equal(t, int64(1), metrics.FallbackOperationsUsed)
|
||||
assert.Equal(t, int64(1), metrics.CircuitBreakersTripped)
|
||||
}
|
||||
|
||||
func TestIntegrityMonitor_ErrorTypeTracking(t *testing.T) {
|
||||
log := logger.New("error", "text", "")
|
||||
monitor := NewIntegrityMonitor(log)
|
||||
|
||||
// Record various error types
|
||||
errorTypes := []recovery.ErrorType{
|
||||
recovery.ErrorTypeAddressCorruption,
|
||||
recovery.ErrorTypeContractCallFailed,
|
||||
recovery.ErrorTypeRPCConnectionFailed,
|
||||
recovery.ErrorTypeDataParsingFailed,
|
||||
recovery.ErrorTypeValidationFailed,
|
||||
recovery.ErrorTypeAddressCorruption, // Duplicate
|
||||
}
|
||||
|
||||
for _, errorType := range errorTypes {
|
||||
monitor.RecordErrorType(errorType)
|
||||
}
|
||||
|
||||
metrics := monitor.GetMetrics()
|
||||
|
||||
// Verify error type counts
|
||||
assert.Equal(t, int64(2), metrics.ErrorsByType[recovery.ErrorTypeAddressCorruption])
|
||||
assert.Equal(t, int64(1), metrics.ErrorsByType[recovery.ErrorTypeContractCallFailed])
|
||||
assert.Equal(t, int64(1), metrics.ErrorsByType[recovery.ErrorTypeRPCConnectionFailed])
|
||||
assert.Equal(t, int64(1), metrics.ErrorsByType[recovery.ErrorTypeDataParsingFailed])
|
||||
assert.Equal(t, int64(1), metrics.ErrorsByType[recovery.ErrorTypeValidationFailed])
|
||||
}
|
||||
|
||||
func TestIntegrityMonitor_GetHealthSummary(t *testing.T) {
|
||||
log := logger.New("error", "text", "")
|
||||
monitor := NewIntegrityMonitor(log)
|
||||
|
||||
// Generate some activity
|
||||
for i := 0; i < 100; i++ {
|
||||
monitor.RecordAddressProcessed()
|
||||
if i%10 == 0 { // 10% corruption rate
|
||||
addr := common.HexToAddress(fmt.Sprintf("0x%040d", i))
|
||||
monitor.RecordCorruptionDetected(addr, 50, "test")
|
||||
}
|
||||
monitor.RecordValidationResult(i%20 != 0) // 95% success rate
|
||||
monitor.RecordContractCallResult(i%10 != 0) // 90% success rate
|
||||
}
|
||||
|
||||
summary := monitor.GetHealthSummary()
|
||||
|
||||
// Verify summary structure
|
||||
assert.True(t, summary["enabled"].(bool))
|
||||
assert.Equal(t, int64(100), summary["total_addresses_processed"].(int64))
|
||||
assert.Equal(t, int64(10), summary["corruption_detections"].(int64))
|
||||
assert.InDelta(t, 0.1, summary["corruption_rate"].(float64), 0.01)
|
||||
assert.InDelta(t, 0.95, summary["validation_success_rate"].(float64), 0.01)
|
||||
assert.InDelta(t, 0.9, summary["contract_call_success_rate"].(float64), 0.01)
|
||||
|
||||
// Health score should be reasonable
|
||||
healthScore := summary["health_score"].(float64)
|
||||
assert.Greater(t, healthScore, 0.7) // Should be decent despite some issues
|
||||
assert.Less(t, healthScore, 1.0) // Not perfect due to corruption
|
||||
}
|
||||
|
||||
func TestIntegrityMonitor_AlertThresholds(t *testing.T) {
|
||||
log := logger.New("error", "text", "")
|
||||
monitor := NewIntegrityMonitor(log)
|
||||
mockSubscriber := &MockAlertSubscriber{}
|
||||
monitor.AddAlertSubscriber(mockSubscriber)
|
||||
|
||||
// Test health score threshold
|
||||
monitor.SetThreshold("health_score_min", 0.8)
|
||||
|
||||
// Generate activity that drops health below threshold
|
||||
for i := 0; i < 50; i++ {
|
||||
monitor.RecordAddressProcessed()
|
||||
// High corruption rate to drop health score
|
||||
addr := common.HexToAddress(fmt.Sprintf("0x%040d", i))
|
||||
monitor.RecordCorruptionDetected(addr, 80, "test")
|
||||
}
|
||||
|
||||
// Should trigger health score alert
|
||||
alerts := mockSubscriber.GetAlerts()
|
||||
healthAlerts := 0
|
||||
for _, alert := range alerts {
|
||||
if alert.Severity == AlertSeverityCritical &&
|
||||
alert.Context != nil &&
|
||||
alert.Context["health_score"] != nil {
|
||||
healthAlerts++
|
||||
}
|
||||
}
|
||||
assert.Greater(t, healthAlerts, 0, "Should have triggered health score alerts")
|
||||
}
|
||||
|
||||
func TestIntegrityMonitor_ConcurrentAccess(t *testing.T) {
|
||||
log := logger.New("error", "text", "")
|
||||
monitor := NewIntegrityMonitor(log)
|
||||
|
||||
const numGoroutines = 50
|
||||
const operationsPerGoroutine = 100
|
||||
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
// Launch concurrent operations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := 0; j < operationsPerGoroutine; j++ {
|
||||
// Perform various operations
|
||||
monitor.RecordAddressProcessed()
|
||||
monitor.RecordValidationResult(j%10 != 0)
|
||||
monitor.RecordContractCallResult(j%5 != 0)
|
||||
|
||||
if j%20 == 0 { // Occasional corruption
|
||||
addr := common.HexToAddress(fmt.Sprintf("0x%020d%020d", id, j))
|
||||
monitor.RecordCorruptionDetected(addr, 60, fmt.Sprintf("goroutine_%d", id))
|
||||
}
|
||||
|
||||
// Recovery actions
|
||||
if j%15 == 0 {
|
||||
monitor.RecordRecoveryAction(recovery.ActionRetryWithBackoff)
|
||||
}
|
||||
if j%25 == 0 {
|
||||
monitor.RecordErrorType(recovery.ErrorTypeAddressCorruption)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for completion
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("Concurrent test timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify final metrics are consistent
|
||||
metrics := monitor.GetMetrics()
|
||||
expectedAddresses := int64(numGoroutines * operationsPerGoroutine)
|
||||
assert.Equal(t, expectedAddresses, metrics.TotalAddressesProcessed)
|
||||
|
||||
// Should have some corruption detections
|
||||
assert.Greater(t, metrics.CorruptAddressesDetected, int64(0))
|
||||
|
||||
// Should have recorded recovery actions
|
||||
assert.Greater(t, metrics.RetryOperationsTriggered, int64(0))
|
||||
|
||||
// Health score should be calculated
|
||||
assert.GreaterOrEqual(t, metrics.HealthScore, 0.0)
|
||||
assert.LessOrEqual(t, metrics.HealthScore, 1.0)
|
||||
|
||||
t.Logf("Final metrics: Processed=%d, Corrupted=%d, Health=%.3f",
|
||||
metrics.TotalAddressesProcessed,
|
||||
metrics.CorruptAddressesDetected,
|
||||
metrics.HealthScore)
|
||||
}
|
||||
|
||||
func TestIntegrityMonitor_DisableEnable(t *testing.T) {
|
||||
log := logger.New("error", "text", "")
|
||||
monitor := NewIntegrityMonitor(log)
|
||||
|
||||
// Should be enabled by default
|
||||
assert.True(t, monitor.IsEnabled())
|
||||
|
||||
// Record some activity
|
||||
monitor.RecordAddressProcessed()
|
||||
monitor.RecordValidationResult(true)
|
||||
|
||||
initialMetrics := monitor.GetMetrics()
|
||||
assert.Greater(t, initialMetrics.TotalAddressesProcessed, int64(0))
|
||||
|
||||
// Disable monitor
|
||||
monitor.Disable()
|
||||
assert.False(t, monitor.IsEnabled())
|
||||
|
||||
// Activity should not be recorded when disabled
|
||||
monitor.RecordAddressProcessed()
|
||||
monitor.RecordValidationResult(true)
|
||||
|
||||
disabledMetrics := monitor.GetMetrics()
|
||||
assert.Equal(t, initialMetrics.TotalAddressesProcessed, disabledMetrics.TotalAddressesProcessed)
|
||||
|
||||
// Re-enable
|
||||
monitor.Enable()
|
||||
assert.True(t, monitor.IsEnabled())
|
||||
|
||||
// Activity should be recorded again
|
||||
monitor.RecordAddressProcessed()
|
||||
enabledMetrics := monitor.GetMetrics()
|
||||
assert.Greater(t, enabledMetrics.TotalAddressesProcessed, disabledMetrics.TotalAddressesProcessed)
|
||||
}
|
||||
|
||||
func TestIntegrityMonitor_Performance(t *testing.T) {
|
||||
log := logger.New("error", "text", "")
|
||||
monitor := NewIntegrityMonitor(log)
|
||||
|
||||
const iterations = 10000
|
||||
|
||||
// Benchmark recording operations
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
monitor.RecordAddressProcessed()
|
||||
monitor.RecordValidationResult(i%10 != 0)
|
||||
monitor.RecordContractCallResult(i%5 != 0)
|
||||
|
||||
if i%100 == 0 {
|
||||
addr := common.HexToAddress(fmt.Sprintf("0x%040d", i))
|
||||
monitor.RecordCorruptionDetected(addr, 50, "benchmark")
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
avgTime := duration / iterations
|
||||
|
||||
t.Logf("Performance: %d operations in %v (avg: %v per operation)",
|
||||
iterations, duration, avgTime)
|
||||
|
||||
// Should be reasonably fast (under 500 microseconds per operation is acceptable)
|
||||
maxTime := 500 * time.Microsecond
|
||||
assert.Less(t, avgTime.Nanoseconds(), maxTime.Nanoseconds(),
|
||||
"Recording should be faster than %v per operation (got %v)", maxTime, avgTime)
|
||||
|
||||
// Verify metrics are accurate
|
||||
metrics := monitor.GetMetrics()
|
||||
assert.Equal(t, int64(iterations), metrics.TotalAddressesProcessed)
|
||||
assert.Equal(t, int64(100), metrics.CorruptAddressesDetected) // Every 100th iteration
|
||||
}
|
||||
494
orig/internal/ratelimit/adaptive.go
Normal file
494
orig/internal/ratelimit/adaptive.go
Normal file
@@ -0,0 +1,494 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/config"
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// AdaptiveRateLimiter implements adaptive rate limiting that adjusts to endpoint capacity
|
||||
type AdaptiveRateLimiter struct {
|
||||
endpoints map[string]*AdaptiveEndpoint
|
||||
mu sync.RWMutex
|
||||
logger *logger.Logger
|
||||
defaultConfig config.RateLimitConfig
|
||||
adjustInterval time.Duration
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// AdaptiveEndpoint represents an endpoint with adaptive rate limiting
|
||||
type AdaptiveEndpoint struct {
|
||||
URL string
|
||||
limiter *rate.Limiter
|
||||
config config.RateLimitConfig
|
||||
circuitBreaker *CircuitBreaker
|
||||
metrics *EndpointMetrics
|
||||
healthChecker *HealthChecker
|
||||
lastAdjustment time.Time
|
||||
consecutiveErrors int64
|
||||
consecutiveSuccess int64
|
||||
}
|
||||
|
||||
// EndpointMetrics tracks performance metrics for an endpoint
|
||||
// All fields must be 64-bit aligned for atomic access
|
||||
type EndpointMetrics struct {
|
||||
TotalRequests int64
|
||||
SuccessfulRequests int64
|
||||
FailedRequests int64
|
||||
TotalLatency int64 // nanoseconds
|
||||
LastRequestTime int64 // unix timestamp
|
||||
// Non-atomic fields - must be protected by mutex when accessed
|
||||
mu sync.RWMutex
|
||||
SuccessRate float64
|
||||
AverageLatency float64 // milliseconds
|
||||
}
|
||||
|
||||
// CircuitBreaker implements circuit breaker pattern for failed endpoints
|
||||
type CircuitBreaker struct {
|
||||
state int32 // 0: Closed, 1: Open, 2: HalfOpen
|
||||
failureCount int64
|
||||
lastFailTime int64
|
||||
threshold int64
|
||||
timeout time.Duration // How long to wait before trying again
|
||||
testRequests int64 // Number of test requests in half-open state
|
||||
}
|
||||
|
||||
// Circuit breaker states
|
||||
const (
|
||||
CircuitClosed = 0
|
||||
CircuitOpen = 1
|
||||
CircuitHalfOpen = 2
|
||||
)
|
||||
|
||||
// HealthChecker monitors endpoint health
|
||||
type HealthChecker struct {
|
||||
endpoint string
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
isHealthy int64 // atomic bool
|
||||
lastCheck int64 // unix timestamp
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewAdaptiveRateLimiter creates a new adaptive rate limiter
|
||||
func NewAdaptiveRateLimiter(cfg *config.ArbitrumConfig, logger *logger.Logger) *AdaptiveRateLimiter {
|
||||
arl := &AdaptiveRateLimiter{
|
||||
endpoints: make(map[string]*AdaptiveEndpoint),
|
||||
logger: logger,
|
||||
defaultConfig: cfg.RateLimit,
|
||||
adjustInterval: 30 * time.Second,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Create adaptive endpoint for primary endpoint
|
||||
arl.addEndpoint(cfg.RPCEndpoint, cfg.RateLimit)
|
||||
|
||||
// Create adaptive endpoints for reading endpoints
|
||||
for _, endpoint := range cfg.ReadingEndpoints {
|
||||
arl.addEndpoint(endpoint.URL, endpoint.RateLimit)
|
||||
}
|
||||
|
||||
// Create adaptive endpoints for execution endpoints
|
||||
for _, endpoint := range cfg.ExecutionEndpoints {
|
||||
arl.addEndpoint(endpoint.URL, endpoint.RateLimit)
|
||||
}
|
||||
|
||||
// Start background adjustment routine
|
||||
go arl.adjustmentLoop()
|
||||
|
||||
return arl
|
||||
}
|
||||
|
||||
// addEndpoint adds a new adaptive endpoint
|
||||
func (arl *AdaptiveRateLimiter) addEndpoint(url string, config config.RateLimitConfig) {
|
||||
endpoint := &AdaptiveEndpoint{
|
||||
URL: url,
|
||||
limiter: rate.NewLimiter(rate.Limit(config.RequestsPerSecond), config.Burst),
|
||||
config: config,
|
||||
circuitBreaker: &CircuitBreaker{
|
||||
threshold: 10, // Break after 10 consecutive failures
|
||||
timeout: 60 * time.Second,
|
||||
},
|
||||
metrics: &EndpointMetrics{},
|
||||
healthChecker: &HealthChecker{
|
||||
endpoint: url,
|
||||
interval: 30 * time.Second,
|
||||
timeout: 5 * time.Second,
|
||||
isHealthy: 1, // Start assuming healthy
|
||||
stopChan: make(chan struct{}),
|
||||
},
|
||||
}
|
||||
|
||||
arl.mu.Lock()
|
||||
arl.endpoints[url] = endpoint
|
||||
arl.mu.Unlock()
|
||||
|
||||
// Start health checker for this endpoint
|
||||
go endpoint.healthChecker.start()
|
||||
|
||||
arl.logger.Info(fmt.Sprintf("Added adaptive rate limiter for endpoint: %s", url))
|
||||
}
|
||||
|
||||
// WaitForBestEndpoint waits for the best available endpoint
|
||||
func (arl *AdaptiveRateLimiter) WaitForBestEndpoint(ctx context.Context) (string, error) {
|
||||
// Find the best available endpoint
|
||||
bestEndpoint := arl.getBestEndpoint()
|
||||
if bestEndpoint == "" {
|
||||
return "", fmt.Errorf("no healthy endpoints available")
|
||||
}
|
||||
|
||||
// Wait for rate limiter
|
||||
arl.mu.RLock()
|
||||
endpoint := arl.endpoints[bestEndpoint]
|
||||
arl.mu.RUnlock()
|
||||
|
||||
if endpoint == nil {
|
||||
return "", fmt.Errorf("endpoint not found: %s", bestEndpoint)
|
||||
}
|
||||
|
||||
// Check circuit breaker
|
||||
if !endpoint.circuitBreaker.canExecute() {
|
||||
return "", fmt.Errorf("circuit breaker open for endpoint: %s", bestEndpoint)
|
||||
}
|
||||
|
||||
// Wait for rate limiter
|
||||
err := endpoint.limiter.Wait(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return bestEndpoint, nil
|
||||
}
|
||||
|
||||
// RecordResult records the result of a request for adaptive adjustment
|
||||
func (arl *AdaptiveRateLimiter) RecordResult(endpointURL string, success bool, latency time.Duration) {
|
||||
arl.mu.RLock()
|
||||
endpoint, exists := arl.endpoints[endpointURL]
|
||||
arl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Update metrics atomically
|
||||
atomic.AddInt64(&endpoint.metrics.TotalRequests, 1)
|
||||
atomic.AddInt64(&endpoint.metrics.TotalLatency, latency.Nanoseconds())
|
||||
atomic.StoreInt64(&endpoint.metrics.LastRequestTime, time.Now().Unix())
|
||||
|
||||
if success {
|
||||
atomic.AddInt64(&endpoint.metrics.SuccessfulRequests, 1)
|
||||
atomic.AddInt64(&endpoint.consecutiveSuccess, 1)
|
||||
atomic.StoreInt64(&endpoint.consecutiveErrors, 0)
|
||||
endpoint.circuitBreaker.recordSuccess()
|
||||
} else {
|
||||
atomic.AddInt64(&endpoint.metrics.FailedRequests, 1)
|
||||
atomic.AddInt64(&endpoint.consecutiveErrors, 1)
|
||||
atomic.StoreInt64(&endpoint.consecutiveSuccess, 0)
|
||||
endpoint.circuitBreaker.recordFailure()
|
||||
}
|
||||
|
||||
// Update calculated metrics
|
||||
arl.updateCalculatedMetrics(endpoint)
|
||||
}
|
||||
|
||||
// updateCalculatedMetrics updates derived metrics
|
||||
func (arl *AdaptiveRateLimiter) updateCalculatedMetrics(endpoint *AdaptiveEndpoint) {
|
||||
totalReq := atomic.LoadInt64(&endpoint.metrics.TotalRequests)
|
||||
successReq := atomic.LoadInt64(&endpoint.metrics.SuccessfulRequests)
|
||||
totalLatency := atomic.LoadInt64(&endpoint.metrics.TotalLatency)
|
||||
|
||||
if totalReq > 0 {
|
||||
endpoint.metrics.SuccessRate = float64(successReq) / float64(totalReq)
|
||||
endpoint.metrics.AverageLatency = float64(totalLatency) / float64(totalReq) / 1000000 // Convert to milliseconds
|
||||
}
|
||||
}
|
||||
|
||||
// getBestEndpoint selects the best available endpoint based on metrics
|
||||
func (arl *AdaptiveRateLimiter) getBestEndpoint() string {
|
||||
arl.mu.RLock()
|
||||
defer arl.mu.RUnlock()
|
||||
|
||||
bestEndpoint := ""
|
||||
bestScore := float64(-1)
|
||||
|
||||
for url, endpoint := range arl.endpoints {
|
||||
// Skip unhealthy endpoints
|
||||
if atomic.LoadInt64(&endpoint.healthChecker.isHealthy) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if circuit breaker is open
|
||||
if !endpoint.circuitBreaker.canExecute() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Calculate score based on success rate, latency, and current load
|
||||
score := arl.calculateEndpointScore(endpoint)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestEndpoint = url
|
||||
}
|
||||
}
|
||||
|
||||
return bestEndpoint
|
||||
}
|
||||
|
||||
// updateDerivedMetrics safely updates calculated metrics with proper synchronization
|
||||
func (em *EndpointMetrics) updateDerivedMetrics() {
|
||||
totalRequests := atomic.LoadInt64(&em.TotalRequests)
|
||||
successfulRequests := atomic.LoadInt64(&em.SuccessfulRequests)
|
||||
totalLatency := atomic.LoadInt64(&em.TotalLatency)
|
||||
|
||||
em.mu.Lock()
|
||||
defer em.mu.Unlock()
|
||||
|
||||
// Calculate success rate
|
||||
if totalRequests > 0 {
|
||||
em.SuccessRate = float64(successfulRequests) / float64(totalRequests)
|
||||
} else {
|
||||
em.SuccessRate = 0.0
|
||||
}
|
||||
|
||||
// Calculate average latency in milliseconds
|
||||
if totalRequests > 0 {
|
||||
em.AverageLatency = float64(totalLatency) / float64(totalRequests) / 1e6 // ns to ms
|
||||
} else {
|
||||
em.AverageLatency = 0.0
|
||||
}
|
||||
}
|
||||
|
||||
// getCalculatedMetrics safely returns derived metrics
|
||||
func (em *EndpointMetrics) getCalculatedMetrics() (float64, float64) {
|
||||
em.mu.RLock()
|
||||
defer em.mu.RUnlock()
|
||||
return em.SuccessRate, em.AverageLatency
|
||||
}
|
||||
|
||||
// calculateEndpointScore calculates a score for endpoint selection
|
||||
func (arl *AdaptiveRateLimiter) calculateEndpointScore(endpoint *AdaptiveEndpoint) float64 {
|
||||
// Base score on success rate (0-1)
|
||||
successWeight := 0.6
|
||||
latencyWeight := 0.3
|
||||
loadWeight := 0.1
|
||||
|
||||
// Update derived metrics first
|
||||
endpoint.metrics.updateDerivedMetrics()
|
||||
|
||||
// Get calculated metrics safely
|
||||
successScore, avgLatency := endpoint.metrics.getCalculatedMetrics()
|
||||
|
||||
// Invert latency score (lower latency = higher score)
|
||||
latencyScore := 1.0
|
||||
if avgLatency > 0 {
|
||||
// Normalize latency score (assuming 1000ms is poor, 100ms is good)
|
||||
latencyScore = 1.0 - (avgLatency / 1000.0)
|
||||
if latencyScore < 0 {
|
||||
latencyScore = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Load score based on current rate limiter state
|
||||
loadScore := 1.0 // Simplified - could check current tokens in limiter
|
||||
|
||||
return successScore*successWeight + latencyScore*latencyWeight + loadScore*loadWeight
|
||||
}
|
||||
|
||||
// adjustmentLoop runs periodic adjustments to rate limits
|
||||
func (arl *AdaptiveRateLimiter) adjustmentLoop() {
|
||||
ticker := time.NewTicker(arl.adjustInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
arl.adjustRateLimits()
|
||||
case <-arl.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// adjustRateLimits adjusts rate limits based on observed performance
|
||||
func (arl *AdaptiveRateLimiter) adjustRateLimits() {
|
||||
arl.mu.Lock()
|
||||
defer arl.mu.Unlock()
|
||||
|
||||
for url, endpoint := range arl.endpoints {
|
||||
arl.adjustEndpointRateLimit(url, endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
// adjustEndpointRateLimit adjusts rate limit for a specific endpoint
|
||||
func (arl *AdaptiveRateLimiter) adjustEndpointRateLimit(url string, endpoint *AdaptiveEndpoint) {
|
||||
// Don't adjust too frequently
|
||||
if time.Since(endpoint.lastAdjustment) < arl.adjustInterval {
|
||||
return
|
||||
}
|
||||
|
||||
successRate := endpoint.metrics.SuccessRate
|
||||
avgLatency := endpoint.metrics.AverageLatency
|
||||
currentLimit := float64(endpoint.limiter.Limit())
|
||||
|
||||
var newLimit float64 = currentLimit
|
||||
adjustmentFactor := 0.1 // 10% adjustment
|
||||
|
||||
// Increase rate if performing well
|
||||
if successRate > 0.95 && avgLatency < 500 { // 95% success, < 500ms latency
|
||||
newLimit = currentLimit * (1.0 + adjustmentFactor)
|
||||
} else if successRate < 0.8 || avgLatency > 2000 { // < 80% success or > 2s latency
|
||||
newLimit = currentLimit * (1.0 - adjustmentFactor)
|
||||
}
|
||||
|
||||
// Apply bounds
|
||||
minLimit := float64(arl.defaultConfig.RequestsPerSecond) * 0.1 // 10% of default minimum
|
||||
maxLimit := float64(arl.defaultConfig.RequestsPerSecond) * 3.0 // 300% of default maximum
|
||||
|
||||
if newLimit < minLimit {
|
||||
newLimit = minLimit
|
||||
}
|
||||
if newLimit > maxLimit {
|
||||
newLimit = maxLimit
|
||||
}
|
||||
|
||||
// Update if changed significantly
|
||||
if abs(newLimit-currentLimit)/currentLimit > 0.05 { // 5% change threshold
|
||||
endpoint.limiter.SetLimit(rate.Limit(newLimit))
|
||||
endpoint.lastAdjustment = time.Now()
|
||||
|
||||
arl.logger.Info(fmt.Sprintf("Adjusted rate limit for %s: %.2f -> %.2f (success: %.2f%%, latency: %.2fms)",
|
||||
url, currentLimit, newLimit, successRate*100, avgLatency))
|
||||
}
|
||||
}
|
||||
|
||||
// abs returns absolute value of float64
|
||||
func abs(x float64) float64 {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// canExecute checks if circuit breaker allows execution
|
||||
func (cb *CircuitBreaker) canExecute() bool {
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
now := time.Now().Unix()
|
||||
|
||||
switch state {
|
||||
case CircuitClosed:
|
||||
return true
|
||||
case CircuitOpen:
|
||||
// Check if timeout has passed
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailTime)
|
||||
if now-lastFail > int64(cb.timeout.Seconds()) {
|
||||
// Try to move to half-open
|
||||
if atomic.CompareAndSwapInt32(&cb.state, CircuitOpen, CircuitHalfOpen) {
|
||||
atomic.StoreInt64(&cb.testRequests, 0)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
case CircuitHalfOpen:
|
||||
// Allow limited test requests
|
||||
testReq := atomic.LoadInt64(&cb.testRequests)
|
||||
if testReq < 3 { // Allow up to 3 test requests
|
||||
atomic.AddInt64(&cb.testRequests, 1)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// recordSuccess records a successful request
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
if state == CircuitHalfOpen {
|
||||
// Move back to closed after successful test
|
||||
atomic.StoreInt32(&cb.state, CircuitClosed)
|
||||
atomic.StoreInt64(&cb.failureCount, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed request
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
failures := atomic.AddInt64(&cb.failureCount, 1)
|
||||
atomic.StoreInt64(&cb.lastFailTime, time.Now().Unix())
|
||||
|
||||
if failures >= cb.threshold {
|
||||
atomic.StoreInt32(&cb.state, CircuitOpen)
|
||||
}
|
||||
}
|
||||
|
||||
// start starts the health checker
|
||||
func (hc *HealthChecker) start() {
|
||||
ticker := time.NewTicker(hc.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
hc.checkHealth()
|
||||
case <-hc.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkHealth performs a health check on the endpoint
|
||||
func (hc *HealthChecker) checkHealth() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hc.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Simple health check - try to connect
|
||||
// In production, this might make a simple RPC call
|
||||
healthy := hc.performHealthCheck(ctx)
|
||||
|
||||
if healthy {
|
||||
atomic.StoreInt64(&hc.isHealthy, 1)
|
||||
} else {
|
||||
atomic.StoreInt64(&hc.isHealthy, 0)
|
||||
}
|
||||
|
||||
atomic.StoreInt64(&hc.lastCheck, time.Now().Unix())
|
||||
}
|
||||
|
||||
// performHealthCheck performs the actual health check
|
||||
func (hc *HealthChecker) performHealthCheck(ctx context.Context) bool {
|
||||
// Simplified health check - in production would make actual RPC call
|
||||
// For now, just simulate based on endpoint availability
|
||||
return true // Assume healthy for demo
|
||||
}
|
||||
|
||||
// Stop stops the adaptive rate limiter
|
||||
func (arl *AdaptiveRateLimiter) Stop() {
|
||||
close(arl.stopChan)
|
||||
|
||||
// Stop all health checkers
|
||||
arl.mu.RLock()
|
||||
for _, endpoint := range arl.endpoints {
|
||||
close(endpoint.healthChecker.stopChan)
|
||||
}
|
||||
arl.mu.RUnlock()
|
||||
}
|
||||
|
||||
// GetMetrics returns current metrics for all endpoints
|
||||
func (arl *AdaptiveRateLimiter) GetMetrics() map[string]*EndpointMetrics {
|
||||
arl.mu.RLock()
|
||||
defer arl.mu.RUnlock()
|
||||
|
||||
metrics := make(map[string]*EndpointMetrics)
|
||||
for url, endpoint := range arl.endpoints {
|
||||
// Update calculated metrics before returning
|
||||
arl.updateCalculatedMetrics(endpoint)
|
||||
metrics[url] = endpoint.metrics
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
139
orig/internal/ratelimit/manager.go
Normal file
139
orig/internal/ratelimit/manager.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/config"
|
||||
)
|
||||
|
||||
// LimiterManager manages rate limiters for multiple endpoints
|
||||
type LimiterManager struct {
|
||||
limiters map[string]*EndpointLimiter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// EndpointLimiter represents a rate limiter for a specific endpoint
|
||||
type EndpointLimiter struct {
|
||||
URL string
|
||||
Limiter *rate.Limiter
|
||||
Config config.RateLimitConfig
|
||||
}
|
||||
|
||||
// NewLimiterManager creates a new LimiterManager
|
||||
func NewLimiterManager(cfg *config.ArbitrumConfig) *LimiterManager {
|
||||
lm := &LimiterManager{
|
||||
limiters: make(map[string]*EndpointLimiter),
|
||||
}
|
||||
|
||||
// Create limiter for primary endpoint
|
||||
limiter := createLimiter(cfg.RateLimit)
|
||||
lm.limiters[cfg.RPCEndpoint] = &EndpointLimiter{
|
||||
URL: cfg.RPCEndpoint,
|
||||
Limiter: limiter,
|
||||
Config: cfg.RateLimit,
|
||||
}
|
||||
|
||||
// Create limiters for reading endpoints
|
||||
for _, endpoint := range cfg.ReadingEndpoints {
|
||||
limiter := createLimiter(endpoint.RateLimit)
|
||||
lm.limiters[endpoint.URL] = &EndpointLimiter{
|
||||
URL: endpoint.URL,
|
||||
Limiter: limiter,
|
||||
Config: endpoint.RateLimit,
|
||||
}
|
||||
}
|
||||
|
||||
// Create limiters for execution endpoints
|
||||
for _, endpoint := range cfg.ExecutionEndpoints {
|
||||
limiter := createLimiter(endpoint.RateLimit)
|
||||
lm.limiters[endpoint.URL] = &EndpointLimiter{
|
||||
URL: endpoint.URL,
|
||||
Limiter: limiter,
|
||||
Config: endpoint.RateLimit,
|
||||
}
|
||||
}
|
||||
|
||||
return lm
|
||||
}
|
||||
|
||||
// createLimiter creates a rate limiter based on the configuration
|
||||
func createLimiter(cfg config.RateLimitConfig) *rate.Limiter {
|
||||
// Create a rate limiter with the specified rate and burst
|
||||
r := rate.Limit(cfg.RequestsPerSecond)
|
||||
return rate.NewLimiter(r, cfg.Burst)
|
||||
}
|
||||
|
||||
// WaitForLimit waits for the rate limiter to allow a request
|
||||
func (lm *LimiterManager) WaitForLimit(ctx context.Context, endpointURL string) error {
|
||||
lm.mu.RLock()
|
||||
limiter, exists := lm.limiters[endpointURL]
|
||||
lm.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("no rate limiter found for endpoint: %s", endpointURL)
|
||||
}
|
||||
|
||||
// Wait for permission to make a request
|
||||
return limiter.Limiter.Wait(ctx)
|
||||
}
|
||||
|
||||
// TryWaitForLimit tries to wait for the rate limiter to allow a request without blocking
|
||||
func (lm *LimiterManager) TryWaitForLimit(ctx context.Context, endpointURL string) error {
|
||||
lm.mu.RLock()
|
||||
limiter, exists := lm.limiters[endpointURL]
|
||||
lm.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("no rate limiter found for endpoint: %s", endpointURL)
|
||||
}
|
||||
|
||||
// Try to wait for permission to make a request without blocking
|
||||
if !limiter.Limiter.Allow() {
|
||||
return fmt.Errorf("rate limit exceeded for endpoint: %s", endpointURL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLimiter returns the rate limiter for a specific endpoint
|
||||
func (lm *LimiterManager) GetLimiter(endpointURL string) (*rate.Limiter, error) {
|
||||
lm.mu.RLock()
|
||||
limiter, exists := lm.limiters[endpointURL]
|
||||
lm.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no rate limiter found for endpoint: %s", endpointURL)
|
||||
}
|
||||
|
||||
return limiter.Limiter, nil
|
||||
}
|
||||
|
||||
// UpdateLimiter updates the rate limiter for an endpoint
|
||||
func (lm *LimiterManager) UpdateLimiter(endpointURL string, cfg config.RateLimitConfig) {
|
||||
lm.mu.Lock()
|
||||
defer lm.mu.Unlock()
|
||||
|
||||
limiter := createLimiter(cfg)
|
||||
lm.limiters[endpointURL] = &EndpointLimiter{
|
||||
URL: endpointURL,
|
||||
Limiter: limiter,
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GetEndpoints returns all endpoint URLs
|
||||
func (lm *LimiterManager) GetEndpoints() []string {
|
||||
lm.mu.RLock()
|
||||
defer lm.mu.RUnlock()
|
||||
|
||||
endpoints := make([]string, 0, len(lm.limiters))
|
||||
for url := range lm.limiters {
|
||||
endpoints = append(endpoints, url)
|
||||
}
|
||||
|
||||
return endpoints
|
||||
}
|
||||
243
orig/internal/ratelimit/manager_test.go
Normal file
243
orig/internal/ratelimit/manager_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/config"
|
||||
)
|
||||
|
||||
func TestNewLimiterManager(t *testing.T) {
|
||||
// Create test config
|
||||
cfg := &config.ArbitrumConfig{
|
||||
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 10,
|
||||
Burst: 20,
|
||||
},
|
||||
ReadingEndpoints: []config.EndpointConfig{
|
||||
{
|
||||
URL: "https://read.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 5,
|
||||
Burst: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
ExecutionEndpoints: []config.EndpointConfig{
|
||||
{
|
||||
URL: "https://exec.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 3,
|
||||
Burst: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create limiter manager
|
||||
lm := NewLimiterManager(cfg)
|
||||
|
||||
// Verify limiter manager was created correctly
|
||||
assert.NotNil(t, lm)
|
||||
assert.NotNil(t, lm.limiters)
|
||||
assert.Len(t, lm.limiters, 3) // Primary + 1 fallback
|
||||
|
||||
// Check primary endpoint limiter
|
||||
primaryLimiter, exists := lm.limiters[cfg.RPCEndpoint]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, cfg.RPCEndpoint, primaryLimiter.URL)
|
||||
assert.Equal(t, cfg.RateLimit, primaryLimiter.Config)
|
||||
assert.NotNil(t, primaryLimiter.Limiter)
|
||||
|
||||
// Check fallback endpoint limiter
|
||||
fallbackLimiter, exists := lm.limiters[cfg.ReadingEndpoints[0].URL]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, cfg.ReadingEndpoints[0].URL, fallbackLimiter.URL)
|
||||
assert.Equal(t, cfg.ReadingEndpoints[0].RateLimit, fallbackLimiter.Config)
|
||||
assert.NotNil(t, fallbackLimiter.Limiter)
|
||||
}
|
||||
|
||||
func TestWaitForLimit(t *testing.T) {
|
||||
// Create test config
|
||||
cfg := &config.ArbitrumConfig{
|
||||
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 10,
|
||||
Burst: 20,
|
||||
},
|
||||
}
|
||||
|
||||
// Create limiter manager
|
||||
lm := NewLimiterManager(cfg)
|
||||
|
||||
// Test waiting for limit on existing endpoint
|
||||
ctx := context.Background()
|
||||
err := lm.WaitForLimit(ctx, cfg.RPCEndpoint)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test waiting for limit on non-existing endpoint
|
||||
err = lm.WaitForLimit(ctx, "https://nonexistent.com")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no rate limiter found for endpoint")
|
||||
}
|
||||
|
||||
func TestTryWaitForLimit(t *testing.T) {
|
||||
// Create test config
|
||||
cfg := &config.ArbitrumConfig{
|
||||
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 10,
|
||||
Burst: 20,
|
||||
},
|
||||
}
|
||||
|
||||
// Create limiter manager
|
||||
lm := NewLimiterManager(cfg)
|
||||
|
||||
// Test trying to wait for limit on existing endpoint
|
||||
ctx := context.Background()
|
||||
err := lm.TryWaitForLimit(ctx, cfg.RPCEndpoint)
|
||||
assert.NoError(t, err) // Should succeed since we have burst capacity
|
||||
|
||||
// Test trying to wait for limit on non-existing endpoint
|
||||
err = lm.TryWaitForLimit(ctx, "https://nonexistent.com")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no rate limiter found for endpoint")
|
||||
}
|
||||
|
||||
func TestGetLimiter(t *testing.T) {
|
||||
// Create test config
|
||||
cfg := &config.ArbitrumConfig{
|
||||
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 10,
|
||||
Burst: 20,
|
||||
},
|
||||
}
|
||||
|
||||
// Create limiter manager
|
||||
lm := NewLimiterManager(cfg)
|
||||
|
||||
// Test getting limiter for existing endpoint
|
||||
limiter, err := lm.GetLimiter(cfg.RPCEndpoint)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, limiter)
|
||||
assert.IsType(t, &rate.Limiter{}, limiter)
|
||||
|
||||
// Test getting limiter for non-existing endpoint
|
||||
limiter, err = lm.GetLimiter("https://nonexistent.com")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no rate limiter found for endpoint")
|
||||
assert.Nil(t, limiter)
|
||||
}
|
||||
|
||||
func TestUpdateLimiter(t *testing.T) {
|
||||
// Create test config
|
||||
cfg := &config.ArbitrumConfig{
|
||||
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 10,
|
||||
Burst: 20,
|
||||
},
|
||||
}
|
||||
|
||||
// Create limiter manager
|
||||
lm := NewLimiterManager(cfg)
|
||||
|
||||
// Get original limiter
|
||||
originalLimiter, err := lm.GetLimiter(cfg.RPCEndpoint)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, originalLimiter)
|
||||
|
||||
// Update the limiter
|
||||
newConfig := config.RateLimitConfig{
|
||||
RequestsPerSecond: 20,
|
||||
Burst: 40,
|
||||
}
|
||||
lm.UpdateLimiter(cfg.RPCEndpoint, newConfig)
|
||||
|
||||
// Get updated limiter
|
||||
updatedLimiter, err := lm.GetLimiter(cfg.RPCEndpoint)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, updatedLimiter)
|
||||
|
||||
// The limiter should be different (new instance)
|
||||
assert.NotEqual(t, originalLimiter, updatedLimiter)
|
||||
|
||||
// Check that the config was updated
|
||||
endpointLimiter := lm.limiters[cfg.RPCEndpoint]
|
||||
assert.Equal(t, newConfig, endpointLimiter.Config)
|
||||
}
|
||||
|
||||
func TestGetEndpoints(t *testing.T) {
|
||||
// Create test config
|
||||
cfg := &config.ArbitrumConfig{
|
||||
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 10,
|
||||
Burst: 20,
|
||||
},
|
||||
ReadingEndpoints: []config.EndpointConfig{
|
||||
{
|
||||
URL: "https://fallback1.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 5,
|
||||
Burst: 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
URL: "https://fallback2.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 3,
|
||||
Burst: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create limiter manager
|
||||
lm := NewLimiterManager(cfg)
|
||||
|
||||
// Get endpoints
|
||||
endpoints := lm.GetEndpoints()
|
||||
|
||||
// Verify results
|
||||
assert.Len(t, endpoints, 3) // Primary + 2 fallbacks
|
||||
assert.Contains(t, endpoints, cfg.RPCEndpoint)
|
||||
assert.Contains(t, endpoints, cfg.ReadingEndpoints[0].URL)
|
||||
assert.Contains(t, endpoints, cfg.ReadingEndpoints[1].URL)
|
||||
}
|
||||
|
||||
func TestRateLimiting(t *testing.T) {
|
||||
// Create test config with very low rate limit for testing
|
||||
cfg := &config.ArbitrumConfig{
|
||||
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestsPerSecond: 1, // 1 request per second
|
||||
Burst: 1, // No burst
|
||||
},
|
||||
}
|
||||
|
||||
// Create limiter manager
|
||||
lm := NewLimiterManager(cfg)
|
||||
|
||||
// Make a request (should succeed immediately)
|
||||
start := time.Now()
|
||||
ctx := context.Background()
|
||||
err := lm.WaitForLimit(ctx, cfg.RPCEndpoint)
|
||||
assert.NoError(t, err)
|
||||
duration := time.Since(start)
|
||||
assert.True(t, duration < time.Millisecond*100, "First request should be fast")
|
||||
|
||||
// Make another request immediately (should be delayed)
|
||||
start = time.Now()
|
||||
err = lm.WaitForLimit(ctx, cfg.RPCEndpoint)
|
||||
assert.NoError(t, err)
|
||||
duration = time.Since(start)
|
||||
assert.True(t, duration >= time.Second, "Second request should be delayed by rate limiter")
|
||||
}
|
||||
621
orig/internal/recovery/error_handler.go
Normal file
621
orig/internal/recovery/error_handler.go
Normal file
@@ -0,0 +1,621 @@
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// ErrorSeverity represents the severity level of an error
|
||||
type ErrorSeverity int
|
||||
|
||||
const (
|
||||
SeverityLow ErrorSeverity = iota
|
||||
SeverityMedium
|
||||
SeverityHigh
|
||||
SeverityCritical
|
||||
)
|
||||
|
||||
func (s ErrorSeverity) String() string {
|
||||
switch s {
|
||||
case SeverityLow:
|
||||
return "LOW"
|
||||
case SeverityMedium:
|
||||
return "MEDIUM"
|
||||
case SeverityHigh:
|
||||
return "HIGH"
|
||||
case SeverityCritical:
|
||||
return "CRITICAL"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorType categorizes different types of errors
|
||||
type ErrorType int
|
||||
|
||||
const (
|
||||
ErrorTypeAddressCorruption ErrorType = iota
|
||||
ErrorTypeContractCallFailed
|
||||
ErrorTypeRPCConnectionFailed
|
||||
ErrorTypeDataParsingFailed
|
||||
ErrorTypeValidationFailed
|
||||
ErrorTypeTimeoutError
|
||||
)
|
||||
|
||||
func (e ErrorType) String() string {
|
||||
switch e {
|
||||
case ErrorTypeAddressCorruption:
|
||||
return "ADDRESS_CORRUPTION"
|
||||
case ErrorTypeContractCallFailed:
|
||||
return "CONTRACT_CALL_FAILED"
|
||||
case ErrorTypeRPCConnectionFailed:
|
||||
return "RPC_CONNECTION_FAILED"
|
||||
case ErrorTypeDataParsingFailed:
|
||||
return "DATA_PARSING_FAILED"
|
||||
case ErrorTypeValidationFailed:
|
||||
return "VALIDATION_FAILED"
|
||||
case ErrorTypeTimeoutError:
|
||||
return "TIMEOUT_ERROR"
|
||||
default:
|
||||
return "UNKNOWN_ERROR"
|
||||
}
|
||||
}
|
||||
|
||||
// RecoveryAction represents an action to take when an error occurs
|
||||
type RecoveryAction int
|
||||
|
||||
const (
|
||||
ActionSkipAndContinue RecoveryAction = iota
|
||||
ActionRetryWithBackoff
|
||||
ActionUseFallbackData
|
||||
ActionCircuitBreaker
|
||||
ActionEmergencyStop
|
||||
)
|
||||
|
||||
func (a RecoveryAction) String() string {
|
||||
switch a {
|
||||
case ActionSkipAndContinue:
|
||||
return "SKIP_AND_CONTINUE"
|
||||
case ActionRetryWithBackoff:
|
||||
return "RETRY_WITH_BACKOFF"
|
||||
case ActionUseFallbackData:
|
||||
return "USE_FALLBACK_DATA"
|
||||
case ActionCircuitBreaker:
|
||||
return "CIRCUIT_BREAKER"
|
||||
case ActionEmergencyStop:
|
||||
return "EMERGENCY_STOP"
|
||||
default:
|
||||
return "UNKNOWN_ACTION"
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorEvent represents a specific error occurrence
|
||||
type ErrorEvent struct {
|
||||
Timestamp time.Time
|
||||
Type ErrorType
|
||||
Severity ErrorSeverity
|
||||
Component string
|
||||
Address common.Address
|
||||
Message string
|
||||
Context map[string]interface{}
|
||||
AttemptCount int
|
||||
LastAttempt time.Time
|
||||
Resolved bool
|
||||
ResolvedAt time.Time
|
||||
}
|
||||
|
||||
// RecoveryRule defines how to handle specific error patterns
|
||||
type RecoveryRule struct {
|
||||
ErrorType ErrorType
|
||||
MaxSeverity ErrorSeverity
|
||||
Action RecoveryAction
|
||||
MaxRetries int
|
||||
BackoffInterval time.Duration
|
||||
CircuitBreakerThreshold int
|
||||
ContextMatchers map[string]interface{}
|
||||
}
|
||||
|
||||
// ErrorHandler provides comprehensive error handling and recovery capabilities
|
||||
type ErrorHandler struct {
|
||||
mu sync.RWMutex
|
||||
logger *logger.Logger
|
||||
errorHistory []ErrorEvent
|
||||
componentStats map[string]*ComponentStats
|
||||
circuitBreakers map[string]*CircuitBreaker
|
||||
recoveryRules []RecoveryRule
|
||||
fallbackProvider FallbackDataProvider
|
||||
maxHistorySize int
|
||||
alertThresholds map[ErrorType]int
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// ComponentStats tracks error statistics for components
|
||||
type ComponentStats struct {
|
||||
mu sync.RWMutex
|
||||
Component string
|
||||
TotalErrors int
|
||||
ErrorsByType map[ErrorType]int
|
||||
ErrorsBySeverity map[ErrorSeverity]int
|
||||
LastError time.Time
|
||||
ConsecutiveFailures int
|
||||
SuccessCount int
|
||||
IsHealthy bool
|
||||
LastHealthCheck time.Time
|
||||
}
|
||||
|
||||
// CircuitBreaker implements circuit breaker pattern for failing components
|
||||
type CircuitBreaker struct {
|
||||
mu sync.RWMutex
|
||||
Name string
|
||||
State CircuitState
|
||||
FailureCount int
|
||||
Threshold int
|
||||
Timeout time.Duration
|
||||
LastFailure time.Time
|
||||
LastSuccess time.Time
|
||||
HalfOpenAllowed bool
|
||||
}
|
||||
|
||||
type CircuitState int
|
||||
|
||||
const (
|
||||
CircuitClosed CircuitState = iota
|
||||
CircuitOpen
|
||||
CircuitHalfOpen
|
||||
)
|
||||
|
||||
func (s CircuitState) String() string {
|
||||
switch s {
|
||||
case CircuitClosed:
|
||||
return "CLOSED"
|
||||
case CircuitOpen:
|
||||
return "OPEN"
|
||||
case CircuitHalfOpen:
|
||||
return "HALF_OPEN"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// FallbackDataProvider interface for providing fallback data when primary sources fail
|
||||
type FallbackDataProvider interface {
|
||||
GetFallbackTokenInfo(ctx context.Context, address common.Address) (*FallbackTokenInfo, error)
|
||||
GetFallbackPoolInfo(ctx context.Context, address common.Address) (*FallbackPoolInfo, error)
|
||||
GetFallbackContractType(ctx context.Context, address common.Address) (string, error)
|
||||
}
|
||||
|
||||
type FallbackTokenInfo struct {
|
||||
Address common.Address
|
||||
Symbol string
|
||||
Name string
|
||||
Decimals uint8
|
||||
IsVerified bool
|
||||
Source string
|
||||
Confidence float64
|
||||
}
|
||||
|
||||
type FallbackPoolInfo struct {
|
||||
Address common.Address
|
||||
Token0 common.Address
|
||||
Token1 common.Address
|
||||
Protocol string
|
||||
Fee uint32
|
||||
IsVerified bool
|
||||
Source string
|
||||
Confidence float64
|
||||
}
|
||||
|
||||
// NewErrorHandler creates a new error handler with default configuration
|
||||
func NewErrorHandler(logger *logger.Logger) *ErrorHandler {
|
||||
handler := &ErrorHandler{
|
||||
logger: logger,
|
||||
errorHistory: make([]ErrorEvent, 0),
|
||||
componentStats: make(map[string]*ComponentStats),
|
||||
circuitBreakers: make(map[string]*CircuitBreaker),
|
||||
maxHistorySize: 1000,
|
||||
alertThresholds: make(map[ErrorType]int),
|
||||
enabled: true,
|
||||
}
|
||||
|
||||
// Initialize default recovery rules
|
||||
handler.initializeDefaultRules()
|
||||
|
||||
// Initialize default alert thresholds
|
||||
handler.initializeAlertThresholds()
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
// initializeDefaultRules sets up default recovery rules for common error scenarios
|
||||
func (eh *ErrorHandler) initializeDefaultRules() {
|
||||
eh.recoveryRules = []RecoveryRule{
|
||||
{
|
||||
ErrorType: ErrorTypeAddressCorruption,
|
||||
MaxSeverity: SeverityMedium,
|
||||
Action: ActionRetryWithBackoff,
|
||||
MaxRetries: 2,
|
||||
BackoffInterval: 500 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
ErrorType: ErrorTypeAddressCorruption,
|
||||
MaxSeverity: SeverityCritical,
|
||||
Action: ActionUseFallbackData,
|
||||
MaxRetries: 0,
|
||||
BackoffInterval: 0,
|
||||
},
|
||||
{
|
||||
ErrorType: ErrorTypeContractCallFailed,
|
||||
MaxSeverity: SeverityMedium,
|
||||
Action: ActionRetryWithBackoff,
|
||||
MaxRetries: 3,
|
||||
BackoffInterval: 2 * time.Second,
|
||||
},
|
||||
{
|
||||
ErrorType: ErrorTypeRPCConnectionFailed,
|
||||
MaxSeverity: SeverityHigh,
|
||||
Action: ActionCircuitBreaker,
|
||||
MaxRetries: 5,
|
||||
BackoffInterval: 5 * time.Second,
|
||||
CircuitBreakerThreshold: 10,
|
||||
},
|
||||
{
|
||||
ErrorType: ErrorTypeDataParsingFailed,
|
||||
MaxSeverity: SeverityMedium,
|
||||
Action: ActionUseFallbackData,
|
||||
MaxRetries: 2,
|
||||
BackoffInterval: 1 * time.Second,
|
||||
},
|
||||
{
|
||||
ErrorType: ErrorTypeValidationFailed,
|
||||
MaxSeverity: SeverityLow,
|
||||
Action: ActionSkipAndContinue,
|
||||
MaxRetries: 0,
|
||||
BackoffInterval: 0,
|
||||
},
|
||||
{
|
||||
ErrorType: ErrorTypeValidationFailed,
|
||||
MaxSeverity: SeverityHigh,
|
||||
Action: ActionRetryWithBackoff,
|
||||
MaxRetries: 1,
|
||||
BackoffInterval: 500 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
ErrorType: ErrorTypeTimeoutError,
|
||||
MaxSeverity: SeverityMedium,
|
||||
Action: ActionRetryWithBackoff,
|
||||
MaxRetries: 3,
|
||||
BackoffInterval: 3 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// initializeAlertThresholds sets up alert thresholds for different error types
|
||||
func (eh *ErrorHandler) initializeAlertThresholds() {
|
||||
eh.alertThresholds[ErrorTypeAddressCorruption] = 5
|
||||
eh.alertThresholds[ErrorTypeContractCallFailed] = 20
|
||||
eh.alertThresholds[ErrorTypeRPCConnectionFailed] = 10
|
||||
eh.alertThresholds[ErrorTypeDataParsingFailed] = 15
|
||||
eh.alertThresholds[ErrorTypeValidationFailed] = 25
|
||||
eh.alertThresholds[ErrorTypeTimeoutError] = 30
|
||||
}
|
||||
|
||||
// HandleError processes an error and determines the appropriate recovery action
|
||||
func (eh *ErrorHandler) HandleError(ctx context.Context, errorType ErrorType, severity ErrorSeverity, component string, address common.Address, message string, context map[string]interface{}) RecoveryAction {
|
||||
if !eh.enabled {
|
||||
return ActionSkipAndContinue
|
||||
}
|
||||
|
||||
eh.mu.Lock()
|
||||
defer eh.mu.Unlock()
|
||||
|
||||
// Record the error event
|
||||
event := ErrorEvent{
|
||||
Timestamp: time.Now(),
|
||||
Type: errorType,
|
||||
Severity: severity,
|
||||
Component: component,
|
||||
Address: address,
|
||||
Message: message,
|
||||
Context: context,
|
||||
AttemptCount: 1,
|
||||
LastAttempt: time.Now(),
|
||||
}
|
||||
|
||||
// Update error history
|
||||
eh.addToHistory(event)
|
||||
|
||||
// Update component statistics
|
||||
eh.updateComponentStats(component, errorType, severity)
|
||||
|
||||
// Check circuit breakers
|
||||
if eh.shouldTriggerCircuitBreaker(component, errorType) {
|
||||
eh.triggerCircuitBreaker(component)
|
||||
return ActionCircuitBreaker
|
||||
}
|
||||
|
||||
// Find matching recovery rule
|
||||
rule := eh.findRecoveryRule(errorType, severity, context)
|
||||
if rule == nil {
|
||||
// Default action for unmatched errors
|
||||
return ActionSkipAndContinue
|
||||
}
|
||||
|
||||
// Log the error and recovery action
|
||||
eh.logger.Error("Error handled by recovery system",
|
||||
"type", errorType.String(),
|
||||
"severity", severity.String(),
|
||||
"component", component,
|
||||
"address", address.Hex(),
|
||||
"message", message,
|
||||
"action", rule.Action.String())
|
||||
|
||||
// Check if alert threshold is reached
|
||||
eh.checkAlertThresholds(errorType)
|
||||
|
||||
return rule.Action
|
||||
}
|
||||
|
||||
// addToHistory adds an error event to the history buffer
|
||||
func (eh *ErrorHandler) addToHistory(event ErrorEvent) {
|
||||
eh.errorHistory = append(eh.errorHistory, event)
|
||||
|
||||
// Trim history if it exceeds max size
|
||||
if len(eh.errorHistory) > eh.maxHistorySize {
|
||||
eh.errorHistory = eh.errorHistory[len(eh.errorHistory)-eh.maxHistorySize:]
|
||||
}
|
||||
}
|
||||
|
||||
// updateComponentStats updates statistics for a component
|
||||
func (eh *ErrorHandler) updateComponentStats(component string, errorType ErrorType, severity ErrorSeverity) {
|
||||
stats, exists := eh.componentStats[component]
|
||||
if !exists {
|
||||
stats = &ComponentStats{
|
||||
Component: component,
|
||||
ErrorsByType: make(map[ErrorType]int),
|
||||
ErrorsBySeverity: make(map[ErrorSeverity]int),
|
||||
IsHealthy: true,
|
||||
}
|
||||
eh.componentStats[component] = stats
|
||||
}
|
||||
|
||||
stats.mu.Lock()
|
||||
defer stats.mu.Unlock()
|
||||
|
||||
stats.TotalErrors++
|
||||
stats.ErrorsByType[errorType]++
|
||||
stats.ErrorsBySeverity[severity]++
|
||||
stats.LastError = time.Now()
|
||||
stats.ConsecutiveFailures++
|
||||
|
||||
// Mark as unhealthy if too many consecutive failures
|
||||
if stats.ConsecutiveFailures > 10 {
|
||||
stats.IsHealthy = false
|
||||
}
|
||||
}
|
||||
|
||||
// findRecoveryRule finds the best matching recovery rule for an error
|
||||
func (eh *ErrorHandler) findRecoveryRule(errorType ErrorType, severity ErrorSeverity, context map[string]interface{}) *RecoveryRule {
|
||||
for _, rule := range eh.recoveryRules {
|
||||
if rule.ErrorType == errorType && severity <= rule.MaxSeverity {
|
||||
// Check context matchers if present
|
||||
if len(rule.ContextMatchers) > 0 {
|
||||
if !eh.matchesContext(context, rule.ContextMatchers) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return &rule
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchesContext checks if the error context matches the rule's context matchers
|
||||
func (eh *ErrorHandler) matchesContext(errorContext, ruleMatchers map[string]interface{}) bool {
|
||||
for key, expectedValue := range ruleMatchers {
|
||||
if actualValue, exists := errorContext[key]; !exists || actualValue != expectedValue {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// shouldTriggerCircuitBreaker determines if a circuit breaker should be triggered
|
||||
func (eh *ErrorHandler) shouldTriggerCircuitBreaker(component string, errorType ErrorType) bool {
|
||||
stats, exists := eh.componentStats[component]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
stats.mu.RLock()
|
||||
defer stats.mu.RUnlock()
|
||||
|
||||
// Trigger if consecutive failures exceed threshold for critical errors
|
||||
if errorType == ErrorTypeRPCConnectionFailed && stats.ConsecutiveFailures >= 5 {
|
||||
return true
|
||||
}
|
||||
|
||||
if errorType == ErrorTypeAddressCorruption && stats.ConsecutiveFailures >= 3 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// triggerCircuitBreaker activates a circuit breaker for a component
|
||||
func (eh *ErrorHandler) triggerCircuitBreaker(component string) {
|
||||
breaker := &CircuitBreaker{
|
||||
Name: component,
|
||||
State: CircuitOpen,
|
||||
FailureCount: 0,
|
||||
Threshold: 5,
|
||||
Timeout: 30 * time.Second,
|
||||
LastFailure: time.Now(),
|
||||
}
|
||||
|
||||
eh.circuitBreakers[component] = breaker
|
||||
|
||||
eh.logger.Warn("Circuit breaker triggered",
|
||||
"component", component,
|
||||
"timeout", breaker.Timeout)
|
||||
}
|
||||
|
||||
// checkAlertThresholds checks if error counts have reached alert thresholds
|
||||
func (eh *ErrorHandler) checkAlertThresholds(errorType ErrorType) {
|
||||
threshold, exists := eh.alertThresholds[errorType]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Count recent errors of this type (last hour)
|
||||
recentCount := 0
|
||||
cutoff := time.Now().Add(-1 * time.Hour)
|
||||
|
||||
for _, event := range eh.errorHistory {
|
||||
if event.Type == errorType && event.Timestamp.After(cutoff) {
|
||||
recentCount++
|
||||
}
|
||||
}
|
||||
|
||||
if recentCount >= threshold {
|
||||
eh.logger.Warn("Error threshold reached - alert triggered",
|
||||
"error_type", errorType.String(),
|
||||
"count", recentCount,
|
||||
"threshold", threshold)
|
||||
// Here you would trigger your alerting system
|
||||
}
|
||||
}
|
||||
|
||||
// GetComponentHealth returns the health status of all components
|
||||
func (eh *ErrorHandler) GetComponentHealth() map[string]*ComponentStats {
|
||||
eh.mu.RLock()
|
||||
defer eh.mu.RUnlock()
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make(map[string]*ComponentStats)
|
||||
for name, stats := range eh.componentStats {
|
||||
result[name] = &ComponentStats{
|
||||
Component: stats.Component,
|
||||
TotalErrors: stats.TotalErrors,
|
||||
ErrorsByType: make(map[ErrorType]int),
|
||||
ErrorsBySeverity: make(map[ErrorSeverity]int),
|
||||
LastError: stats.LastError,
|
||||
ConsecutiveFailures: stats.ConsecutiveFailures,
|
||||
SuccessCount: stats.SuccessCount,
|
||||
IsHealthy: stats.IsHealthy,
|
||||
LastHealthCheck: stats.LastHealthCheck,
|
||||
}
|
||||
|
||||
// Copy maps
|
||||
for k, v := range stats.ErrorsByType {
|
||||
result[name].ErrorsByType[k] = v
|
||||
}
|
||||
for k, v := range stats.ErrorsBySeverity {
|
||||
result[name].ErrorsBySeverity[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation for a component
|
||||
func (eh *ErrorHandler) RecordSuccess(component string) {
|
||||
eh.mu.Lock()
|
||||
defer eh.mu.Unlock()
|
||||
|
||||
stats, exists := eh.componentStats[component]
|
||||
if !exists {
|
||||
stats = &ComponentStats{
|
||||
Component: component,
|
||||
ErrorsByType: make(map[ErrorType]int),
|
||||
ErrorsBySeverity: make(map[ErrorSeverity]int),
|
||||
IsHealthy: true,
|
||||
}
|
||||
eh.componentStats[component] = stats
|
||||
}
|
||||
|
||||
stats.mu.Lock()
|
||||
defer stats.mu.Unlock()
|
||||
|
||||
stats.SuccessCount++
|
||||
stats.ConsecutiveFailures = 0
|
||||
stats.IsHealthy = true
|
||||
stats.LastHealthCheck = time.Now()
|
||||
|
||||
// Reset circuit breaker if it exists
|
||||
if breaker, exists := eh.circuitBreakers[component]; exists {
|
||||
breaker.mu.Lock()
|
||||
breaker.State = CircuitClosed
|
||||
breaker.FailureCount = 0
|
||||
breaker.LastSuccess = time.Now()
|
||||
breaker.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// IsCircuitOpen checks if a circuit breaker is open for a component
|
||||
func (eh *ErrorHandler) IsCircuitOpen(component string) bool {
|
||||
eh.mu.RLock()
|
||||
defer eh.mu.RUnlock()
|
||||
|
||||
breaker, exists := eh.circuitBreakers[component]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
breaker.mu.RLock()
|
||||
defer breaker.mu.RUnlock()
|
||||
|
||||
if breaker.State == CircuitOpen {
|
||||
// Check if timeout has passed
|
||||
if time.Since(breaker.LastFailure) > breaker.Timeout {
|
||||
breaker.State = CircuitHalfOpen
|
||||
breaker.HalfOpenAllowed = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetFallbackProvider sets the fallback data provider
|
||||
func (eh *ErrorHandler) SetFallbackProvider(provider FallbackDataProvider) {
|
||||
eh.mu.Lock()
|
||||
defer eh.mu.Unlock()
|
||||
eh.fallbackProvider = provider
|
||||
}
|
||||
|
||||
// GetErrorSummary returns a summary of recent errors
|
||||
func (eh *ErrorHandler) GetErrorSummary(duration time.Duration) map[string]interface{} {
|
||||
eh.mu.RLock()
|
||||
defer eh.mu.RUnlock()
|
||||
|
||||
cutoff := time.Now().Add(-duration)
|
||||
summary := map[string]interface{}{
|
||||
"total_errors": 0,
|
||||
"errors_by_type": make(map[string]int),
|
||||
"errors_by_severity": make(map[string]int),
|
||||
"errors_by_component": make(map[string]int),
|
||||
"time_range": duration.String(),
|
||||
}
|
||||
|
||||
for _, event := range eh.errorHistory {
|
||||
if event.Timestamp.After(cutoff) {
|
||||
summary["total_errors"] = summary["total_errors"].(int) + 1
|
||||
|
||||
typeKey := event.Type.String()
|
||||
summary["errors_by_type"].(map[string]int)[typeKey]++
|
||||
|
||||
severityKey := event.Severity.String()
|
||||
summary["errors_by_severity"].(map[string]int)[severityKey]++
|
||||
|
||||
summary["errors_by_component"].(map[string]int)[event.Component]++
|
||||
}
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
384
orig/internal/recovery/fallback_provider.go
Normal file
384
orig/internal/recovery/fallback_provider.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
"github.com/fraktal/mev-beta/internal/registry"
|
||||
)
|
||||
|
||||
// DefaultFallbackProvider implements FallbackDataProvider with multiple data sources
|
||||
type DefaultFallbackProvider struct {
|
||||
mu sync.RWMutex
|
||||
logger *logger.Logger
|
||||
contractRegistry *registry.ContractRegistry
|
||||
staticTokenData map[common.Address]*FallbackTokenInfo
|
||||
staticPoolData map[common.Address]*FallbackPoolInfo
|
||||
cacheTimeout time.Duration
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewDefaultFallbackProvider creates a new fallback data provider
|
||||
func NewDefaultFallbackProvider(logger *logger.Logger, contractRegistry *registry.ContractRegistry) *DefaultFallbackProvider {
|
||||
provider := &DefaultFallbackProvider{
|
||||
logger: logger,
|
||||
contractRegistry: contractRegistry,
|
||||
staticTokenData: make(map[common.Address]*FallbackTokenInfo),
|
||||
staticPoolData: make(map[common.Address]*FallbackPoolInfo),
|
||||
cacheTimeout: 5 * time.Minute,
|
||||
enabled: true,
|
||||
}
|
||||
|
||||
// Initialize with known safe data
|
||||
provider.initializeStaticData()
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
// initializeStaticData populates the provider with known good data for critical Arbitrum contracts
|
||||
func (fp *DefaultFallbackProvider) initializeStaticData() {
|
||||
fp.mu.Lock()
|
||||
defer fp.mu.Unlock()
|
||||
|
||||
// Major Arbitrum tokens with verified addresses
|
||||
fp.staticTokenData[common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1")] = &FallbackTokenInfo{
|
||||
Address: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"),
|
||||
Symbol: "WETH",
|
||||
Name: "Wrapped Ether",
|
||||
Decimals: 18,
|
||||
IsVerified: true,
|
||||
Source: "static_fallback",
|
||||
Confidence: 1.0,
|
||||
}
|
||||
|
||||
fp.staticTokenData[common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831")] = &FallbackTokenInfo{
|
||||
Address: common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831"),
|
||||
Symbol: "USDC",
|
||||
Name: "USD Coin",
|
||||
Decimals: 6,
|
||||
IsVerified: true,
|
||||
Source: "static_fallback",
|
||||
Confidence: 1.0,
|
||||
}
|
||||
|
||||
fp.staticTokenData[common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9")] = &FallbackTokenInfo{
|
||||
Address: common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"),
|
||||
Symbol: "USDT",
|
||||
Name: "Tether USD",
|
||||
Decimals: 6,
|
||||
IsVerified: true,
|
||||
Source: "static_fallback",
|
||||
Confidence: 1.0,
|
||||
}
|
||||
|
||||
fp.staticTokenData[common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f")] = &FallbackTokenInfo{
|
||||
Address: common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"),
|
||||
Symbol: "WBTC",
|
||||
Name: "Wrapped BTC",
|
||||
Decimals: 8,
|
||||
IsVerified: true,
|
||||
Source: "static_fallback",
|
||||
Confidence: 1.0,
|
||||
}
|
||||
|
||||
fp.staticTokenData[common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548")] = &FallbackTokenInfo{
|
||||
Address: common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548"),
|
||||
Symbol: "ARB",
|
||||
Name: "Arbitrum",
|
||||
Decimals: 18,
|
||||
IsVerified: true,
|
||||
Source: "static_fallback",
|
||||
Confidence: 1.0,
|
||||
}
|
||||
|
||||
// High-volume Uniswap V3 pools with verified addresses and token pairs
|
||||
fp.staticPoolData[common.HexToAddress("0xC6962004f452bE9203591991D15f6b388e09E8D0")] = &FallbackPoolInfo{
|
||||
Address: common.HexToAddress("0xC6962004f452bE9203591991D15f6b388e09E8D0"),
|
||||
Token0: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
|
||||
Token1: common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831"), // USDC
|
||||
Protocol: "UniswapV3",
|
||||
Fee: 500, // 0.05%
|
||||
IsVerified: true,
|
||||
Source: "static_fallback",
|
||||
Confidence: 1.0,
|
||||
}
|
||||
|
||||
fp.staticPoolData[common.HexToAddress("0x641C00A822e8b671738d32a431a4Fb6074E5c79d")] = &FallbackPoolInfo{
|
||||
Address: common.HexToAddress("0x641C00A822e8b671738d32a431a4Fb6074E5c79d"),
|
||||
Token0: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
|
||||
Token1: common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831"), // USDC
|
||||
Protocol: "UniswapV3",
|
||||
Fee: 3000, // 0.3%
|
||||
IsVerified: true,
|
||||
Source: "static_fallback",
|
||||
Confidence: 1.0,
|
||||
}
|
||||
|
||||
fp.staticPoolData[common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d")] = &FallbackPoolInfo{
|
||||
Address: common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d"),
|
||||
Token0: common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831"), // USDC
|
||||
Token1: common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"), // USDT
|
||||
Protocol: "UniswapV3",
|
||||
Fee: 100, // 0.01%
|
||||
IsVerified: true,
|
||||
Source: "static_fallback",
|
||||
Confidence: 1.0,
|
||||
}
|
||||
|
||||
fp.staticPoolData[common.HexToAddress("0x2f5e87C032bc4F8526F320c012A4e678F1fa6cAB")] = &FallbackPoolInfo{
|
||||
Address: common.HexToAddress("0x2f5e87C032bc4F8526F320c012A4e678F1fa6cAB"),
|
||||
Token0: common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"), // WBTC
|
||||
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
|
||||
Protocol: "UniswapV3",
|
||||
Fee: 500, // 0.05%
|
||||
IsVerified: true,
|
||||
Source: "static_fallback",
|
||||
Confidence: 1.0,
|
||||
}
|
||||
|
||||
fp.logger.Info("Initialized fallback provider with static data",
|
||||
"tokens", len(fp.staticTokenData),
|
||||
"pools", len(fp.staticPoolData))
|
||||
}
|
||||
|
||||
// GetFallbackTokenInfo provides fallback token information
|
||||
func (fp *DefaultFallbackProvider) GetFallbackTokenInfo(ctx context.Context, address common.Address) (*FallbackTokenInfo, error) {
|
||||
if !fp.enabled {
|
||||
return nil, fmt.Errorf("fallback provider disabled")
|
||||
}
|
||||
|
||||
fp.mu.RLock()
|
||||
defer fp.mu.RUnlock()
|
||||
|
||||
// First, try static data
|
||||
if tokenInfo, exists := fp.staticTokenData[address]; exists {
|
||||
fp.logger.Debug("Fallback token info from static data",
|
||||
"address", address.Hex(),
|
||||
"symbol", tokenInfo.Symbol,
|
||||
"source", tokenInfo.Source)
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// Second, try contract registry if available
|
||||
if fp.contractRegistry != nil {
|
||||
if contractInfo, err := fp.contractRegistry.GetContractInfo(ctx, address); err == nil && contractInfo != nil {
|
||||
tokenInfo := &FallbackTokenInfo{
|
||||
Address: address,
|
||||
Symbol: contractInfo.Symbol,
|
||||
Name: contractInfo.Name,
|
||||
Decimals: contractInfo.Decimals,
|
||||
IsVerified: contractInfo.IsVerified,
|
||||
Source: "contract_registry",
|
||||
Confidence: contractInfo.Confidence,
|
||||
}
|
||||
|
||||
fp.logger.Debug("Fallback token info from registry",
|
||||
"address", address.Hex(),
|
||||
"symbol", tokenInfo.Symbol,
|
||||
"confidence", tokenInfo.Confidence)
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Third, provide minimal safe fallback for unknown tokens
|
||||
tokenInfo := &FallbackTokenInfo{
|
||||
Address: address,
|
||||
Symbol: fmt.Sprintf("UNK_%s", address.Hex()[:8]),
|
||||
Name: "Unknown Token",
|
||||
Decimals: 18, // Safe default
|
||||
IsVerified: false,
|
||||
Source: "generated_fallback",
|
||||
Confidence: 0.1,
|
||||
}
|
||||
|
||||
fp.logger.Warn("Using generated fallback token info",
|
||||
"address", address.Hex(),
|
||||
"symbol", tokenInfo.Symbol)
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// GetFallbackPoolInfo provides fallback pool information
|
||||
func (fp *DefaultFallbackProvider) GetFallbackPoolInfo(ctx context.Context, address common.Address) (*FallbackPoolInfo, error) {
|
||||
if !fp.enabled {
|
||||
return nil, fmt.Errorf("fallback provider disabled")
|
||||
}
|
||||
|
||||
fp.mu.RLock()
|
||||
defer fp.mu.RUnlock()
|
||||
|
||||
// First, try static data
|
||||
if poolInfo, exists := fp.staticPoolData[address]; exists {
|
||||
fp.logger.Debug("Fallback pool info from static data",
|
||||
"address", address.Hex(),
|
||||
"protocol", poolInfo.Protocol,
|
||||
"token0", poolInfo.Token0.Hex(),
|
||||
"token1", poolInfo.Token1.Hex())
|
||||
return poolInfo, nil
|
||||
}
|
||||
|
||||
// Second, try contract registry if available
|
||||
if fp.contractRegistry != nil {
|
||||
if poolInfo := fp.contractRegistry.GetPoolInfo(address); poolInfo != nil {
|
||||
fallbackInfo := &FallbackPoolInfo{
|
||||
Address: address,
|
||||
Token0: poolInfo.Token0,
|
||||
Token1: poolInfo.Token1,
|
||||
Protocol: poolInfo.Protocol,
|
||||
Fee: poolInfo.Fee,
|
||||
IsVerified: poolInfo.IsVerified,
|
||||
Source: "contract_registry",
|
||||
Confidence: poolInfo.Confidence,
|
||||
}
|
||||
|
||||
fp.logger.Debug("Fallback pool info from registry",
|
||||
"address", address.Hex(),
|
||||
"protocol", fallbackInfo.Protocol,
|
||||
"confidence", fallbackInfo.Confidence)
|
||||
|
||||
return fallbackInfo, nil
|
||||
}
|
||||
}
|
||||
|
||||
// No fallback available for unknown pools - return error
|
||||
return nil, fmt.Errorf("no fallback data available for pool %s", address.Hex())
|
||||
}
|
||||
|
||||
// GetFallbackContractType provides fallback contract type information
|
||||
func (fp *DefaultFallbackProvider) GetFallbackContractType(ctx context.Context, address common.Address) (string, error) {
|
||||
if !fp.enabled {
|
||||
return "", fmt.Errorf("fallback provider disabled")
|
||||
}
|
||||
|
||||
fp.mu.RLock()
|
||||
defer fp.mu.RUnlock()
|
||||
|
||||
// Check if it's a known token
|
||||
if _, exists := fp.staticTokenData[address]; exists {
|
||||
return "ERC20", nil
|
||||
}
|
||||
|
||||
// Check if it's a known pool
|
||||
if _, exists := fp.staticPoolData[address]; exists {
|
||||
return "Pool", nil
|
||||
}
|
||||
|
||||
// Try contract registry
|
||||
if fp.contractRegistry != nil {
|
||||
if contractInfo, err := fp.contractRegistry.GetContractInfo(ctx, address); err == nil && contractInfo != nil {
|
||||
return contractInfo.Type.String(), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Default to unknown
|
||||
return "Unknown", nil
|
||||
}
|
||||
|
||||
// AddStaticTokenData adds static token data for fallback use
|
||||
func (fp *DefaultFallbackProvider) AddStaticTokenData(address common.Address, info *FallbackTokenInfo) {
|
||||
fp.mu.Lock()
|
||||
defer fp.mu.Unlock()
|
||||
|
||||
fp.staticTokenData[address] = info
|
||||
fp.logger.Debug("Added static token data",
|
||||
"address", address.Hex(),
|
||||
"symbol", info.Symbol)
|
||||
}
|
||||
|
||||
// AddStaticPoolData adds static pool data for fallback use
|
||||
func (fp *DefaultFallbackProvider) AddStaticPoolData(address common.Address, info *FallbackPoolInfo) {
|
||||
fp.mu.Lock()
|
||||
defer fp.mu.Unlock()
|
||||
|
||||
fp.staticPoolData[address] = info
|
||||
fp.logger.Debug("Added static pool data",
|
||||
"address", address.Hex(),
|
||||
"protocol", info.Protocol)
|
||||
}
|
||||
|
||||
// IsAddressKnown checks if an address is in the static fallback data
|
||||
func (fp *DefaultFallbackProvider) IsAddressKnown(address common.Address) bool {
|
||||
fp.mu.RLock()
|
||||
defer fp.mu.RUnlock()
|
||||
|
||||
_, isToken := fp.staticTokenData[address]
|
||||
_, isPool := fp.staticPoolData[address]
|
||||
|
||||
return isToken || isPool
|
||||
}
|
||||
|
||||
// GetKnownAddresses returns all known addresses in the fallback provider
|
||||
func (fp *DefaultFallbackProvider) GetKnownAddresses() (tokens []common.Address, pools []common.Address) {
|
||||
fp.mu.RLock()
|
||||
defer fp.mu.RUnlock()
|
||||
|
||||
for addr := range fp.staticTokenData {
|
||||
tokens = append(tokens, addr)
|
||||
}
|
||||
|
||||
for addr := range fp.staticPoolData {
|
||||
pools = append(pools, addr)
|
||||
}
|
||||
|
||||
return tokens, pools
|
||||
}
|
||||
|
||||
// ValidateAddressWithFallback performs validation using fallback data
|
||||
func (fp *DefaultFallbackProvider) ValidateAddressWithFallback(ctx context.Context, address common.Address, expectedType string) (bool, float64, error) {
|
||||
if !fp.enabled {
|
||||
return false, 0.0, fmt.Errorf("fallback provider disabled")
|
||||
}
|
||||
|
||||
// Check if address is known in our static data
|
||||
if fp.IsAddressKnown(address) {
|
||||
actualType, err := fp.GetFallbackContractType(ctx, address)
|
||||
if err != nil {
|
||||
return false, 0.0, err
|
||||
}
|
||||
|
||||
if actualType == expectedType {
|
||||
return true, 1.0, nil // High confidence for known addresses
|
||||
}
|
||||
|
||||
return false, 0.0, fmt.Errorf("type mismatch: expected %s, got %s", expectedType, actualType)
|
||||
}
|
||||
|
||||
// For unknown addresses, provide low confidence validation
|
||||
return true, 0.3, nil // Allow with low confidence
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the fallback provider
|
||||
func (fp *DefaultFallbackProvider) GetStats() map[string]interface{} {
|
||||
fp.mu.RLock()
|
||||
defer fp.mu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"enabled": fp.enabled,
|
||||
"static_tokens_count": len(fp.staticTokenData),
|
||||
"static_pools_count": len(fp.staticPoolData),
|
||||
"cache_timeout": fp.cacheTimeout.String(),
|
||||
"has_registry": fp.contractRegistry != nil,
|
||||
}
|
||||
}
|
||||
|
||||
// Enable enables the fallback provider
|
||||
func (fp *DefaultFallbackProvider) Enable() {
|
||||
fp.mu.Lock()
|
||||
defer fp.mu.Unlock()
|
||||
fp.enabled = true
|
||||
fp.logger.Info("Fallback provider enabled")
|
||||
}
|
||||
|
||||
// Disable disables the fallback provider
|
||||
func (fp *DefaultFallbackProvider) Disable() {
|
||||
fp.mu.Lock()
|
||||
defer fp.mu.Unlock()
|
||||
fp.enabled = false
|
||||
fp.logger.Info("Fallback provider disabled")
|
||||
}
|
||||
446
orig/internal/recovery/retry_handler.go
Normal file
446
orig/internal/recovery/retry_handler.go
Normal file
@@ -0,0 +1,446 @@
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// RetryConfig defines retry behavior configuration
|
||||
type RetryConfig struct {
|
||||
MaxAttempts int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
BackoffFactor float64
|
||||
JitterEnabled bool
|
||||
TimeoutPerAttempt time.Duration
|
||||
}
|
||||
|
||||
// DefaultRetryConfig returns a sensible default retry configuration
|
||||
func DefaultRetryConfig() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
JitterEnabled: true,
|
||||
TimeoutPerAttempt: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// RetryableOperation represents an operation that can be retried
|
||||
type RetryableOperation func(ctx context.Context, attempt int) error
|
||||
|
||||
// RetryHandler provides exponential backoff retry capabilities
|
||||
type RetryHandler struct {
|
||||
mu sync.RWMutex
|
||||
logger *logger.Logger
|
||||
configs map[string]RetryConfig
|
||||
stats map[string]*RetryStats
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// RetryStats tracks retry statistics for operations
|
||||
type RetryStats struct {
|
||||
mu sync.RWMutex
|
||||
OperationType string
|
||||
TotalAttempts int
|
||||
SuccessfulRetries int
|
||||
FailedRetries int
|
||||
AverageAttempts float64
|
||||
LastAttempt time.Time
|
||||
LastSuccess time.Time
|
||||
LastFailure time.Time
|
||||
}
|
||||
|
||||
// RetryResult contains the result of a retry operation
|
||||
type RetryResult struct {
|
||||
Success bool
|
||||
Attempts int
|
||||
TotalDuration time.Duration
|
||||
LastError error
|
||||
LastAttemptAt time.Time
|
||||
}
|
||||
|
||||
// NewRetryHandler creates a new retry handler
|
||||
func NewRetryHandler(logger *logger.Logger) *RetryHandler {
|
||||
handler := &RetryHandler{
|
||||
logger: logger,
|
||||
configs: make(map[string]RetryConfig),
|
||||
stats: make(map[string]*RetryStats),
|
||||
enabled: true,
|
||||
}
|
||||
|
||||
// Initialize default configurations for common operations
|
||||
handler.initializeDefaultConfigs()
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
// initializeDefaultConfigs sets up default retry configurations
|
||||
func (rh *RetryHandler) initializeDefaultConfigs() {
|
||||
// Contract call retries - moderate backoff
|
||||
rh.configs["contract_call"] = RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 500 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
JitterEnabled: true,
|
||||
TimeoutPerAttempt: 10 * time.Second,
|
||||
}
|
||||
|
||||
// RPC connection retries - aggressive backoff
|
||||
rh.configs["rpc_connection"] = RetryConfig{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
BackoffFactor: 2.5,
|
||||
JitterEnabled: true,
|
||||
TimeoutPerAttempt: 15 * time.Second,
|
||||
}
|
||||
|
||||
// Data parsing retries - quick retries
|
||||
rh.configs["data_parsing"] = RetryConfig{
|
||||
MaxAttempts: 2,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 1 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
JitterEnabled: false,
|
||||
TimeoutPerAttempt: 5 * time.Second,
|
||||
}
|
||||
|
||||
// Block processing retries - conservative
|
||||
rh.configs["block_processing"] = RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 2 * time.Second,
|
||||
MaxDelay: 10 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
JitterEnabled: true,
|
||||
TimeoutPerAttempt: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Token metadata retries - patient backoff
|
||||
rh.configs["token_metadata"] = RetryConfig{
|
||||
MaxAttempts: 4,
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 20 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
JitterEnabled: true,
|
||||
TimeoutPerAttempt: 15 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteWithRetry executes an operation with retry logic
|
||||
func (rh *RetryHandler) ExecuteWithRetry(ctx context.Context, operationType string, operation RetryableOperation) *RetryResult {
|
||||
if !rh.enabled {
|
||||
// If retries are disabled, try once
|
||||
err := operation(ctx, 1)
|
||||
return &RetryResult{
|
||||
Success: err == nil,
|
||||
Attempts: 1,
|
||||
TotalDuration: 0,
|
||||
LastError: err,
|
||||
LastAttemptAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
config := rh.getConfig(operationType)
|
||||
start := time.Now()
|
||||
var lastError error
|
||||
|
||||
rh.mu.Lock()
|
||||
stats, exists := rh.stats[operationType]
|
||||
if !exists {
|
||||
stats = &RetryStats{
|
||||
OperationType: operationType,
|
||||
}
|
||||
rh.stats[operationType] = stats
|
||||
}
|
||||
rh.mu.Unlock()
|
||||
|
||||
for attempt := 1; attempt <= config.MaxAttempts; attempt++ {
|
||||
// Create context with timeout for this attempt
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, config.TimeoutPerAttempt)
|
||||
|
||||
rh.logger.Debug("Attempting operation with retry",
|
||||
"operation", operationType,
|
||||
"attempt", attempt,
|
||||
"max_attempts", config.MaxAttempts)
|
||||
|
||||
// Execute the operation
|
||||
err := operation(attemptCtx, attempt)
|
||||
cancel()
|
||||
|
||||
// Update statistics
|
||||
stats.mu.Lock()
|
||||
stats.TotalAttempts++
|
||||
stats.LastAttempt = time.Now()
|
||||
stats.mu.Unlock()
|
||||
|
||||
if err == nil {
|
||||
// Success!
|
||||
duration := time.Since(start)
|
||||
|
||||
stats.mu.Lock()
|
||||
stats.SuccessfulRetries++
|
||||
stats.LastSuccess = time.Now()
|
||||
denominator := stats.SuccessfulRetries + stats.FailedRetries
|
||||
if denominator > 0 {
|
||||
stats.AverageAttempts = float64(stats.TotalAttempts) / float64(denominator)
|
||||
}
|
||||
stats.mu.Unlock()
|
||||
|
||||
rh.logger.Debug("Operation succeeded",
|
||||
"operation", operationType,
|
||||
"attempt", attempt,
|
||||
"duration", duration)
|
||||
|
||||
return &RetryResult{
|
||||
Success: true,
|
||||
Attempts: attempt,
|
||||
TotalDuration: duration,
|
||||
LastError: nil,
|
||||
LastAttemptAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
lastError = err
|
||||
|
||||
// Check if context was cancelled
|
||||
if ctx.Err() != nil {
|
||||
rh.logger.Debug("Operation cancelled by context",
|
||||
"operation", operationType,
|
||||
"attempt", attempt,
|
||||
"error", ctx.Err())
|
||||
break
|
||||
}
|
||||
|
||||
// Don't wait after the last attempt
|
||||
if attempt < config.MaxAttempts {
|
||||
delay := rh.calculateDelay(config, attempt)
|
||||
|
||||
rh.logger.Debug("Operation failed, retrying",
|
||||
"operation", operationType,
|
||||
"attempt", attempt,
|
||||
"error", err,
|
||||
"delay", delay)
|
||||
|
||||
// Wait before next attempt
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
// Continue to next attempt
|
||||
case <-ctx.Done():
|
||||
// Context cancelled during wait
|
||||
break
|
||||
}
|
||||
} else {
|
||||
rh.logger.Warn("Operation failed after all retries",
|
||||
"operation", operationType,
|
||||
"attempts", attempt,
|
||||
"error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// All attempts failed
|
||||
duration := time.Since(start)
|
||||
|
||||
stats.mu.Lock()
|
||||
stats.FailedRetries++
|
||||
stats.LastFailure = time.Now()
|
||||
denominator := stats.SuccessfulRetries + stats.FailedRetries
|
||||
if denominator > 0 {
|
||||
stats.AverageAttempts = float64(stats.TotalAttempts) / float64(denominator)
|
||||
}
|
||||
stats.mu.Unlock()
|
||||
|
||||
return &RetryResult{
|
||||
Success: false,
|
||||
Attempts: config.MaxAttempts,
|
||||
TotalDuration: duration,
|
||||
LastError: lastError,
|
||||
LastAttemptAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// calculateDelay calculates the delay before the next retry attempt
|
||||
func (rh *RetryHandler) calculateDelay(config RetryConfig, attempt int) time.Duration {
|
||||
// Calculate exponential backoff
|
||||
delay := float64(config.InitialDelay) * math.Pow(config.BackoffFactor, float64(attempt-1))
|
||||
|
||||
// Apply maximum delay cap
|
||||
if delay > float64(config.MaxDelay) {
|
||||
delay = float64(config.MaxDelay)
|
||||
}
|
||||
|
||||
duration := time.Duration(delay)
|
||||
|
||||
// Add jitter if enabled
|
||||
if config.JitterEnabled {
|
||||
jitter := time.Duration(float64(duration) * 0.1 * (2*rh.randomFloat() - 1))
|
||||
duration += jitter
|
||||
}
|
||||
|
||||
// Ensure minimum delay
|
||||
if duration < 0 {
|
||||
duration = config.InitialDelay
|
||||
}
|
||||
|
||||
return duration
|
||||
}
|
||||
|
||||
// randomFloat returns a pseudo-random float between 0 and 1
|
||||
func (rh *RetryHandler) randomFloat() float64 {
|
||||
// Simple pseudo-random number based on current time
|
||||
return float64(time.Now().UnixNano()%1000) / 1000.0
|
||||
}
|
||||
|
||||
// getConfig returns the retry configuration for an operation type
|
||||
func (rh *RetryHandler) getConfig(operationType string) RetryConfig {
|
||||
rh.mu.RLock()
|
||||
defer rh.mu.RUnlock()
|
||||
|
||||
if config, exists := rh.configs[operationType]; exists {
|
||||
return config
|
||||
}
|
||||
|
||||
// Return default config if no specific config found
|
||||
return DefaultRetryConfig()
|
||||
}
|
||||
|
||||
// SetConfig sets a custom retry configuration for an operation type
|
||||
func (rh *RetryHandler) SetConfig(operationType string, config RetryConfig) {
|
||||
rh.mu.Lock()
|
||||
defer rh.mu.Unlock()
|
||||
|
||||
rh.configs[operationType] = config
|
||||
rh.logger.Debug("Set retry config",
|
||||
"operation", operationType,
|
||||
"max_attempts", config.MaxAttempts,
|
||||
"initial_delay", config.InitialDelay,
|
||||
"max_delay", config.MaxDelay)
|
||||
}
|
||||
|
||||
// GetStats returns retry statistics for all operation types
|
||||
func (rh *RetryHandler) GetStats() map[string]*RetryStats {
|
||||
rh.mu.RLock()
|
||||
defer rh.mu.RUnlock()
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make(map[string]*RetryStats)
|
||||
for opType, stats := range rh.stats {
|
||||
stats.mu.RLock()
|
||||
result[opType] = &RetryStats{
|
||||
OperationType: stats.OperationType,
|
||||
TotalAttempts: stats.TotalAttempts,
|
||||
SuccessfulRetries: stats.SuccessfulRetries,
|
||||
FailedRetries: stats.FailedRetries,
|
||||
AverageAttempts: stats.AverageAttempts,
|
||||
LastAttempt: stats.LastAttempt,
|
||||
LastSuccess: stats.LastSuccess,
|
||||
LastFailure: stats.LastFailure,
|
||||
}
|
||||
stats.mu.RUnlock()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetOperationStats returns statistics for a specific operation type
|
||||
func (rh *RetryHandler) GetOperationStats(operationType string) *RetryStats {
|
||||
rh.mu.RLock()
|
||||
defer rh.mu.RUnlock()
|
||||
|
||||
stats, exists := rh.stats[operationType]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
stats.mu.RLock()
|
||||
defer stats.mu.RUnlock()
|
||||
|
||||
return &RetryStats{
|
||||
OperationType: stats.OperationType,
|
||||
TotalAttempts: stats.TotalAttempts,
|
||||
SuccessfulRetries: stats.SuccessfulRetries,
|
||||
FailedRetries: stats.FailedRetries,
|
||||
AverageAttempts: stats.AverageAttempts,
|
||||
LastAttempt: stats.LastAttempt,
|
||||
LastSuccess: stats.LastSuccess,
|
||||
LastFailure: stats.LastFailure,
|
||||
}
|
||||
}
|
||||
|
||||
// ResetStats resets statistics for all operation types
|
||||
func (rh *RetryHandler) ResetStats() {
|
||||
rh.mu.Lock()
|
||||
defer rh.mu.Unlock()
|
||||
|
||||
rh.stats = make(map[string]*RetryStats)
|
||||
rh.logger.Info("Reset retry statistics")
|
||||
}
|
||||
|
||||
// Enable enables the retry handler
|
||||
func (rh *RetryHandler) Enable() {
|
||||
rh.mu.Lock()
|
||||
defer rh.mu.Unlock()
|
||||
rh.enabled = true
|
||||
rh.logger.Info("Retry handler enabled")
|
||||
}
|
||||
|
||||
// Disable disables the retry handler
|
||||
func (rh *RetryHandler) Disable() {
|
||||
rh.mu.Lock()
|
||||
defer rh.mu.Unlock()
|
||||
rh.enabled = false
|
||||
rh.logger.Info("Retry handler disabled")
|
||||
}
|
||||
|
||||
// IsEnabled returns whether the retry handler is enabled
|
||||
func (rh *RetryHandler) IsEnabled() bool {
|
||||
rh.mu.RLock()
|
||||
defer rh.mu.RUnlock()
|
||||
return rh.enabled
|
||||
}
|
||||
|
||||
// GetHealthSummary returns a health summary based on retry statistics
|
||||
func (rh *RetryHandler) GetHealthSummary() map[string]interface{} {
|
||||
stats := rh.GetStats()
|
||||
|
||||
summary := map[string]interface{}{
|
||||
"enabled": rh.enabled,
|
||||
"total_operations": len(stats),
|
||||
"healthy_operations": 0,
|
||||
"unhealthy_operations": 0,
|
||||
"operation_details": make(map[string]interface{}),
|
||||
}
|
||||
|
||||
for opType, opStats := range stats {
|
||||
total := opStats.SuccessfulRetries + opStats.FailedRetries
|
||||
successRate := 0.0
|
||||
if total > 0 {
|
||||
successRate = float64(opStats.SuccessfulRetries) / float64(total)
|
||||
}
|
||||
|
||||
isHealthy := successRate >= 0.9 && opStats.AverageAttempts <= 2.0
|
||||
|
||||
if isHealthy {
|
||||
summary["healthy_operations"] = summary["healthy_operations"].(int) + 1
|
||||
} else {
|
||||
summary["unhealthy_operations"] = summary["unhealthy_operations"].(int) + 1
|
||||
}
|
||||
|
||||
summary["operation_details"].(map[string]interface{})[opType] = map[string]interface{}{
|
||||
"success_rate": successRate,
|
||||
"average_attempts": opStats.AverageAttempts,
|
||||
"total_operations": total,
|
||||
"is_healthy": isHealthy,
|
||||
"last_success": opStats.LastSuccess,
|
||||
"last_failure": opStats.LastFailure,
|
||||
}
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
362
orig/internal/recovery/retry_handler_test.go
Normal file
362
orig/internal/recovery/retry_handler_test.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
func TestRetryHandler_ExecuteWithRetry_Success(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
|
||||
attempts := 0
|
||||
operation := func(ctx context.Context, attempt int) error {
|
||||
attempts++
|
||||
if attempts == 2 {
|
||||
return nil // Success on second attempt
|
||||
}
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
|
||||
result := handler.ExecuteWithRetry(context.Background(), "test_operation", operation)
|
||||
|
||||
assert.True(t, result.Success)
|
||||
assert.Equal(t, 2, result.Attempts)
|
||||
assert.Nil(t, result.LastError)
|
||||
assert.Equal(t, 2, attempts)
|
||||
}
|
||||
|
||||
func TestRetryHandler_ExecuteWithRetry_MaxAttemptsReached(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
|
||||
attempts := 0
|
||||
operation := func(ctx context.Context, attempt int) error {
|
||||
attempts++
|
||||
return errors.New("persistent failure")
|
||||
}
|
||||
|
||||
result := handler.ExecuteWithRetry(context.Background(), "test_operation", operation)
|
||||
|
||||
assert.False(t, result.Success)
|
||||
assert.Equal(t, 3, result.Attempts) // Default max attempts
|
||||
assert.NotNil(t, result.LastError)
|
||||
assert.Equal(t, "persistent failure", result.LastError.Error())
|
||||
assert.Equal(t, 3, attempts)
|
||||
}
|
||||
|
||||
func TestRetryHandler_ExecuteWithRetry_ContextCanceled(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
attempts := 0
|
||||
operation := func(ctx context.Context, attempt int) error {
|
||||
attempts++
|
||||
if attempts == 2 {
|
||||
cancel() // Cancel context on second attempt
|
||||
}
|
||||
return errors.New("failure")
|
||||
}
|
||||
|
||||
result := handler.ExecuteWithRetry(ctx, "test_operation", operation)
|
||||
|
||||
assert.False(t, result.Success)
|
||||
assert.LessOrEqual(t, result.Attempts, 3)
|
||||
assert.NotNil(t, result.LastError)
|
||||
}
|
||||
|
||||
func TestRetryHandler_ExecuteWithRetry_CustomConfig(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
|
||||
// Set custom configuration
|
||||
customConfig := RetryConfig{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
JitterEnabled: false,
|
||||
TimeoutPerAttempt: 1 * time.Second,
|
||||
}
|
||||
handler.SetConfig("custom_operation", customConfig)
|
||||
|
||||
attempts := 0
|
||||
operation := func(ctx context.Context, attempt int) error {
|
||||
attempts++
|
||||
return errors.New("persistent failure")
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result := handler.ExecuteWithRetry(context.Background(), "custom_operation", operation)
|
||||
duration := time.Since(start)
|
||||
|
||||
assert.False(t, result.Success)
|
||||
assert.Equal(t, 5, result.Attempts) // Custom max attempts
|
||||
assert.Equal(t, 5, attempts)
|
||||
|
||||
// Should have taken some time due to delays (at least 150ms for delays)
|
||||
expectedMinDuration := 10*time.Millisecond + 20*time.Millisecond + 40*time.Millisecond + 80*time.Millisecond
|
||||
assert.GreaterOrEqual(t, duration, expectedMinDuration)
|
||||
}
|
||||
|
||||
func TestRetryHandler_ExecuteWithRetry_Disabled(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
handler.Disable()
|
||||
|
||||
attempts := 0
|
||||
operation := func(ctx context.Context, attempt int) error {
|
||||
attempts++
|
||||
return errors.New("failure")
|
||||
}
|
||||
|
||||
result := handler.ExecuteWithRetry(context.Background(), "test_operation", operation)
|
||||
|
||||
assert.False(t, result.Success)
|
||||
assert.Equal(t, 1, result.Attempts) // Only one attempt when disabled
|
||||
assert.Equal(t, 1, attempts)
|
||||
}
|
||||
|
||||
func TestRetryHandler_CalculateDelay(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
|
||||
config := RetryConfig{
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 1 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
JitterEnabled: false,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
attempt int
|
||||
expectedMin time.Duration
|
||||
expectedMax time.Duration
|
||||
}{
|
||||
{1, 100 * time.Millisecond, 100 * time.Millisecond},
|
||||
{2, 200 * time.Millisecond, 200 * time.Millisecond},
|
||||
{3, 400 * time.Millisecond, 400 * time.Millisecond},
|
||||
{4, 800 * time.Millisecond, 800 * time.Millisecond},
|
||||
{5, 1 * time.Second, 1 * time.Second}, // Should be capped at MaxDelay
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("attempt_%d", tt.attempt), func(t *testing.T) {
|
||||
delay := handler.calculateDelay(config, tt.attempt)
|
||||
assert.GreaterOrEqual(t, delay, tt.expectedMin)
|
||||
assert.LessOrEqual(t, delay, tt.expectedMax)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryHandler_CalculateDelay_WithJitter(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
|
||||
config := RetryConfig{
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 1 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
JitterEnabled: true,
|
||||
}
|
||||
|
||||
// Test jitter variation
|
||||
delays := make([]time.Duration, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
delays[i] = handler.calculateDelay(config, 2) // 200ms base
|
||||
}
|
||||
|
||||
// Should have some variation due to jitter
|
||||
allSame := true
|
||||
for i := 1; i < len(delays); i++ {
|
||||
if delays[i] != delays[0] {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.False(t, allSame, "Jitter should cause variation in delays")
|
||||
|
||||
// All delays should be reasonable (within 10% of base)
|
||||
baseDelay := 200 * time.Millisecond
|
||||
for _, delay := range delays {
|
||||
assert.GreaterOrEqual(t, delay, baseDelay*9/10) // 10% below
|
||||
assert.LessOrEqual(t, delay, baseDelay*11/10) // 10% above
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryHandler_GetStats(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
|
||||
// Execute some operations
|
||||
successOp := func(ctx context.Context, attempt int) error {
|
||||
return nil
|
||||
}
|
||||
failOp := func(ctx context.Context, attempt int) error {
|
||||
return errors.New("failure")
|
||||
}
|
||||
|
||||
handler.ExecuteWithRetry(context.Background(), "test_success", successOp)
|
||||
handler.ExecuteWithRetry(context.Background(), "test_success", successOp)
|
||||
handler.ExecuteWithRetry(context.Background(), "test_fail", failOp)
|
||||
|
||||
stats := handler.GetStats()
|
||||
|
||||
// Check success stats
|
||||
successStats := stats["test_success"]
|
||||
require.NotNil(t, successStats)
|
||||
assert.Equal(t, 2, successStats.TotalAttempts)
|
||||
assert.Equal(t, 2, successStats.SuccessfulRetries)
|
||||
assert.Equal(t, 0, successStats.FailedRetries)
|
||||
|
||||
// Check failure stats
|
||||
failStats := stats["test_fail"]
|
||||
require.NotNil(t, failStats)
|
||||
assert.Equal(t, 3, failStats.TotalAttempts) // Default max attempts
|
||||
assert.Equal(t, 0, failStats.SuccessfulRetries)
|
||||
assert.Equal(t, 1, failStats.FailedRetries)
|
||||
}
|
||||
|
||||
func TestRetryHandler_GetHealthSummary(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
|
||||
// Execute some operations to generate stats
|
||||
successOp := func(ctx context.Context, attempt int) error {
|
||||
return nil
|
||||
}
|
||||
partialFailOp := func(ctx context.Context, attempt int) error {
|
||||
if attempt < 2 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2 immediate successes
|
||||
handler.ExecuteWithRetry(context.Background(), "immediate_success", successOp)
|
||||
handler.ExecuteWithRetry(context.Background(), "immediate_success", successOp)
|
||||
|
||||
// 1 success after retry
|
||||
handler.ExecuteWithRetry(context.Background(), "retry_success", partialFailOp)
|
||||
|
||||
summary := handler.GetHealthSummary()
|
||||
|
||||
assert.True(t, summary["enabled"].(bool))
|
||||
assert.Equal(t, 2, summary["total_operations"].(int))
|
||||
assert.Equal(t, 2, summary["healthy_operations"].(int))
|
||||
assert.Equal(t, 0, summary["unhealthy_operations"].(int))
|
||||
|
||||
// Check operation details
|
||||
details := summary["operation_details"].(map[string]interface{})
|
||||
|
||||
immediateDetails := details["immediate_success"].(map[string]interface{})
|
||||
assert.Equal(t, 1.0, immediateDetails["success_rate"].(float64))
|
||||
assert.Equal(t, 1.0, immediateDetails["average_attempts"].(float64))
|
||||
assert.True(t, immediateDetails["is_healthy"].(bool))
|
||||
|
||||
retryDetails := details["retry_success"].(map[string]interface{})
|
||||
assert.Equal(t, 1.0, retryDetails["success_rate"].(float64))
|
||||
assert.Equal(t, 2.0, retryDetails["average_attempts"].(float64))
|
||||
assert.True(t, retryDetails["is_healthy"].(bool)) // Still healthy despite retries
|
||||
}
|
||||
|
||||
func TestRetryHandler_ConcurrentExecution(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
|
||||
const numGoroutines = 50
|
||||
const operationsPerGoroutine = 20
|
||||
|
||||
done := make(chan bool, numGoroutines)
|
||||
successCount := make(chan int, numGoroutines)
|
||||
|
||||
operation := func(ctx context.Context, attempt int) error {
|
||||
// 80% success rate
|
||||
if attempt <= 1 && time.Now().UnixNano()%5 != 0 {
|
||||
return nil
|
||||
}
|
||||
if attempt == 2 {
|
||||
return nil // Always succeed on second attempt
|
||||
}
|
||||
return errors.New("failure")
|
||||
}
|
||||
|
||||
// Launch concurrent retry operations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
successes := 0
|
||||
for j := 0; j < operationsPerGoroutine; j++ {
|
||||
result := handler.ExecuteWithRetry(context.Background(),
|
||||
fmt.Sprintf("concurrent_op_%d", id), operation)
|
||||
if result.Success {
|
||||
successes++
|
||||
}
|
||||
}
|
||||
successCount <- successes
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Collect results
|
||||
totalSuccesses := 0
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
totalSuccesses += <-successCount
|
||||
case <-time.After(30 * time.Second):
|
||||
t.Fatal("Concurrent retry test timed out")
|
||||
}
|
||||
}
|
||||
|
||||
totalOperations := numGoroutines * operationsPerGoroutine
|
||||
successRate := float64(totalSuccesses) / float64(totalOperations)
|
||||
|
||||
t.Logf("Concurrent execution: %d/%d operations succeeded (%.2f%%)",
|
||||
totalSuccesses, totalOperations, successRate*100)
|
||||
|
||||
// Should have high success rate due to retries
|
||||
assert.GreaterOrEqual(t, successRate, 0.8, "Success rate should be at least 80%")
|
||||
|
||||
// Verify stats are consistent
|
||||
stats := handler.GetStats()
|
||||
assert.NotEmpty(t, stats, "Should have recorded stats")
|
||||
}
|
||||
|
||||
func TestRetryHandler_EdgeCases(t *testing.T) {
|
||||
log := logger.New("debug", "text", "")
|
||||
handler := NewRetryHandler(log)
|
||||
|
||||
t.Run("nil operation", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
handler.ExecuteWithRetry(context.Background(), "nil_op", nil)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("empty operation type", func(t *testing.T) {
|
||||
operation := func(ctx context.Context, attempt int) error {
|
||||
return nil
|
||||
}
|
||||
result := handler.ExecuteWithRetry(context.Background(), "", operation)
|
||||
assert.True(t, result.Success)
|
||||
})
|
||||
|
||||
t.Run("very long operation type", func(t *testing.T) {
|
||||
longName := string(make([]byte, 1000))
|
||||
operation := func(ctx context.Context, attempt int) error {
|
||||
return nil
|
||||
}
|
||||
result := handler.ExecuteWithRetry(context.Background(), longName, operation)
|
||||
assert.True(t, result.Success)
|
||||
})
|
||||
}
|
||||
493
orig/internal/registry/contract_registry.go
Normal file
493
orig/internal/registry/contract_registry.go
Normal file
@@ -0,0 +1,493 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/contracts"
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// ContractInfo contains comprehensive information about a contract
|
||||
type ContractInfo struct {
|
||||
Address common.Address
|
||||
Type contracts.ContractType
|
||||
Name string
|
||||
Symbol string
|
||||
Decimals uint8
|
||||
Token0 common.Address // For pools
|
||||
Token1 common.Address // For pools
|
||||
Fee uint32 // For V3 pools (in basis points)
|
||||
Factory common.Address
|
||||
Protocol string
|
||||
IsVerified bool
|
||||
LastUpdated time.Time
|
||||
Confidence float64
|
||||
}
|
||||
|
||||
// ContractRegistry provides authoritative mapping of contract addresses to their types and metadata
|
||||
type ContractRegistry struct {
|
||||
mu sync.RWMutex
|
||||
contracts map[common.Address]*ContractInfo
|
||||
tokensBySymbol map[string]common.Address
|
||||
poolsByTokenPair map[string][]common.Address // "token0:token1" -> pool addresses
|
||||
detector *contracts.ContractDetector
|
||||
logger *logger.Logger
|
||||
updateInterval time.Duration
|
||||
lastFullUpdate time.Time
|
||||
}
|
||||
|
||||
// NewContractRegistry creates a new contract registry
|
||||
func NewContractRegistry(detector *contracts.ContractDetector, logger *logger.Logger) *ContractRegistry {
|
||||
registry := &ContractRegistry{
|
||||
contracts: make(map[common.Address]*ContractInfo),
|
||||
tokensBySymbol: make(map[string]common.Address),
|
||||
poolsByTokenPair: make(map[string][]common.Address),
|
||||
detector: detector,
|
||||
logger: logger,
|
||||
updateInterval: 24 * time.Hour, // Update every 24 hours
|
||||
lastFullUpdate: time.Time{},
|
||||
}
|
||||
|
||||
// Initialize with known Arbitrum contracts
|
||||
registry.initializeKnownContracts()
|
||||
return registry
|
||||
}
|
||||
|
||||
// initializeKnownContracts populates the registry with well-known Arbitrum contracts
|
||||
func (cr *ContractRegistry) initializeKnownContracts() {
|
||||
cr.mu.Lock()
|
||||
defer cr.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Major ERC-20 tokens on Arbitrum
|
||||
knownTokens := map[common.Address]*ContractInfo{
|
||||
common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"): {
|
||||
Address: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"),
|
||||
Type: contracts.ContractTypeERC20Token,
|
||||
Name: "Wrapped Ether",
|
||||
Symbol: "WETH",
|
||||
Decimals: 18,
|
||||
Protocol: "Arbitrum",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
common.HexToAddress("0xA0b86a33E6D8E4BBa6Fd6bD5BA0e2FF8A1e8B8D4"): {
|
||||
Address: common.HexToAddress("0xA0b86a33E6D8E4BBa6Fd6bD5BA0e2FF8A1e8B8D4"),
|
||||
Type: contracts.ContractTypeERC20Token,
|
||||
Name: "USD Coin",
|
||||
Symbol: "USDC",
|
||||
Decimals: 6,
|
||||
Protocol: "Arbitrum",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"): {
|
||||
Address: common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"),
|
||||
Type: contracts.ContractTypeERC20Token,
|
||||
Name: "Tether USD",
|
||||
Symbol: "USDT",
|
||||
Decimals: 6,
|
||||
Protocol: "Arbitrum",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"): {
|
||||
Address: common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"),
|
||||
Type: contracts.ContractTypeERC20Token,
|
||||
Name: "Wrapped BTC",
|
||||
Symbol: "WBTC",
|
||||
Decimals: 8,
|
||||
Protocol: "Arbitrum",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548"): {
|
||||
Address: common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548"),
|
||||
Type: contracts.ContractTypeERC20Token,
|
||||
Name: "Arbitrum",
|
||||
Symbol: "ARB",
|
||||
Decimals: 18,
|
||||
Protocol: "Arbitrum",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Major Uniswap V3 pools on Arbitrum
|
||||
knownPools := map[common.Address]*ContractInfo{
|
||||
common.HexToAddress("0xC6962004f452bE9203591991D15f6b388e09E8D0"): {
|
||||
Address: common.HexToAddress("0xC6962004f452bE9203591991D15f6b388e09E8D0"),
|
||||
Type: contracts.ContractTypeUniswapV3Pool,
|
||||
Name: "USDC/WETH 0.05%",
|
||||
Token0: common.HexToAddress("0xA0b86a33E6D8E4BBa6Fd6bD5BA0e2FF8A1e8B8D4"), // USDC
|
||||
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
|
||||
Fee: 500, // 0.05%
|
||||
Factory: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
|
||||
Protocol: "UniswapV3",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d"): {
|
||||
Address: common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d"),
|
||||
Type: contracts.ContractTypeUniswapV3Pool,
|
||||
Name: "USDC/WETH 0.3%",
|
||||
Token0: common.HexToAddress("0xA0b86a33E6D8E4BBa6Fd6bD5BA0e2FF8A1e8B8D4"), // USDC
|
||||
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
|
||||
Fee: 3000, // 0.3%
|
||||
Factory: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
|
||||
Protocol: "UniswapV3",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
common.HexToAddress("0x2f5e87C9312fa29aed5c179E456625D79015299c"): {
|
||||
Address: common.HexToAddress("0x2f5e87C9312fa29aed5c179E456625D79015299c"),
|
||||
Type: contracts.ContractTypeUniswapV3Pool,
|
||||
Name: "WBTC/WETH 0.05%",
|
||||
Token0: common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"), // WBTC
|
||||
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
|
||||
Fee: 500, // 0.05%
|
||||
Factory: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
|
||||
Protocol: "UniswapV3",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
common.HexToAddress("0x641C00A822e8b671738d32a431a4Fb6074E5c79d"): {
|
||||
Address: common.HexToAddress("0x641C00A822e8b671738d32a431a4Fb6074E5c79d"),
|
||||
Type: contracts.ContractTypeUniswapV3Pool,
|
||||
Name: "USDT/WETH 0.05%",
|
||||
Token0: common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"), // USDT
|
||||
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
|
||||
Fee: 500, // 0.05%
|
||||
Factory: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
|
||||
Protocol: "UniswapV3",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
common.HexToAddress("0xFe7D6a84287235C7b4b57C4fEb9a44d4C6Ed3BB8"): {
|
||||
Address: common.HexToAddress("0xFe7D6a84287235C7b4b57C4fEb9a44d4C6Ed3BB8"),
|
||||
Type: contracts.ContractTypeUniswapV3Pool,
|
||||
Name: "ARB/WETH 0.05%",
|
||||
Token0: common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548"), // ARB
|
||||
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
|
||||
Fee: 500, // 0.05%
|
||||
Factory: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
|
||||
Protocol: "UniswapV3",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Major routers on Arbitrum
|
||||
knownRouters := map[common.Address]*ContractInfo{
|
||||
common.HexToAddress("0xE592427A0AEce92De3Edee1F18E0157C05861564"): {
|
||||
Address: common.HexToAddress("0xE592427A0AEce92De3Edee1F18E0157C05861564"),
|
||||
Type: contracts.ContractTypeUniswapV3Router,
|
||||
Name: "Uniswap V3 Router",
|
||||
Protocol: "UniswapV3",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
common.HexToAddress("0x68b3465833fb72A70ecDF485E0e4C7bD8665Fc45"): {
|
||||
Address: common.HexToAddress("0x68b3465833fb72A70ecDF485E0e4C7bD8665Fc45"),
|
||||
Type: contracts.ContractTypeUniswapV3Router,
|
||||
Name: "Uniswap V3 Router 2",
|
||||
Protocol: "UniswapV3",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
common.HexToAddress("0x1b02dA8Cb0d097eB8D57A175b88c7D8b47997506"): {
|
||||
Address: common.HexToAddress("0x1b02dA8Cb0d097eB8D57A175b88c7D8b47997506"),
|
||||
Type: contracts.ContractTypeUniswapV2Router,
|
||||
Name: "SushiSwap Router",
|
||||
Protocol: "SushiSwap",
|
||||
IsVerified: true,
|
||||
LastUpdated: now,
|
||||
Confidence: 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Add all known contracts
|
||||
for addr, info := range knownTokens {
|
||||
cr.contracts[addr] = info
|
||||
cr.tokensBySymbol[info.Symbol] = addr
|
||||
}
|
||||
|
||||
for addr, info := range knownPools {
|
||||
cr.contracts[addr] = info
|
||||
// Index pools by token pair
|
||||
if info.Token0 != (common.Address{}) && info.Token1 != (common.Address{}) {
|
||||
pairKey := cr.makeTokenPairKey(info.Token0, info.Token1)
|
||||
cr.poolsByTokenPair[pairKey] = append(cr.poolsByTokenPair[pairKey], addr)
|
||||
}
|
||||
}
|
||||
|
||||
for addr, info := range knownRouters {
|
||||
cr.contracts[addr] = info
|
||||
}
|
||||
|
||||
cr.logger.Info("Contract registry initialized",
|
||||
"tokens", len(knownTokens),
|
||||
"pools", len(knownPools),
|
||||
"routers", len(knownRouters))
|
||||
}
|
||||
|
||||
// GetContractInfo retrieves contract information, using cache first then detection
|
||||
func (cr *ContractRegistry) GetContractInfo(ctx context.Context, address common.Address) (*ContractInfo, error) {
|
||||
// Try cache first
|
||||
if info := cr.getCachedInfo(address); info != nil {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// Not in cache, use detector
|
||||
detection := cr.detector.DetectContractType(ctx, address)
|
||||
if detection.Error != nil {
|
||||
return nil, fmt.Errorf("contract detection failed: %w", detection.Error)
|
||||
}
|
||||
|
||||
// Create contract info from detection
|
||||
info := &ContractInfo{
|
||||
Address: address,
|
||||
Type: detection.ContractType,
|
||||
Name: fmt.Sprintf("Unknown %s", detection.ContractType.String()),
|
||||
Protocol: "Unknown",
|
||||
IsVerified: false,
|
||||
LastUpdated: time.Now(),
|
||||
Confidence: detection.Confidence,
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
cr.cacheContractInfo(info)
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// getCachedInfo safely retrieves cached contract info
|
||||
func (cr *ContractRegistry) getCachedInfo(address common.Address) *ContractInfo {
|
||||
cr.mu.RLock()
|
||||
defer cr.mu.RUnlock()
|
||||
return cr.contracts[address]
|
||||
}
|
||||
|
||||
// cacheContractInfo safely caches contract info
|
||||
func (cr *ContractRegistry) cacheContractInfo(info *ContractInfo) {
|
||||
cr.mu.Lock()
|
||||
defer cr.mu.Unlock()
|
||||
cr.contracts[info.Address] = info
|
||||
}
|
||||
|
||||
// IsKnownToken checks if an address is a known ERC-20 token
|
||||
func (cr *ContractRegistry) IsKnownToken(address common.Address) bool {
|
||||
cr.mu.RLock()
|
||||
defer cr.mu.RUnlock()
|
||||
|
||||
if info, exists := cr.contracts[address]; exists {
|
||||
return info.Type == contracts.ContractTypeERC20Token
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsKnownPool checks if an address is a known pool
|
||||
func (cr *ContractRegistry) IsKnownPool(address common.Address) bool {
|
||||
cr.mu.RLock()
|
||||
defer cr.mu.RUnlock()
|
||||
|
||||
if info, exists := cr.contracts[address]; exists {
|
||||
return info.Type == contracts.ContractTypeUniswapV2Pool ||
|
||||
info.Type == contracts.ContractTypeUniswapV3Pool
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsKnownRouter checks if an address is a known router
|
||||
func (cr *ContractRegistry) IsKnownRouter(address common.Address) bool {
|
||||
cr.mu.RLock()
|
||||
defer cr.mu.RUnlock()
|
||||
|
||||
if info, exists := cr.contracts[address]; exists {
|
||||
return info.Type == contracts.ContractTypeUniswapV2Router ||
|
||||
info.Type == contracts.ContractTypeUniswapV3Router ||
|
||||
info.Type == contracts.ContractTypeUniversalRouter
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetTokenBySymbol retrieves a token address by symbol
|
||||
func (cr *ContractRegistry) GetTokenBySymbol(symbol string) (common.Address, bool) {
|
||||
cr.mu.RLock()
|
||||
defer cr.mu.RUnlock()
|
||||
|
||||
addr, exists := cr.tokensBySymbol[symbol]
|
||||
return addr, exists
|
||||
}
|
||||
|
||||
// GetPoolsForTokenPair retrieves pools for a given token pair
|
||||
func (cr *ContractRegistry) GetPoolsForTokenPair(token0, token1 common.Address) []common.Address {
|
||||
cr.mu.RLock()
|
||||
defer cr.mu.RUnlock()
|
||||
|
||||
pairKey := cr.makeTokenPairKey(token0, token1)
|
||||
return cr.poolsByTokenPair[pairKey]
|
||||
}
|
||||
|
||||
// makeTokenPairKey creates a consistent key for token pairs
|
||||
func (cr *ContractRegistry) makeTokenPairKey(token0, token1 common.Address) string {
|
||||
// Ensure consistent ordering
|
||||
if token0.Big().Cmp(token1.Big()) > 0 {
|
||||
token0, token1 = token1, token0
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", token0.Hex(), token1.Hex())
|
||||
}
|
||||
|
||||
// GetKnownContracts returns all cached contracts
|
||||
func (cr *ContractRegistry) GetKnownContracts() map[common.Address]*ContractInfo {
|
||||
cr.mu.RLock()
|
||||
defer cr.mu.RUnlock()
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make(map[common.Address]*ContractInfo)
|
||||
for addr, info := range cr.contracts {
|
||||
result[addr] = info
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UpdateContractInfo updates contract information (for dynamic discovery)
|
||||
func (cr *ContractRegistry) UpdateContractInfo(info *ContractInfo) {
|
||||
cr.mu.Lock()
|
||||
defer cr.mu.Unlock()
|
||||
|
||||
info.LastUpdated = time.Now()
|
||||
cr.contracts[info.Address] = info
|
||||
|
||||
// Update indexes
|
||||
if info.Type == contracts.ContractTypeERC20Token && info.Symbol != "" {
|
||||
cr.tokensBySymbol[info.Symbol] = info.Address
|
||||
}
|
||||
|
||||
if (info.Type == contracts.ContractTypeUniswapV2Pool || info.Type == contracts.ContractTypeUniswapV3Pool) &&
|
||||
info.Token0 != (common.Address{}) && info.Token1 != (common.Address{}) {
|
||||
pairKey := cr.makeTokenPairKey(info.Token0, info.Token1)
|
||||
// Check if already exists in slice
|
||||
pools := cr.poolsByTokenPair[pairKey]
|
||||
exists := false
|
||||
for _, pool := range pools {
|
||||
if pool == info.Address {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
cr.poolsByTokenPair[pairKey] = append(pools, info.Address)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetCacheStats returns statistics about the cached contracts
|
||||
func (cr *ContractRegistry) GetCacheStats() map[string]int {
|
||||
cr.mu.RLock()
|
||||
defer cr.mu.RUnlock()
|
||||
|
||||
stats := map[string]int{
|
||||
"total": len(cr.contracts),
|
||||
"tokens": 0,
|
||||
"v2_pools": 0,
|
||||
"v3_pools": 0,
|
||||
"routers": 0,
|
||||
"verified": 0,
|
||||
}
|
||||
|
||||
for _, info := range cr.contracts {
|
||||
switch info.Type {
|
||||
case contracts.ContractTypeERC20Token:
|
||||
stats["tokens"]++
|
||||
case contracts.ContractTypeUniswapV2Pool:
|
||||
stats["v2_pools"]++
|
||||
case contracts.ContractTypeUniswapV3Pool:
|
||||
stats["v3_pools"]++
|
||||
case contracts.ContractTypeUniswapV2Router, contracts.ContractTypeUniswapV3Router:
|
||||
stats["routers"]++
|
||||
}
|
||||
|
||||
if info.IsVerified {
|
||||
stats["verified"]++
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetPoolInfo returns pool information for a given address
|
||||
func (cr *ContractRegistry) GetPoolInfo(address common.Address) *ContractInfo {
|
||||
cr.mu.RLock()
|
||||
defer cr.mu.RUnlock()
|
||||
|
||||
if info, exists := cr.contracts[address]; exists {
|
||||
if info.Type == contracts.ContractTypeUniswapV2Pool || info.Type == contracts.ContractTypeUniswapV3Pool {
|
||||
return info
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPool adds a pool to the registry
|
||||
func (cr *ContractRegistry) AddPool(poolAddress, token0, token1 common.Address, protocol string) {
|
||||
cr.mu.Lock()
|
||||
defer cr.mu.Unlock()
|
||||
|
||||
// Determine pool type based on protocol
|
||||
var poolType contracts.ContractType
|
||||
switch protocol {
|
||||
case "UniswapV2":
|
||||
poolType = contracts.ContractTypeUniswapV2Pool
|
||||
case "UniswapV3":
|
||||
poolType = contracts.ContractTypeUniswapV3Pool
|
||||
default:
|
||||
poolType = contracts.ContractTypeUniswapV2Pool // Default to V2
|
||||
}
|
||||
|
||||
// Add or update pool info
|
||||
info := &ContractInfo{
|
||||
Address: poolAddress,
|
||||
Type: poolType,
|
||||
Token0: token0,
|
||||
Token1: token1,
|
||||
Protocol: protocol,
|
||||
IsVerified: false, // Runtime discovered pools are not pre-verified
|
||||
LastUpdated: time.Now(),
|
||||
Confidence: 0.8, // High confidence for runtime discovered pools
|
||||
}
|
||||
|
||||
cr.contracts[poolAddress] = info
|
||||
|
||||
// Update token pair mapping
|
||||
pairKey := cr.makeTokenPairKey(token0, token1)
|
||||
if pools, exists := cr.poolsByTokenPair[pairKey]; exists {
|
||||
// Check if pool already exists in the list
|
||||
for _, existingPool := range pools {
|
||||
if existingPool == poolAddress {
|
||||
return // Already exists
|
||||
}
|
||||
}
|
||||
cr.poolsByTokenPair[pairKey] = append(pools, poolAddress)
|
||||
} else {
|
||||
cr.poolsByTokenPair[pairKey] = []common.Address{poolAddress}
|
||||
}
|
||||
}
|
||||
292
orig/internal/secure/config_manager.go
Normal file
292
orig/internal/secure/config_manager.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package secure
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// ConfigManager handles secure configuration management
|
||||
type ConfigManager struct {
|
||||
logger *logger.Logger
|
||||
aesGCM cipher.AEAD
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewConfigManager creates a new secure configuration manager
|
||||
func NewConfigManager(logger *logger.Logger) (*ConfigManager, error) {
|
||||
// Get encryption key from environment or generate one
|
||||
keyStr := os.Getenv("MEV_BOT_CONFIG_KEY")
|
||||
if keyStr == "" {
|
||||
return nil, errors.New("MEV_BOT_CONFIG_KEY environment variable not set")
|
||||
}
|
||||
|
||||
// Create SHA-256 hash of the key for AES-256
|
||||
key := sha256.Sum256([]byte(keyStr))
|
||||
|
||||
// Create AES cipher
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM mode
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM mode: %w", err)
|
||||
}
|
||||
|
||||
return &ConfigManager{
|
||||
logger: logger,
|
||||
aesGCM: aesGCM,
|
||||
key: key[:],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptValue encrypts a configuration value
|
||||
func (cm *ConfigManager) EncryptValue(plaintext string) (string, error) {
|
||||
// Create a random nonce
|
||||
nonce := make([]byte, cm.aesGCM.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the plaintext
|
||||
ciphertext := cm.aesGCM.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
|
||||
// Encode to base64 for storage
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptValue decrypts a configuration value
|
||||
func (cm *ConfigManager) DecryptValue(ciphertext string) (string, error) {
|
||||
// Decode from base64
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode base64: %w", err)
|
||||
}
|
||||
|
||||
// Check minimum length (nonce size)
|
||||
nonceSize := cm.aesGCM.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce, ciphertext_bytes := data[:nonceSize], data[nonceSize:]
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := cm.aesGCM.Open(nil, nonce, ciphertext_bytes, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// GetSecureValue gets a secure value from environment with fallback to encrypted storage
|
||||
func (cm *ConfigManager) GetSecureValue(key string) (string, error) {
|
||||
// First try environment variable
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Try encrypted environment variable
|
||||
encryptedKey := key + "_ENCRYPTED"
|
||||
if encryptedValue := os.Getenv(encryptedKey); encryptedValue != "" {
|
||||
return cm.DecryptValue(encryptedValue)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("secure value not found for key: %s", key)
|
||||
}
|
||||
|
||||
// SecureConfig holds encrypted configuration values
|
||||
type SecureConfig struct {
|
||||
manager *ConfigManager
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
// NewSecureConfig creates a new secure configuration
|
||||
func NewSecureConfig(manager *ConfigManager) *SecureConfig {
|
||||
return &SecureConfig{
|
||||
manager: manager,
|
||||
values: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Set stores a value securely
|
||||
func (sc *SecureConfig) Set(key, value string) error {
|
||||
encrypted, err := sc.manager.EncryptValue(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt value for key %s: %w", key, err)
|
||||
}
|
||||
|
||||
sc.values[key] = encrypted
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a value securely
|
||||
func (sc *SecureConfig) Get(key string) (string, error) {
|
||||
// Check local encrypted storage first
|
||||
if encrypted, exists := sc.values[key]; exists {
|
||||
return sc.manager.DecryptValue(encrypted)
|
||||
}
|
||||
|
||||
// Fallback to secure environment lookup
|
||||
return sc.manager.GetSecureValue(key)
|
||||
}
|
||||
|
||||
// GetRequired retrieves a required value, returning error if not found
|
||||
func (sc *SecureConfig) GetRequired(key string) (string, error) {
|
||||
value, err := sc.Get(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("required configuration value missing: %s", key)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return "", fmt.Errorf("required configuration value empty: %s", key)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// GetWithDefault retrieves a value with a default fallback
|
||||
func (sc *SecureConfig) GetWithDefault(key, defaultValue string) string {
|
||||
value, err := sc.Get(key)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// LoadFromEnvironment loads configuration from environment variables
|
||||
func (sc *SecureConfig) LoadFromEnvironment(keys []string) error {
|
||||
for _, key := range keys {
|
||||
value, err := sc.manager.GetSecureValue(key)
|
||||
if err != nil {
|
||||
sc.manager.logger.Warn(fmt.Sprintf("Could not load secure config for %s: %v", key, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Store encrypted in memory
|
||||
if err := sc.Set(key, value); err != nil {
|
||||
return fmt.Errorf("failed to store secure config for %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all stored values from memory
|
||||
func (sc *SecureConfig) Clear() {
|
||||
// Zero out the map entries before clearing
|
||||
for key := range sc.values {
|
||||
// Overwrite with zeros
|
||||
sc.values[key] = strings.Repeat("0", len(sc.values[key]))
|
||||
delete(sc.values, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks that all required configuration is present
|
||||
func (sc *SecureConfig) Validate(requiredKeys []string) error {
|
||||
var missingKeys []string
|
||||
|
||||
for _, key := range requiredKeys {
|
||||
if _, err := sc.GetRequired(key); err != nil {
|
||||
missingKeys = append(missingKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingKeys) > 0 {
|
||||
return fmt.Errorf("missing required configuration keys: %s", strings.Join(missingKeys, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateConfigKey generates a new encryption key for configuration
|
||||
func GenerateConfigKey() (string, error) {
|
||||
key := make([]byte, 32) // 256-bit key
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random key: %w", err)
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
// ConfigValidator provides validation utilities
|
||||
type ConfigValidator struct {
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// NewConfigValidator creates a new configuration validator
|
||||
func NewConfigValidator(logger *logger.Logger) *ConfigValidator {
|
||||
return &ConfigValidator{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateURL validates that a URL is properly formatted and uses HTTPS
|
||||
func (cv *ConfigValidator) ValidateURL(url string) error {
|
||||
if url == "" {
|
||||
return errors.New("URL cannot be empty")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "wss://") {
|
||||
return errors.New("URL must use HTTPS or WSS protocol")
|
||||
}
|
||||
|
||||
// Additional validation could go here (DNS lookup, connection test, etc.)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAPIKey validates that an API key meets minimum security requirements
|
||||
func (cv *ConfigValidator) ValidateAPIKey(key string) error {
|
||||
if key == "" {
|
||||
return errors.New("API key cannot be empty")
|
||||
}
|
||||
|
||||
if len(key) < 32 {
|
||||
return errors.New("API key must be at least 32 characters")
|
||||
}
|
||||
|
||||
// Check for basic entropy (not all same character, contains mixed case, etc.)
|
||||
if strings.Count(key, string(key[0])) == len(key) {
|
||||
return errors.New("API key lacks sufficient entropy")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAddress validates an Ethereum address
|
||||
func (cv *ConfigValidator) ValidateAddress(address string) error {
|
||||
if address == "" {
|
||||
return errors.New("address cannot be empty")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(address, "0x") {
|
||||
return errors.New("address must start with 0x")
|
||||
}
|
||||
|
||||
if len(address) != 42 { // 0x + 40 hex chars
|
||||
return errors.New("address must be 42 characters long")
|
||||
}
|
||||
|
||||
// Validate hex format
|
||||
for i, char := range address[2:] {
|
||||
if !((char >= '0' && char <= '9') || (char >= 'a' && char <= 'f') || (char >= 'A' && char <= 'F')) {
|
||||
return fmt.Errorf("invalid hex character at position %d: %c", i+2, char)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
116
orig/internal/tokens/arbitrum.go
Normal file
116
orig/internal/tokens/arbitrum.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package tokens
|
||||
|
||||
import (
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// ArbitrumTokens contains the addresses of popular tokens on Arbitrum
|
||||
type ArbitrumTokens struct {
|
||||
// Tier 1 - Major Assets
|
||||
WETH common.Address
|
||||
USDC common.Address
|
||||
USDT common.Address
|
||||
ARB common.Address
|
||||
WBTC common.Address
|
||||
DAI common.Address
|
||||
LINK common.Address
|
||||
UNI common.Address
|
||||
GMX common.Address
|
||||
GRT common.Address
|
||||
|
||||
// Tier 2 - DeFi Blue Chips
|
||||
USDCe common.Address // Bridged USDC
|
||||
PENDLE common.Address
|
||||
RDNT common.Address
|
||||
MAGIC common.Address
|
||||
GRAIL common.Address
|
||||
|
||||
// Tier 3 - Additional High Volume
|
||||
AAVE common.Address
|
||||
CRV common.Address
|
||||
BAL common.Address
|
||||
COMP common.Address
|
||||
MKR common.Address
|
||||
}
|
||||
|
||||
// GetArbitrumTokens returns the addresses of popular tokens on Arbitrum
|
||||
func GetArbitrumTokens() *ArbitrumTokens {
|
||||
return &ArbitrumTokens{
|
||||
// Tier 1 - Major Assets
|
||||
WETH: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // Wrapped Ether
|
||||
USDC: common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831"), // USD Coin (Native)
|
||||
USDT: common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"), // Tether USD
|
||||
ARB: common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548"), // Arbitrum Token
|
||||
WBTC: common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"), // Wrapped Bitcoin
|
||||
DAI: common.HexToAddress("0xDA10009cBd5D07dd0CeCc66161FC93D7c9000da1"), // Dai Stablecoin
|
||||
LINK: common.HexToAddress("0xf97f4df75117a78c1A5a0DBb814Af92458539FB4"), // ChainLink Token
|
||||
UNI: common.HexToAddress("0xFa7F8980b0f1E64A2062791cc3b0871572f1F7f0"), // Uniswap
|
||||
GMX: common.HexToAddress("0xfc5A1A6EB076a2C7aD06eD22C90d7E710E35ad0a"), // GMX
|
||||
GRT: common.HexToAddress("0x9623063377AD1B27544C965cCd7342f7EA7e88C7"), // The Graph
|
||||
|
||||
// Tier 2 - DeFi Blue Chips
|
||||
USDCe: common.HexToAddress("0xFF970A61A04b1cA14834A43f5dE4533eBDDB5CC8"), // USD Coin (Bridged)
|
||||
PENDLE: common.HexToAddress("0x0c880f6761F1af8d9Aa9C466984b80DAb9a8c9e8"), // Pendle
|
||||
RDNT: common.HexToAddress("0x3082CC23568eA640225c2467653dB90e9250AaA0"), // Radiant Capital
|
||||
MAGIC: common.HexToAddress("0x539bdE0d7Dbd336b79148AA742883198BBF60342"), // Magic
|
||||
GRAIL: common.HexToAddress("0x3d9907F9a368ad0a51Be60f7Da3b97cf940982D8"), // Camelot (GRAIL)
|
||||
|
||||
// Tier 3 - Additional High Volume
|
||||
AAVE: common.HexToAddress("0xba5DdD1f9d7F570dc94a51479a000E3BCE967196"), // Aave
|
||||
CRV: common.HexToAddress("0x11cDb42B0EB46D95f990BeDD4695A6e3fA034978"), // Curve
|
||||
BAL: common.HexToAddress("0x040d1EdC9569d4Bab2D15287Dc5A4F10F56a56B8"), // Balancer
|
||||
COMP: common.HexToAddress("0x354A6dA3fcde098F8389cad84b0182725c6C91dE"), // Compound
|
||||
MKR: common.HexToAddress("0x2e9a6Df78E42a30712c10a9Dc4b1C8656f8F2879"), // Maker
|
||||
}
|
||||
}
|
||||
|
||||
// GetTriangularPaths returns common triangular arbitrage paths on Arbitrum
|
||||
func GetTriangularPaths() []TriangularPath {
|
||||
tokens := GetArbitrumTokens()
|
||||
|
||||
return []TriangularPath{
|
||||
{
|
||||
Name: "USDC-WETH-WBTC-USDC",
|
||||
Tokens: []common.Address{tokens.USDC, tokens.WETH, tokens.WBTC},
|
||||
},
|
||||
{
|
||||
Name: "USDC-WETH-ARB-USDC",
|
||||
Tokens: []common.Address{tokens.USDC, tokens.WETH, tokens.ARB},
|
||||
},
|
||||
{
|
||||
Name: "WETH-USDC-USDT-WETH",
|
||||
Tokens: []common.Address{tokens.WETH, tokens.USDC, tokens.USDT},
|
||||
},
|
||||
{
|
||||
Name: "USDC-DAI-USDT-USDC",
|
||||
Tokens: []common.Address{tokens.USDC, tokens.DAI, tokens.USDT},
|
||||
},
|
||||
{
|
||||
Name: "WETH-ARB-GMX-WETH",
|
||||
Tokens: []common.Address{tokens.WETH, tokens.ARB, tokens.GMX},
|
||||
},
|
||||
{
|
||||
Name: "USDC-LINK-WETH-USDC",
|
||||
Tokens: []common.Address{tokens.USDC, tokens.LINK, tokens.WETH},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TriangularPath represents a triangular arbitrage path
|
||||
type TriangularPath struct {
|
||||
Name string
|
||||
Tokens []common.Address
|
||||
}
|
||||
|
||||
// GetMostLiquidTokens returns the most liquid tokens for market scanning
|
||||
func GetMostLiquidTokens() []common.Address {
|
||||
tokens := GetArbitrumTokens()
|
||||
return []common.Address{
|
||||
tokens.WETH,
|
||||
tokens.USDC,
|
||||
tokens.USDT,
|
||||
tokens.ARB,
|
||||
tokens.WBTC,
|
||||
tokens.DAI,
|
||||
}
|
||||
}
|
||||
139
orig/internal/utils/address.go
Normal file
139
orig/internal/utils/address.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/validation"
|
||||
)
|
||||
|
||||
// SafeAddressConverter provides validated address conversion functions
|
||||
type SafeAddressConverter struct {
|
||||
validator *validation.AddressValidator
|
||||
}
|
||||
|
||||
// NewSafeAddressConverter creates a new safe address converter
|
||||
func NewSafeAddressConverter() *SafeAddressConverter {
|
||||
return &SafeAddressConverter{
|
||||
validator: validation.NewAddressValidator(),
|
||||
}
|
||||
}
|
||||
|
||||
// SafeConversionResult contains the result of a safe address conversion
|
||||
type SafeConversionResult struct {
|
||||
Address common.Address
|
||||
IsValid bool
|
||||
Error error
|
||||
Warnings []string
|
||||
}
|
||||
|
||||
// SafeHexToAddress safely converts a hex string to an address with validation
|
||||
func (c *SafeAddressConverter) SafeHexToAddress(hexStr string) *SafeConversionResult {
|
||||
result := &SafeConversionResult{
|
||||
Address: common.Address{},
|
||||
IsValid: false,
|
||||
Warnings: []string{},
|
||||
}
|
||||
|
||||
// Basic input validation
|
||||
if hexStr == "" {
|
||||
result.Error = fmt.Errorf("empty hex string")
|
||||
return result
|
||||
}
|
||||
|
||||
// Normalize hex string
|
||||
normalizedHex := strings.ToLower(strings.TrimSpace(hexStr))
|
||||
if !strings.HasPrefix(normalizedHex, "0x") {
|
||||
normalizedHex = "0x" + normalizedHex
|
||||
}
|
||||
|
||||
// Length validation
|
||||
if len(normalizedHex) != 42 { // 0x + 40 hex chars
|
||||
result.Error = fmt.Errorf("invalid hex string length: %d, expected 42", len(normalizedHex))
|
||||
return result
|
||||
}
|
||||
|
||||
// Use comprehensive address validation
|
||||
validationResult := c.validator.ValidateAddress(normalizedHex)
|
||||
if !validationResult.IsValid {
|
||||
result.Error = fmt.Errorf("address validation failed: %v", validationResult.ErrorMessages)
|
||||
return result
|
||||
}
|
||||
|
||||
// Check corruption score
|
||||
if validationResult.CorruptionScore > 50 {
|
||||
result.Error = fmt.Errorf("high corruption score (%d), refusing conversion", validationResult.CorruptionScore)
|
||||
return result
|
||||
}
|
||||
|
||||
// Add warnings for moderate corruption
|
||||
if validationResult.CorruptionScore > 10 {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("moderate corruption score: %d", validationResult.CorruptionScore))
|
||||
}
|
||||
|
||||
// Add warnings from validation
|
||||
if len(validationResult.WarningMessages) > 0 {
|
||||
result.Warnings = append(result.Warnings, validationResult.WarningMessages...)
|
||||
}
|
||||
|
||||
// Convert to address
|
||||
result.Address = common.HexToAddress(normalizedHex)
|
||||
result.IsValid = true
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// SafeHexToAddressSlice safely converts multiple hex strings to addresses
|
||||
func (c *SafeAddressConverter) SafeHexToAddressSlice(hexStrings []string) ([]common.Address, []error) {
|
||||
addresses := make([]common.Address, 0, len(hexStrings))
|
||||
errors := make([]error, 0)
|
||||
|
||||
for i, hexStr := range hexStrings {
|
||||
result := c.SafeHexToAddress(hexStr)
|
||||
if !result.IsValid {
|
||||
errors = append(errors, fmt.Errorf("address %d (%s): %v", i, hexStr, result.Error))
|
||||
continue
|
||||
}
|
||||
addresses = append(addresses, result.Address)
|
||||
}
|
||||
|
||||
return addresses, errors
|
||||
}
|
||||
|
||||
// MustSafeHexToAddress safely converts hex string to address, panics on failure
|
||||
// Only use for hardcoded known-good addresses
|
||||
func (c *SafeAddressConverter) MustSafeHexToAddress(hexStr string) common.Address {
|
||||
result := c.SafeHexToAddress(hexStr)
|
||||
if !result.IsValid {
|
||||
panic(fmt.Sprintf("invalid address conversion: %s -> %v", hexStr, result.Error))
|
||||
}
|
||||
return result.Address
|
||||
}
|
||||
|
||||
// TryHexToAddress attempts conversion with fallback to zero address
|
||||
func (c *SafeAddressConverter) TryHexToAddress(hexStr string) (common.Address, bool) {
|
||||
result := c.SafeHexToAddress(hexStr)
|
||||
return result.Address, result.IsValid
|
||||
}
|
||||
|
||||
// Global converter instance for convenience
|
||||
var globalConverter = NewSafeAddressConverter()
|
||||
|
||||
// Global convenience functions
|
||||
|
||||
// SafeHexToAddress is a global convenience function for safe address conversion
|
||||
func SafeHexToAddress(hexStr string) *SafeConversionResult {
|
||||
return globalConverter.SafeHexToAddress(hexStr)
|
||||
}
|
||||
|
||||
// TryHexToAddress is a global convenience function for address conversion with fallback
|
||||
func TryHexToAddress(hexStr string) (common.Address, bool) {
|
||||
return globalConverter.TryHexToAddress(hexStr)
|
||||
}
|
||||
|
||||
// MustSafeHexToAddress is a global convenience function for known-good addresses
|
||||
func MustSafeHexToAddress(hexStr string) common.Address {
|
||||
return globalConverter.MustSafeHexToAddress(hexStr)
|
||||
}
|
||||
38
orig/internal/utils/utils.go
Normal file
38
orig/internal/utils/utils.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/big"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FormatWeiToEther formats a wei amount to ether
|
||||
func FormatWeiToEther(wei *big.Int) *big.Float {
|
||||
ether := new(big.Float).SetInt(wei)
|
||||
ether.Quo(ether, big.NewFloat(1e18))
|
||||
return ether
|
||||
}
|
||||
|
||||
// FormatTime formats a timestamp to a readable string
|
||||
func FormatTime(timestamp uint64) string {
|
||||
if timestamp > math.MaxInt64 {
|
||||
return "Invalid timestamp: exceeds maximum value"
|
||||
}
|
||||
return time.Unix(int64(timestamp), 0).Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
// Min returns the smaller of two integers
|
||||
func Min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Max returns the larger of two integers
|
||||
func Max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
704
orig/internal/validation/address.go
Normal file
704
orig/internal/validation/address.go
Normal file
@@ -0,0 +1,704 @@
|
||||
// Package validation provides comprehensive Ethereum address validation and corruption detection.
|
||||
// This package is critical for the MEV bot's security and reliability, preventing costly errors
|
||||
// from malformed or corrupted addresses that could cause transaction failures or security issues.
|
||||
//
|
||||
// Key features:
|
||||
// - Multi-layer address validation (format, length, corruption detection)
|
||||
// - Contract type classification and prevention of ERC-20/pool confusion
|
||||
// - Corruption scoring system to identify suspicious addresses
|
||||
// - Known contract registry for instant validation of major protocols
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// AddressValidator provides comprehensive Ethereum address validation with advanced
|
||||
// corruption detection and contract type classification. This validator is designed
|
||||
// to prevent the costly errors that can occur when malformed addresses are used
|
||||
// in contract calls, particularly in high-frequency MEV operations.
|
||||
//
|
||||
// The validator implements multiple validation layers:
|
||||
// 1. Basic format validation (hex format, length, prefix)
|
||||
// 2. Corruption pattern detection (repetitive patterns, suspicious zeros)
|
||||
// 3. Contract type classification to prevent misuse
|
||||
// 4. Known contract registry for instant validation
|
||||
type AddressValidator struct {
|
||||
// Known corrupted patterns that should be immediately rejected
|
||||
// These patterns are derived from observed corruption incidents in production
|
||||
corruptedPatterns []string
|
||||
// Precompiled regex equivalents for efficient matching
|
||||
corruptedPatternRegex []*regexp.Regexp
|
||||
|
||||
// Registry of known contract addresses and their verified types
|
||||
// This enables instant validation without requiring RPC calls
|
||||
knownContracts map[common.Address]ContractType
|
||||
|
||||
// Enhanced known contract registry for detailed validation
|
||||
knownContractsRegistry *KnownContractRegistry
|
||||
}
|
||||
|
||||
// ContractType represents the classification of an Ethereum contract.
|
||||
// This classification is critical for preventing the ERC-20/pool confusion
|
||||
// that was causing massive log spam and transaction failures in production.
|
||||
type ContractType int
|
||||
|
||||
const (
|
||||
// ContractTypeUnknown indicates the contract type could not be determined
|
||||
// This is the default for addresses not in the known contracts registry
|
||||
ContractTypeUnknown ContractType = iota
|
||||
|
||||
// ContractTypeERC20Token indicates a standard ERC-20 token contract
|
||||
// These contracts should never be used in pool-specific operations
|
||||
ContractTypeERC20Token
|
||||
|
||||
// ContractTypeUniswapV2Pool indicates a Uniswap V2 compatible pool contract
|
||||
// These support token0(), token1(), and getReserves() functions
|
||||
ContractTypeUniswapV2Pool
|
||||
|
||||
// ContractTypeUniswapV3Pool indicates a Uniswap V3 compatible pool contract
|
||||
// These support token0(), token1(), and slot0() functions
|
||||
ContractTypeUniswapV3Pool
|
||||
|
||||
// ContractTypeRouter indicates a DEX router contract
|
||||
// These should not be used directly as token or pool addresses
|
||||
ContractTypeRouter
|
||||
|
||||
// ContractTypeFactory indicates a pool factory contract
|
||||
// These create pools but are not pools themselves
|
||||
ContractTypeFactory
|
||||
)
|
||||
|
||||
// String returns a human-readable representation of the contract type.
|
||||
func (ct ContractType) String() string {
|
||||
switch ct {
|
||||
case ContractTypeERC20Token:
|
||||
return "ERC20_TOKEN"
|
||||
case ContractTypeUniswapV2Pool:
|
||||
return "UNISWAP_V2_POOL"
|
||||
case ContractTypeUniswapV3Pool:
|
||||
return "UNISWAP_V3_POOL"
|
||||
case ContractTypeRouter:
|
||||
return "ROUTER"
|
||||
case ContractTypeFactory:
|
||||
return "FACTORY"
|
||||
case ContractTypeUnknown:
|
||||
fallthrough
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// ValidationError represents a structured validation failure for an address.
|
||||
type ValidationError struct {
|
||||
Code string
|
||||
Message string
|
||||
Context map[string]interface{}
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *ValidationError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.Code != "" {
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// AddressValidationResult contains comprehensive validation results and metadata
|
||||
// for an Ethereum address. This structure provides detailed information about
|
||||
// the validation process, including any issues found and confidence metrics.
|
||||
type AddressValidationResult struct {
|
||||
// IsValid indicates whether the address passed all validation checks
|
||||
// A false value means the address should not be used in transactions
|
||||
IsValid bool
|
||||
|
||||
// Address contains the parsed Ethereum address (only valid if IsValid is true)
|
||||
Address common.Address
|
||||
|
||||
// ContractType indicates the classification of the contract (if known)
|
||||
ContractType ContractType
|
||||
|
||||
// ErrorMessages contains detailed descriptions of validation failures
|
||||
// These are critical issues that prevent the address from being used
|
||||
ErrorMessages []string
|
||||
|
||||
// WarningMessages contains non-critical issues or concerns
|
||||
// These don't prevent usage but should be logged for monitoring
|
||||
WarningMessages []string
|
||||
|
||||
// CorruptionScore is a 0-100 metric indicating likelihood of corruption
|
||||
// Higher scores indicate more suspicious patterns (0=clean, 100=definitely corrupted)
|
||||
// Addresses with scores >30 are typically rejected in critical operations
|
||||
CorruptionScore int
|
||||
}
|
||||
|
||||
// NewAddressValidator creates a new address validator
|
||||
func NewAddressValidator() *AddressValidator {
|
||||
patterns := []string{
|
||||
// Patterns indicating clear corruption
|
||||
"0000000000000000000000000000000000000000", // All zeros
|
||||
"000000000000000000000000000000000000000", // Missing one zero
|
||||
"00000000000000000000000000000000000000000", // Extra zero
|
||||
// Patterns with trailing zeros indicating truncation
|
||||
"00000000000000000000000000$",
|
||||
"000000000000000000000000$",
|
||||
"0000000000000000000000$",
|
||||
// Patterns with leading non-hex after 0x
|
||||
"^0x[^0-9a-fA-F]",
|
||||
}
|
||||
|
||||
compiled := make([]*regexp.Regexp, 0, len(patterns))
|
||||
for _, pattern := range patterns {
|
||||
compiled = append(compiled, regexp.MustCompile(pattern))
|
||||
}
|
||||
|
||||
return &AddressValidator{
|
||||
corruptedPatterns: patterns,
|
||||
corruptedPatternRegex: compiled,
|
||||
knownContracts: make(map[common.Address]ContractType),
|
||||
knownContractsRegistry: NewKnownContractRegistry(),
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeKnownContracts populates the validator with known Arbitrum contracts
|
||||
func (av *AddressValidator) InitializeKnownContracts() {
|
||||
// Known Arbitrum tokens
|
||||
av.knownContracts[common.HexToAddress("0xFF970A61A04b1cA14834A43f5dE4533eBDDB5CC8")] = ContractTypeERC20Token // USDC
|
||||
av.knownContracts[common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1")] = ContractTypeERC20Token // WETH
|
||||
av.knownContracts[common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9")] = ContractTypeERC20Token // USDT
|
||||
av.knownContracts[common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f")] = ContractTypeERC20Token // WBTC
|
||||
av.knownContracts[common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548")] = ContractTypeERC20Token // ARB
|
||||
|
||||
// Known Arbitrum routers
|
||||
av.knownContracts[common.HexToAddress("0xE592427A0AEce92De3Edee1F18E0157C05861564")] = ContractTypeRouter // Uniswap V3 Router
|
||||
av.knownContracts[common.HexToAddress("0x4752ba5dbc23f44d87826276bf6fd6b1c372ad24")] = ContractTypeRouter // Uniswap V2 Router02
|
||||
av.knownContracts[common.HexToAddress("0xA51afAFe0263b40EdaEf0Df8781eA9aa03E381a3")] = ContractTypeRouter // Universal Router
|
||||
av.knownContracts[common.HexToAddress("0xC36442b4a4522E871399CD717aBDD847Ab11FE88")] = ContractTypeRouter // Position Manager
|
||||
|
||||
// Known Arbitrum factories
|
||||
av.knownContracts[common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984")] = ContractTypeFactory // Uniswap V3 Factory
|
||||
av.knownContracts[common.HexToAddress("0xf1D7CC64Fb4452F05c498126312eBE29f30Fbcf9")] = ContractTypeFactory // Uniswap V2 Factory
|
||||
|
||||
// Known high-volume pools
|
||||
av.knownContracts[common.HexToAddress("0xC6962004f452bE9203591991D15f6b388e09E8D0")] = ContractTypeUniswapV3Pool // USDC/WETH 0.05%
|
||||
av.knownContracts[common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d")] = ContractTypeUniswapV3Pool // USDC/WETH 0.3%
|
||||
av.knownContracts[common.HexToAddress("0x2f5e87C9312fa29aed5c179E456625D79015299c")] = ContractTypeUniswapV3Pool // WBTC/WETH 0.05%
|
||||
}
|
||||
|
||||
// ValidateAddress performs comprehensive validation of an Ethereum address string.
|
||||
// This is the primary validation function that applies all validation layers
|
||||
// in sequence to determine if an address is safe to use in transactions.
|
||||
//
|
||||
// The validation process includes:
|
||||
// 1. Basic format validation (hex format, 0x prefix)
|
||||
// 2. Length validation (must be exactly 42 characters)
|
||||
// 3. Corruption pattern detection
|
||||
// 4. Contract type classification
|
||||
// 5. Corruption scoring
|
||||
//
|
||||
// Parameters:
|
||||
// - addressStr: The address string to validate (should include 0x prefix)
|
||||
//
|
||||
// Returns:
|
||||
// - *AddressValidationResult: Comprehensive validation results
|
||||
func (av *AddressValidator) ValidateAddress(addressStr string) *AddressValidationResult {
|
||||
// Initialize the result structure with safe defaults
|
||||
result := &AddressValidationResult{
|
||||
IsValid: false, // Default to invalid until all checks pass
|
||||
ErrorMessages: make([]string, 0),
|
||||
WarningMessages: make([]string, 0),
|
||||
CorruptionScore: 0, // Start with zero corruption score
|
||||
}
|
||||
|
||||
// Basic format validation
|
||||
if !av.isValidHexFormat(addressStr) {
|
||||
result.ErrorMessages = append(result.ErrorMessages, "invalid hex format")
|
||||
result.CorruptionScore += 50
|
||||
return result
|
||||
}
|
||||
|
||||
// Length validation
|
||||
if !av.isValidLength(addressStr) {
|
||||
result.ErrorMessages = append(result.ErrorMessages, "invalid address length")
|
||||
result.CorruptionScore += 50
|
||||
return result
|
||||
}
|
||||
|
||||
// Corruption pattern detection
|
||||
corruptionDetected, patterns := av.detectCorruption(addressStr)
|
||||
if corruptionDetected {
|
||||
result.ErrorMessages = append(result.ErrorMessages, fmt.Sprintf("corruption detected: %v", patterns))
|
||||
result.CorruptionScore += 70
|
||||
return result
|
||||
}
|
||||
|
||||
// Convert to common.Address for further validation
|
||||
if !common.IsHexAddress(addressStr) {
|
||||
result.ErrorMessages = append(result.ErrorMessages, "not a valid Ethereum address")
|
||||
result.CorruptionScore += 30
|
||||
return result
|
||||
}
|
||||
|
||||
address := common.HexToAddress(addressStr)
|
||||
result.Address = address
|
||||
|
||||
// Check for zero address
|
||||
if address == (common.Address{}) {
|
||||
result.ErrorMessages = append(result.ErrorMessages, "zero address")
|
||||
result.CorruptionScore += 40
|
||||
return result
|
||||
}
|
||||
|
||||
// Check known contract types
|
||||
if contractType, exists := av.knownContracts[address]; exists {
|
||||
result.ContractType = contractType
|
||||
} else {
|
||||
result.ContractType = ContractTypeUnknown
|
||||
result.WarningMessages = append(result.WarningMessages, "unknown contract type")
|
||||
}
|
||||
|
||||
// Additional pattern-based corruption detection
|
||||
result.CorruptionScore += av.calculateCorruptionScore(addressStr)
|
||||
|
||||
// Mark as valid if corruption score is low enough
|
||||
if result.CorruptionScore < 30 {
|
||||
result.IsValid = true
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// isValidHexFormat checks if the string is a valid hex format
|
||||
func (av *AddressValidator) isValidHexFormat(addressStr string) bool {
|
||||
if len(addressStr) < 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(addressStr, "0x") && !strings.HasPrefix(addressStr, "0X") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if all characters after 0x are valid hex
|
||||
hexPart := addressStr[2:]
|
||||
if len(hexPart) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < len(hexPart); i++ {
|
||||
switch {
|
||||
case hexPart[i] >= '0' && hexPart[i] <= '9':
|
||||
case hexPart[i] >= 'a' && hexPart[i] <= 'f':
|
||||
case hexPart[i] >= 'A' && hexPart[i] <= 'F':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isValidLength checks if the address has the correct length
|
||||
func (av *AddressValidator) isValidLength(addressStr string) bool {
|
||||
// Ethereum addresses should be 42 characters (0x + 40 hex chars)
|
||||
return len(addressStr) == 42
|
||||
}
|
||||
|
||||
// detectCorruption checks for known corruption patterns
|
||||
func (av *AddressValidator) detectCorruption(addressStr string) (bool, []string) {
|
||||
var detectedPatterns []string
|
||||
|
||||
for idx, re := range av.corruptedPatternRegex {
|
||||
if re.MatchString(addressStr) {
|
||||
detectedPatterns = append(detectedPatterns, av.corruptedPatterns[idx])
|
||||
}
|
||||
}
|
||||
|
||||
return len(detectedPatterns) > 0, detectedPatterns
|
||||
}
|
||||
|
||||
// calculateCorruptionScore calculates a 0-100 score indicating the likelihood
|
||||
// that an address has been corrupted or malformed. This scoring system is based
|
||||
// on patterns observed in production corruption incidents.
|
||||
//
|
||||
// Scoring factors:
|
||||
// - Trailing zeros (indicates truncation): +1 per excess zero
|
||||
// - Leading zeros in middle (unusual pattern): +0.5 per zero
|
||||
// - Repetitive patterns (indicates generation errors): +10
|
||||
// - Other suspicious patterns: variable points
|
||||
//
|
||||
// Parameters:
|
||||
// - addressStr: The address string to analyze (with 0x prefix)
|
||||
//
|
||||
// Returns:
|
||||
// - int: Corruption score (0=clean, 100=definitely corrupted)
|
||||
func (av *AddressValidator) calculateCorruptionScore(addressStr string) int {
|
||||
score := 0
|
||||
// Extract the hex part (remove 0x prefix)
|
||||
hexPart := addressStr[2:]
|
||||
|
||||
// Count trailing zeros (sign of truncation)
|
||||
trailingZeros := 0
|
||||
for i := len(hexPart) - 1; i >= 0; i-- {
|
||||
if hexPart[i] == '0' {
|
||||
trailingZeros++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// More than 10 trailing zeros is suspicious
|
||||
if trailingZeros > 10 {
|
||||
score += trailingZeros
|
||||
}
|
||||
|
||||
// Count leading zeros after first non-zero
|
||||
leadingZeros := 0
|
||||
foundNonZero := false
|
||||
for _, char := range hexPart {
|
||||
if char != '0' {
|
||||
foundNonZero = true
|
||||
} else if foundNonZero {
|
||||
leadingZeros++
|
||||
}
|
||||
}
|
||||
|
||||
// Large blocks of zeros in the middle are suspicious
|
||||
if leadingZeros > 8 {
|
||||
score += leadingZeros / 2
|
||||
}
|
||||
|
||||
// Check for repetitive patterns
|
||||
if av.hasRepetitivePattern(hexPart) {
|
||||
score += 10
|
||||
}
|
||||
|
||||
// Overall zero density check for leading-zero patterns (common corruption)
|
||||
zeroCount := strings.Count(hexPart, "0")
|
||||
if strings.HasPrefix(hexPart, "0000") && float64(zeroCount)/float64(len(hexPart)) > 0.7 {
|
||||
score += 30
|
||||
}
|
||||
|
||||
// Leading zero prefix (common corruption pattern from truncated data)
|
||||
if strings.HasPrefix(hexPart, "000000") {
|
||||
score += 20
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// hasRepetitivePattern detects repetitive patterns in hex strings that indicate
|
||||
// corruption or artificial generation. These patterns are rarely seen in legitimate
|
||||
// Ethereum addresses and often indicate data corruption or malicious generation.
|
||||
//
|
||||
// Detected patterns include:
|
||||
// - Long sequences of the same character (000000000000, ffffffffffff)
|
||||
// - Addresses where all characters are identical
|
||||
// - Other suspicious repetitive patterns
|
||||
//
|
||||
// Parameters:
|
||||
// - hexStr: The hex string to analyze (without 0x prefix)
|
||||
//
|
||||
// Returns:
|
||||
// - bool: true if repetitive patterns are detected
|
||||
func (av *AddressValidator) hasRepetitivePattern(hexStr string) bool {
|
||||
// Define patterns that indicate corruption or artificial generation
|
||||
// These patterns are extremely rare in legitimate Ethereum addresses
|
||||
patterns := []string{"000000000000", "ffffffffffff", "aaaaaaaaaaaa", "bbbbbbbbbbbb",
|
||||
"1111111111", "2222222222", "3333333333", "4444444444", "5555555555",
|
||||
"6666666666", "7777777777", "8888888888", "9999999999"}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(hexStr, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Additional check for address with same character repeated throughout
|
||||
if len(hexStr) >= 10 {
|
||||
firstChar := hexStr[0]
|
||||
allSame := true
|
||||
for i := 1; i < len(hexStr); i++ {
|
||||
if hexStr[i] != firstChar {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allSame {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsValidPoolAddress checks if an address is valid for pool operations
|
||||
func (av *AddressValidator) IsValidPoolAddress(address common.Address) bool {
|
||||
result := av.ValidateAddress(address.Hex())
|
||||
|
||||
if !result.IsValid {
|
||||
return false
|
||||
}
|
||||
|
||||
// Must not be a token, router, or factory for pool operations
|
||||
switch result.ContractType {
|
||||
case ContractTypeERC20Token, ContractTypeRouter, ContractTypeFactory:
|
||||
return false
|
||||
case ContractTypeUniswapV2Pool, ContractTypeUniswapV3Pool:
|
||||
return true
|
||||
case ContractTypeUnknown:
|
||||
// Allow unknown contracts but warn
|
||||
return result.CorruptionScore < 20
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsValidTokenAddress checks if an address is valid for token operations
|
||||
func (av *AddressValidator) IsValidTokenAddress(address common.Address) bool {
|
||||
result := av.ValidateAddress(address.Hex())
|
||||
|
||||
if !result.IsValid {
|
||||
return false
|
||||
}
|
||||
|
||||
// Prefer known tokens, but allow unknown contracts with low corruption scores
|
||||
switch result.ContractType {
|
||||
case ContractTypeERC20Token:
|
||||
return true
|
||||
case ContractTypeRouter, ContractTypeFactory, ContractTypeUniswapV2Pool, ContractTypeUniswapV3Pool:
|
||||
return false
|
||||
case ContractTypeUnknown:
|
||||
return result.CorruptionScore < 15
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetContractType returns the contract type for a given address
|
||||
func (av *AddressValidator) GetContractType(address common.Address) ContractType {
|
||||
if contractType, exists := av.knownContracts[address]; exists {
|
||||
return contractType
|
||||
}
|
||||
return ContractTypeUnknown
|
||||
}
|
||||
|
||||
// ValidateContractTypeConsistency validates that addresses don't have conflicting contract types
|
||||
func (av *AddressValidator) ValidateContractTypeConsistency(tokenAddresses []common.Address, poolAddresses []common.Address) error {
|
||||
// CRITICAL: Ensure no address appears in both token and pool lists
|
||||
for _, token := range tokenAddresses {
|
||||
for _, pool := range poolAddresses {
|
||||
if token == pool {
|
||||
return fmt.Errorf("address %s cannot be both a token and a pool", token.Hex())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CRITICAL: Validate each token address is actually a token
|
||||
for _, token := range tokenAddresses {
|
||||
if !av.IsValidTokenAddress(token) {
|
||||
return fmt.Errorf("address %s is not a valid token address", token.Hex())
|
||||
}
|
||||
|
||||
// Additional check: ensure it's not marked as a pool in known contracts
|
||||
if contractType := av.GetContractType(token); contractType == ContractTypeUniswapV2Pool || contractType == ContractTypeUniswapV3Pool {
|
||||
return fmt.Errorf("address %s is marked as a pool but being used as a token", token.Hex())
|
||||
}
|
||||
}
|
||||
|
||||
// CRITICAL: Validate each pool address is actually a pool
|
||||
for _, pool := range poolAddresses {
|
||||
if !av.IsValidPoolAddress(pool) {
|
||||
return fmt.Errorf("address %s is not a valid pool address", pool.Hex())
|
||||
}
|
||||
|
||||
// Additional check: ensure it's not marked as a token in known contracts
|
||||
if contractType := av.GetContractType(pool); contractType == ContractTypeERC20Token {
|
||||
return fmt.Errorf("address %s is marked as a token but being used as a pool", pool.Hex())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PreventERC20PoolConfusion is a critical safety function that prevents the costly
|
||||
// error where ERC-20 token contracts are incorrectly used as pool contracts.
|
||||
// This was the root cause of the 535K+ log spam incident in production.
|
||||
//
|
||||
// The function performs a type safety check to ensure that:
|
||||
// - ERC-20 tokens are not used in pool operations
|
||||
// - Pool contracts are not used in token operations
|
||||
// - Unknown contracts meet safety thresholds
|
||||
//
|
||||
// This is a mandatory check for all contract address usage in critical operations.
|
||||
//
|
||||
// Parameters:
|
||||
// - address: The contract address to validate
|
||||
// - expectedType: The contract type expected by the calling code
|
||||
//
|
||||
// Returns:
|
||||
// - error: nil if the address is safe to use, error describing the issue otherwise
|
||||
func (av *AddressValidator) PreventERC20PoolConfusion(address common.Address, expectedType ContractType) error {
|
||||
// Check if we have prior knowledge about this contract
|
||||
knownType := av.GetContractType(address)
|
||||
|
||||
// If we have knowledge about this contract, use it
|
||||
if knownType != ContractTypeUnknown {
|
||||
if knownType != expectedType {
|
||||
return fmt.Errorf("contract type mismatch for %s: expected %s but known as %s",
|
||||
address.Hex(), contractTypeToString(expectedType), contractTypeToString(knownType))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For unknown contracts, perform basic validation
|
||||
result := av.ValidateAddress(address.Hex())
|
||||
if !result.IsValid {
|
||||
return fmt.Errorf("invalid address %s: %v", address.Hex(), result.ErrorMessages)
|
||||
}
|
||||
|
||||
// High corruption score indicates potential misclassification
|
||||
if result.CorruptionScore > 25 {
|
||||
return fmt.Errorf("high corruption score (%d) for address %s, refusing to use as %s",
|
||||
result.CorruptionScore, address.Hex(), contractTypeToString(expectedType))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// contractTypeToString converts ContractType to string representation
|
||||
func contractTypeToString(ct ContractType) string {
|
||||
switch ct {
|
||||
case ContractTypeERC20Token:
|
||||
return "ERC-20 Token"
|
||||
case ContractTypeUniswapV2Pool:
|
||||
return "Uniswap V2 Pool"
|
||||
case ContractTypeUniswapV3Pool:
|
||||
return "Uniswap V3 Pool"
|
||||
case ContractTypeRouter:
|
||||
return "Router"
|
||||
case ContractTypeFactory:
|
||||
return "Factory"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// IsKnownContract checks if we have specific knowledge about a contract
|
||||
func (av *AddressValidator) IsKnownContract(address common.Address) bool {
|
||||
_, exists := av.knownContracts[address]
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetValidationStats returns statistics about validation results
|
||||
func (av *AddressValidator) GetValidationStats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"known_contracts": len(av.knownContracts),
|
||||
"corruption_patterns": len(av.corruptedPatterns),
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeAddress attempts to clean up a potentially corrupted address
|
||||
func (av *AddressValidator) SanitizeAddress(addressStr string) (string, error) {
|
||||
// Remove common prefixes that might be corrupted
|
||||
cleaned := strings.TrimSpace(addressStr)
|
||||
|
||||
// Ensure 0x prefix
|
||||
if !strings.HasPrefix(cleaned, "0x") && !strings.HasPrefix(cleaned, "0X") {
|
||||
if len(cleaned) == 40 && av.isValidHexFormat("0x"+cleaned) {
|
||||
cleaned = "0x" + cleaned
|
||||
} else {
|
||||
return "", fmt.Errorf("cannot sanitize address without 0x prefix")
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize to lowercase
|
||||
cleaned = strings.ToLower(cleaned)
|
||||
|
||||
// Validate the sanitized address
|
||||
result := av.ValidateAddress(cleaned)
|
||||
if !result.IsValid {
|
||||
return "", fmt.Errorf("sanitized address is still invalid: %v", result.ErrorMessages)
|
||||
}
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
// ValidatePoolOperation validates if a pool-specific operation should be allowed on an address
|
||||
// This prevents the critical error where ERC-20 tokens are treated as pools
|
||||
func (av *AddressValidator) ValidatePoolOperation(address common.Address, operation string) error {
|
||||
// Check if this is a known ERC-20 token
|
||||
if av.knownContractsRegistry != nil {
|
||||
if err := av.knownContractsRegistry.ValidatePoolCall(address, operation); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Additional validation for suspicious addresses
|
||||
addressStr := address.Hex()
|
||||
|
||||
// Check for zero address (common corruption)
|
||||
if address == (common.Address{}) {
|
||||
return fmt.Errorf("pool operation '%s' attempted on zero address (likely corruption)", operation)
|
||||
}
|
||||
|
||||
// Check for addresses with excessive zeros (likely corruption)
|
||||
if strings.Count(addressStr, "0") > 30 {
|
||||
return fmt.Errorf("pool operation '%s' attempted on suspicious address %s (excessive zeros)", operation, addressStr)
|
||||
}
|
||||
|
||||
// Fallback to general validation
|
||||
result := av.ValidateAddress(addressStr)
|
||||
if !result.IsValid {
|
||||
return fmt.Errorf("pool operation '%s' attempted on invalid address %s: %v", operation, addressStr, result.ErrorMessages)
|
||||
}
|
||||
|
||||
if result.CorruptionScore > 25 {
|
||||
return fmt.Errorf("pool operation '%s' blocked due to corruption score %d on address %s", operation, result.CorruptionScore, addressStr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDetailedAddressAnalysis provides comprehensive analysis of an address including corruption patterns
|
||||
func (av *AddressValidator) GetDetailedAddressAnalysis(address common.Address) map[string]interface{} {
|
||||
analysis := make(map[string]interface{})
|
||||
addressStr := address.Hex()
|
||||
|
||||
// Basic validation
|
||||
result := av.ValidateAddress(addressStr)
|
||||
analysis["is_valid"] = result.IsValid
|
||||
analysis["corruption_score"] = result.CorruptionScore
|
||||
analysis["contract_type"] = result.ContractType.String()
|
||||
analysis["error_messages"] = result.ErrorMessages
|
||||
|
||||
// Known contract information
|
||||
if av.knownContractsRegistry != nil {
|
||||
contractType, name := av.knownContractsRegistry.GetContractType(address)
|
||||
analysis["known_contract_type"] = contractType.String()
|
||||
analysis["known_contract_name"] = name
|
||||
|
||||
isERC20, tokenName := av.knownContractsRegistry.IsKnownERC20(address)
|
||||
analysis["is_known_erc20"] = isERC20
|
||||
if isERC20 {
|
||||
analysis["token_name"] = tokenName
|
||||
}
|
||||
|
||||
// Corruption pattern analysis
|
||||
corruptionPattern := av.knownContractsRegistry.GetCorruptionPattern(address)
|
||||
analysis["corruption_pattern"] = corruptionPattern
|
||||
}
|
||||
|
||||
// Address characteristics
|
||||
analysis["is_zero_address"] = address == (common.Address{})
|
||||
analysis["zero_count"] = strings.Count(addressStr, "0")
|
||||
analysis["hex_length"] = len(addressStr)
|
||||
|
||||
return analysis
|
||||
}
|
||||
351
orig/internal/validation/address_test.go
Normal file
351
orig/internal/validation/address_test.go
Normal file
@@ -0,0 +1,351 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAddressValidator_ValidateAddress(t *testing.T) {
|
||||
validator := NewAddressValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
address string
|
||||
expectedValid bool
|
||||
expectedScore int
|
||||
expectedType ContractType
|
||||
shouldContainErrors []string
|
||||
}{
|
||||
{
|
||||
name: "Valid WETH address",
|
||||
address: "0x82aF49447D8a07e3bd95BD0d56f35241523fBab1",
|
||||
expectedValid: true,
|
||||
expectedScore: 0,
|
||||
expectedType: ContractTypeUnknown, // Will be unknown without RPC
|
||||
},
|
||||
{
|
||||
name: "Valid USDC address",
|
||||
address: "0xaf88d065e77c8cC2239327C5EDb3A432268e5831",
|
||||
expectedValid: true,
|
||||
expectedScore: 0,
|
||||
expectedType: ContractTypeUnknown,
|
||||
},
|
||||
{
|
||||
name: "Critical corruption - TOKEN_0x000000 pattern",
|
||||
address: "0x0000000300000000000000000000000000000000",
|
||||
expectedValid: false,
|
||||
expectedScore: 70, // Detected by corruption patterns
|
||||
shouldContainErrors: []string{"corruption detected"},
|
||||
},
|
||||
{
|
||||
name: "High corruption - mixed zero pattern",
|
||||
address: "0x0000001200000000000000000000000000000000",
|
||||
expectedValid: false,
|
||||
expectedScore: 70,
|
||||
shouldContainErrors: []string{"corruption detected"},
|
||||
},
|
||||
{
|
||||
name: "Medium corruption - trailing zeros",
|
||||
address: "0x123456780000000000000000000000000000000",
|
||||
expectedValid: false,
|
||||
expectedScore: 50,
|
||||
shouldContainErrors: []string{"invalid address length"},
|
||||
},
|
||||
{
|
||||
name: "Low corruption - some zeros",
|
||||
address: "0x1234567800000000000000000000000000000001",
|
||||
expectedValid: true, // Valid format with moderate corruption
|
||||
expectedScore: 25,
|
||||
shouldContainErrors: []string{}, // No errors for valid format
|
||||
},
|
||||
{
|
||||
name: "Invalid length - too short",
|
||||
address: "0x123456",
|
||||
expectedValid: false,
|
||||
expectedScore: 50,
|
||||
shouldContainErrors: []string{"invalid address length"},
|
||||
},
|
||||
{
|
||||
name: "Invalid length - too long",
|
||||
address: "0x82aF49447D8a07e3bd95BD0d56f35241523fBab12345",
|
||||
expectedValid: false,
|
||||
expectedScore: 50,
|
||||
shouldContainErrors: []string{"invalid address length"},
|
||||
},
|
||||
{
|
||||
name: "Invalid hex characters",
|
||||
address: "0x82aF49447D8a07e3bd95BD0d56f35241523fBaZ1",
|
||||
expectedValid: false,
|
||||
expectedScore: 50,
|
||||
shouldContainErrors: []string{"invalid hex format"},
|
||||
},
|
||||
{
|
||||
name: "Missing 0x prefix",
|
||||
address: "82aF49447D8a07e3bd95BD0d56f35241523fBab1",
|
||||
expectedValid: false,
|
||||
expectedScore: 50,
|
||||
shouldContainErrors: []string{"invalid hex format"},
|
||||
},
|
||||
{
|
||||
name: "All zeros address",
|
||||
address: "0x0000000000000000000000000000000000000000",
|
||||
expectedValid: false,
|
||||
expectedScore: 70, // Detected by corruption patterns
|
||||
shouldContainErrors: []string{"corruption detected"},
|
||||
},
|
||||
{
|
||||
name: "Invalid checksum",
|
||||
address: "0x82af49447d8a07e3bd95bd0d56f35241523fbab1", // lowercase
|
||||
expectedValid: true, // Checksum validation not enforced in current implementation
|
||||
expectedScore: 0,
|
||||
shouldContainErrors: []string{}, // No errors for valid format
|
||||
},
|
||||
{
|
||||
name: "Valid checksummed address",
|
||||
address: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1").Hex(),
|
||||
expectedValid: true,
|
||||
expectedScore: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateAddress(tt.address)
|
||||
|
||||
assert.Equal(t, tt.expectedValid, result.IsValid, "IsValid mismatch")
|
||||
assert.Equal(t, tt.expectedScore, result.CorruptionScore, "CorruptionScore mismatch")
|
||||
|
||||
if tt.expectedType != ContractTypeUnknown {
|
||||
assert.Equal(t, tt.expectedType, result.ContractType, "ContractType mismatch")
|
||||
}
|
||||
|
||||
for _, expectedError := range tt.shouldContainErrors {
|
||||
found := false
|
||||
for _, errMsg := range result.ErrorMessages {
|
||||
if contains(errMsg, expectedError) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected error message containing '%s' not found in %v", expectedError, result.ErrorMessages)
|
||||
}
|
||||
|
||||
t.Logf("Address: %s, Valid: %v, Score: %d, Errors: %v",
|
||||
tt.address, result.IsValid, result.CorruptionScore, result.ErrorMessages)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddressValidator_CorruptionPatterns(t *testing.T) {
|
||||
validator := NewAddressValidator()
|
||||
|
||||
corruptionTests := []struct {
|
||||
name string
|
||||
address string
|
||||
expectedScore int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "TOKEN_0x000000 exact pattern",
|
||||
address: "0x0000000300000000000000000000000000000000",
|
||||
expectedScore: 70, // Caught by pattern detection
|
||||
description: "Exact TOKEN_0x000000 corruption pattern",
|
||||
},
|
||||
{
|
||||
name: "Similar corruption pattern",
|
||||
address: "0x0000000100000000000000000000000000000000",
|
||||
expectedScore: 70, // Caught by pattern detection
|
||||
description: "Similar zero-heavy corruption",
|
||||
},
|
||||
{
|
||||
name: "Partial corruption",
|
||||
address: "0x1234000000000000000000000000000000000000",
|
||||
expectedScore: 70, // Caught by pattern detection
|
||||
description: "Partial zero corruption",
|
||||
},
|
||||
{
|
||||
name: "Trailing corruption",
|
||||
address: "0x123456789abcdef000000000000000000000000",
|
||||
expectedScore: 50, // Invalid length
|
||||
description: "Trailing zero corruption",
|
||||
},
|
||||
{
|
||||
name: "Valid but zero-heavy",
|
||||
address: "0x000000000000000000000000000000000000dead",
|
||||
expectedScore: 10, // Valid format, minimal corruption
|
||||
description: "Valid format but suspicious zeros",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range corruptionTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateAddress(tt.address)
|
||||
|
||||
assert.GreaterOrEqual(t, result.CorruptionScore, tt.expectedScore-10,
|
||||
"Corruption score should be at least %d-10 for %s", tt.expectedScore, tt.description)
|
||||
assert.LessOrEqual(t, result.CorruptionScore, 100,
|
||||
"Corruption score should not exceed 100")
|
||||
|
||||
// Check validity based on expected behavior rather than fixed threshold
|
||||
if tt.expectedScore >= 50 && tt.address != "0x000000000000000000000000000000000000dead" {
|
||||
assert.False(t, result.IsValid, "High corruption addresses should be invalid")
|
||||
}
|
||||
|
||||
t.Logf("%s: Score=%d, Valid=%v, Description=%s",
|
||||
tt.address, result.CorruptionScore, result.IsValid, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddressValidator_EdgeCases(t *testing.T) {
|
||||
validator := NewAddressValidator()
|
||||
|
||||
edgeCases := []struct {
|
||||
name string
|
||||
address string
|
||||
shouldBeValid bool
|
||||
}{
|
||||
{"Empty string", "", false},
|
||||
{"Only 0x", "0x", false},
|
||||
{"Just prefix", "0x0", false},
|
||||
{"Uppercase hex", "0x82AF49447D8A07E3BD95BD0D56F35241523FBAB1", true}, // Valid - case doesn't matter
|
||||
{"Mixed case invalid", "0x82aF49447D8a07e3BD95BD0d56f35241523fBaB1", true}, // Valid - case doesn't matter
|
||||
{"Unicode characters", "0x82aF49447D8a07e3bd95BD0d56f35241523fBaβ1", false},
|
||||
{"SQL injection attempt", "0x'; DROP TABLE addresses; --", false},
|
||||
{"Buffer overflow attempt", "0x" + string(make([]byte, 1000)), false},
|
||||
}
|
||||
|
||||
for _, tt := range edgeCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Should not panic
|
||||
result := validator.ValidateAddress(tt.address)
|
||||
|
||||
// Check validity based on expectation
|
||||
assert.Equal(t, tt.shouldBeValid, result.IsValid, "Edge case validity mismatch: %s", tt.address)
|
||||
|
||||
if !tt.shouldBeValid {
|
||||
assert.Greater(t, result.CorruptionScore, 0, "Invalid edge case should have corruption score > 0")
|
||||
assert.NotEmpty(t, result.ErrorMessages, "Invalid edge case should have error messages")
|
||||
} else {
|
||||
// Valid addresses can have low corruption scores
|
||||
assert.Empty(t, result.ErrorMessages, "Valid edge case should not have error messages")
|
||||
}
|
||||
|
||||
t.Logf("Edge case '%s': Valid=%v, Score=%d, Errors=%v",
|
||||
tt.name, result.IsValid, result.CorruptionScore, result.ErrorMessages)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddressValidator_Performance(t *testing.T) {
|
||||
validator := NewAddressValidator()
|
||||
|
||||
// Test addresses for performance benchmark
|
||||
addresses := []string{
|
||||
"0x82aF49447D8a07e3bd95BD0d56f35241523fBab1", // Valid WETH
|
||||
"0x0000000300000000000000000000000000000000", // Corrupted
|
||||
"0xaf88d065e77c8cC2239327C5EDb3A432268e5831", // Valid USDC
|
||||
"0x0000000000000000000000000000000000000000", // Zero address
|
||||
"0x123456789abcdef0000000000000000000000000", // Partial corruption
|
||||
}
|
||||
|
||||
// Warm up
|
||||
for _, addr := range addresses {
|
||||
validator.ValidateAddress(addr)
|
||||
}
|
||||
|
||||
// Benchmark validation performance
|
||||
const iterations = 10000
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
addr := addresses[i%len(addresses)]
|
||||
result := validator.ValidateAddress(addr)
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
avgTime := duration / iterations
|
||||
|
||||
t.Logf("Performance: %d validations in %v (avg: %v per validation)",
|
||||
iterations, duration, avgTime)
|
||||
|
||||
// Should validate at least 1,000 addresses per second
|
||||
maxTime := time.Millisecond * 2 // 2ms per validation = 500/sec (reasonable for complex validation)
|
||||
assert.Less(t, avgTime.Nanoseconds(), maxTime.Nanoseconds(),
|
||||
"Validation should be faster than %v per address (got %v)", maxTime, avgTime)
|
||||
}
|
||||
|
||||
func TestAddressValidator_ConcurrentAccess(t *testing.T) {
|
||||
validator := NewAddressValidator()
|
||||
|
||||
addresses := []string{
|
||||
"0x82aF49447D8a07e3bd95BD0d56f35241523fBab1",
|
||||
"0x0000000300000000000000000000000000000000",
|
||||
"0xaf88d065e77c8cC2239327C5EDb3A432268e5831",
|
||||
}
|
||||
|
||||
const numGoroutines = 100
|
||||
const validationsPerGoroutine = 100
|
||||
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
// Launch concurrent validators
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := 0; j < validationsPerGoroutine; j++ {
|
||||
addr := addresses[j%len(addresses)]
|
||||
result := validator.ValidateAddress(addr)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify consistent results
|
||||
if addr == "0x82aF49447D8a07e3bd95BD0d56f35241523fBab1" {
|
||||
assert.True(t, result.IsValid)
|
||||
assert.Equal(t, 0, result.CorruptionScore)
|
||||
}
|
||||
if addr == "0x0000000300000000000000000000000000000000" {
|
||||
assert.False(t, result.IsValid)
|
||||
assert.Equal(t, 70, result.CorruptionScore) // Updated to match new validation logic
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("Concurrent validation test timed out")
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Successfully completed %d concurrent validations",
|
||||
numGoroutines*validationsPerGoroutine)
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring (case-insensitive)
|
||||
func contains(str, substr string) bool {
|
||||
return len(str) >= len(substr) &&
|
||||
(str == substr ||
|
||||
len(str) > len(substr) &&
|
||||
(str[:len(substr)] == substr ||
|
||||
str[len(str)-len(substr):] == substr ||
|
||||
indexOf(str, substr) >= 0))
|
||||
}
|
||||
|
||||
func indexOf(str, substr string) int {
|
||||
for i := 0; i <= len(str)-len(substr); i++ {
|
||||
if str[i:i+len(substr)] == substr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
158
orig/internal/validation/known_contracts.go
Normal file
158
orig/internal/validation/known_contracts.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// KnownContractRegistry maintains a registry of known contract addresses and their types
|
||||
// This prevents misclassification of well-known contracts (like major ERC-20 tokens)
|
||||
type KnownContractRegistry struct {
|
||||
erc20Tokens map[common.Address]string
|
||||
pools map[common.Address]string
|
||||
routers map[common.Address]string
|
||||
}
|
||||
|
||||
// NewKnownContractRegistry creates a new registry populated with known Arbitrum contracts
|
||||
func NewKnownContractRegistry() *KnownContractRegistry {
|
||||
registry := &KnownContractRegistry{
|
||||
erc20Tokens: make(map[common.Address]string),
|
||||
pools: make(map[common.Address]string),
|
||||
routers: make(map[common.Address]string),
|
||||
}
|
||||
|
||||
// Populate known ERC-20 tokens on Arbitrum (CRITICAL: These should NEVER be treated as pools)
|
||||
registry.addKnownERC20Tokens()
|
||||
registry.addKnownPools()
|
||||
registry.addKnownRouters()
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
// addKnownERC20Tokens adds all major ERC-20 tokens on Arbitrum
|
||||
func (r *KnownContractRegistry) addKnownERC20Tokens() {
|
||||
// Major ERC-20 tokens that were being misclassified as pools
|
||||
r.erc20Tokens[common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1")] = "WETH" // Wrapped Ether
|
||||
r.erc20Tokens[common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831")] = "USDC" // USD Coin
|
||||
r.erc20Tokens[common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9")] = "USDT" // Tether USD
|
||||
r.erc20Tokens[common.HexToAddress("0xFF970A61A04b1cA14834A43f5dE4533eBDDB5CC8")] = "USDC.e" // Bridged USDC
|
||||
r.erc20Tokens[common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548")] = "ARB" // Arbitrum Token
|
||||
r.erc20Tokens[common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f")] = "WBTC" // Wrapped Bitcoin
|
||||
r.erc20Tokens[common.HexToAddress("0xDA10009cBd5D07dd0CeCc66161FC93D7c9000da1")] = "DAI" // Dai Stablecoin
|
||||
r.erc20Tokens[common.HexToAddress("0x17FC002b466eEc40DaE837Fc4bE5c67993ddBd6F")] = "FRAX" // Frax
|
||||
r.erc20Tokens[common.HexToAddress("0x11cDb42B0EB46D95f990BeDD4695A6e3fA034978")] = "CRV" // Curve DAO Token
|
||||
r.erc20Tokens[common.HexToAddress("0x539bdE0d7Dbd336b79148AA742883198BBF60342")] = "MAGIC" // MAGIC
|
||||
r.erc20Tokens[common.HexToAddress("0xf97f4df75117a78c1A5a0DBb814Af92458539FB4")] = "LINK" // Chainlink Token
|
||||
r.erc20Tokens[common.HexToAddress("0xfc5A1A6EB076a2C7aD06eD22C90d7E710E35ad0a")] = "GMX" // GMX
|
||||
r.erc20Tokens[common.HexToAddress("0x6C2C06790b3E3E3c38e12Ee22F8183b37a13EE55")] = "DPX" // Dopex Governance Token
|
||||
r.erc20Tokens[common.HexToAddress("0x10393c20975cF177a3513071bC110f7962CD67da")] = "JONES" // JonesDAO
|
||||
r.erc20Tokens[common.HexToAddress("0x4e352cf164e64adcbad318c3a1e222e9eba4ce42")] = "MCB" // MCDEX Token
|
||||
r.erc20Tokens[common.HexToAddress("0x23A941036Ae778Ac51Ab04CEa08Ed6e2AF33b49")] = "RDNT" // Radiant Capital
|
||||
r.erc20Tokens[common.HexToAddress("0x6694340fc020c5E6B96567843da2df01b2CE1eb6")] = "STG" // Stargate Finance
|
||||
r.erc20Tokens[common.HexToAddress("0x3082CC23568eA640225c2467653dB90e9250AaA0")] = "RDNT" // Radiant (alternative)
|
||||
r.erc20Tokens[common.HexToAddress("0x51fC0f6660482Ea73330E414eFd7808811a57Fa2")] = "PREMIA" // Premia
|
||||
r.erc20Tokens[common.HexToAddress("0x69Eb4FA4a2fbd498C257C57Ea8b7655a2559A581")] = "DODO" // DODO
|
||||
}
|
||||
|
||||
// addKnownPools adds known liquidity pools
|
||||
func (r *KnownContractRegistry) addKnownPools() {
|
||||
// Major Uniswap V3 pools on Arbitrum
|
||||
r.pools[common.HexToAddress("0xC31E54c7a869B9FcBEcc14363CF510d1c41fa443")] = "USDC/WETH-0.05%"
|
||||
r.pools[common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d")] = "USDC/WETH-0.3%"
|
||||
r.pools[common.HexToAddress("0x641C00A822e8b671738d32a431a4Fb6074E5c79d")] = "WETH/ARB-0.3%"
|
||||
r.pools[common.HexToAddress("0xdE64C63e6BaD1Ff18f4F1bdc9d1e7Bbfb5E0B6FD")] = "USDT/USDC-0.01%"
|
||||
r.pools[common.HexToAddress("0x2f5e87C9312fa29aed5c179E456625D79015299c")] = "WBTC/WETH-0.05%"
|
||||
}
|
||||
|
||||
// addKnownRouters adds known router contracts
|
||||
func (r *KnownContractRegistry) addKnownRouters() {
|
||||
// Uniswap V3 SwapRouter
|
||||
r.routers[common.HexToAddress("0xE592427A0AEce92De3Edee1F18E0157C05861564")] = "UniswapV3Router"
|
||||
// Uniswap V3 SwapRouter02
|
||||
r.routers[common.HexToAddress("0x68b3465833fb72A70ecDF485E0e4C7bD8665Fc45")] = "UniswapV3Router02"
|
||||
// SushiSwap Router
|
||||
r.routers[common.HexToAddress("0x1b02dA8Cb0d097eB8D57A175b88c7D8b47997506")] = "SushiSwapRouter"
|
||||
// Camelot Router
|
||||
r.routers[common.HexToAddress("0xc873fEcbd354f5A56E00E710B90EF4201db2448d")] = "CamelotRouter"
|
||||
}
|
||||
|
||||
// IsKnownERC20 checks if an address is a known ERC-20 token
|
||||
func (r *KnownContractRegistry) IsKnownERC20(address common.Address) (bool, string) {
|
||||
name, exists := r.erc20Tokens[address]
|
||||
return exists, name
|
||||
}
|
||||
|
||||
// IsKnownPool checks if an address is a known liquidity pool
|
||||
func (r *KnownContractRegistry) IsKnownPool(address common.Address) (bool, string) {
|
||||
name, exists := r.pools[address]
|
||||
return exists, name
|
||||
}
|
||||
|
||||
// IsKnownRouter checks if an address is a known router contract
|
||||
func (r *KnownContractRegistry) IsKnownRouter(address common.Address) (bool, string) {
|
||||
name, exists := r.routers[address]
|
||||
return exists, name
|
||||
}
|
||||
|
||||
// GetContractType returns the type of a known contract
|
||||
func (r *KnownContractRegistry) GetContractType(address common.Address) (ContractType, string) {
|
||||
if isERC20, name := r.IsKnownERC20(address); isERC20 {
|
||||
return ContractTypeERC20Token, name
|
||||
}
|
||||
if isPool, name := r.IsKnownPool(address); isPool {
|
||||
return ContractTypeUniswapV3Pool, name
|
||||
}
|
||||
if isRouter, name := r.IsKnownRouter(address); isRouter {
|
||||
return ContractTypeRouter, name
|
||||
}
|
||||
return ContractTypeUnknown, ""
|
||||
}
|
||||
|
||||
// ValidatePoolCall validates if a pool-specific operation should be allowed on an address
|
||||
func (r *KnownContractRegistry) ValidatePoolCall(address common.Address, operation string) error {
|
||||
if isERC20, name := r.IsKnownERC20(address); isERC20 {
|
||||
return &ValidationError{
|
||||
Code: "INVALID_POOL_OPERATION",
|
||||
Message: fmt.Sprintf("Attempted pool operation '%s' on known ERC-20 token %s (%s)", operation, name, address.Hex()),
|
||||
Context: map[string]interface{}{
|
||||
"address": address.Hex(),
|
||||
"token_name": name,
|
||||
"operation": operation,
|
||||
"issue": "ERC-20 tokens do not have pool-specific functions like slot0()",
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCorruptionPattern analyzes an address for known corruption patterns
|
||||
func (r *KnownContractRegistry) GetCorruptionPattern(address common.Address) string {
|
||||
hexAddr := address.Hex()
|
||||
|
||||
// Check for zero address
|
||||
if address == (common.Address{}) {
|
||||
return "ZERO_ADDRESS"
|
||||
}
|
||||
|
||||
// Check for mostly zeros with small values
|
||||
if strings.HasSuffix(hexAddr, "0000000000000000000000000000000000000000") {
|
||||
return "TRAILING_ZEROS"
|
||||
}
|
||||
|
||||
// Check for leading zeros with small values
|
||||
if strings.HasPrefix(hexAddr, "0x00000000") {
|
||||
return "LEADING_ZEROS_CORRUPTION"
|
||||
}
|
||||
|
||||
// Check for embedded WETH/USDC patterns (indicates address extraction issues)
|
||||
if strings.Contains(strings.ToLower(hexAddr), "82af49447d8a07e3bd95bd0d56f35241523fbab1") {
|
||||
return "EMBEDDED_WETH_PATTERN"
|
||||
}
|
||||
if strings.Contains(strings.ToLower(hexAddr), "af88d065e77c8cc2239327c5edb3a432268e5831") {
|
||||
return "EMBEDDED_USDC_PATTERN"
|
||||
}
|
||||
|
||||
return "OTHER_CORRUPTION"
|
||||
}
|
||||
Reference in New Issue
Block a user