feat: create v2-prep branch with comprehensive planning

Restructured project for V2 refactor:

**Structure Changes:**
- Moved all V1 code to orig/ folder (preserved with git mv)
- Created docs/planning/ directory
- Added orig/README_V1.md explaining V1 preservation

**Planning Documents:**
- 00_V2_MASTER_PLAN.md: Complete architecture overview
  - Executive summary of critical V1 issues
  - High-level component architecture diagrams
  - 5-phase implementation roadmap
  - Success metrics and risk mitigation

- 07_TASK_BREAKDOWN.md: Atomic task breakdown
  - 99+ hours of detailed tasks
  - Every task < 2 hours (atomic)
  - Clear dependencies and success criteria
  - Organized by implementation phase

**V2 Key Improvements:**
- Per-exchange parsers (factory pattern)
- Multi-layer strict validation
- Multi-index pool cache
- Background validation pipeline
- Comprehensive observability

**Critical Issues Addressed:**
- Zero address tokens (strict validation + cache enrichment)
- Parsing accuracy (protocol-specific parsers)
- No audit trail (background validation channel)
- Inefficient lookups (multi-index cache)
- Stats disconnection (event-driven metrics)

Next Steps:
1. Review planning documents
2. Begin Phase 1: Foundation (P1-001 through P1-010)
3. Implement parsers in Phase 2
4. Build cache system in Phase 3
5. Add validation pipeline in Phase 4
6. Migrate and test in Phase 5

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Administrator
2025-11-10 10:14:26 +01:00
parent 1773daffe7
commit 803de231ba
411 changed files with 20390 additions and 8680 deletions

View File

@@ -0,0 +1,262 @@
package auth
import (
"crypto/subtle"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// AuthConfig holds authentication configuration
type AuthConfig struct {
APIKey string
BasicUsername string
BasicPassword string
AllowedIPs []string
RequireHTTPS bool
RateLimitRPS int
Logger *logger.Logger
}
// Middleware provides authentication middleware for HTTP endpoints
type Middleware struct {
config *AuthConfig
rateLimiter map[string]*RateLimiter
mu sync.RWMutex
}
// RateLimiter tracks request rates per IP
type RateLimiter struct {
requests []time.Time
maxRequests int
window time.Duration
}
// NewMiddleware creates a new authentication middleware
func NewMiddleware(config *AuthConfig) *Middleware {
// Use environment variables for sensitive data
if config.APIKey == "" {
config.APIKey = os.Getenv("MEV_BOT_API_KEY")
}
if config.BasicUsername == "" {
config.BasicUsername = os.Getenv("MEV_BOT_USERNAME")
}
if config.BasicPassword == "" {
config.BasicPassword = os.Getenv("MEV_BOT_PASSWORD")
}
return &Middleware{
config: config,
rateLimiter: make(map[string]*RateLimiter),
}
}
// RequireAuthentication is a middleware that requires API key or basic auth
func (m *Middleware) RequireAuthentication(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Require HTTPS in production
if m.config.RequireHTTPS && r.Header.Get("X-Forwarded-Proto") != "https" && r.TLS == nil {
http.Error(w, "HTTPS required", http.StatusUpgradeRequired)
return
}
// Check IP allowlist
if !m.isIPAllowed(r.RemoteAddr) {
m.config.Logger.Warn(fmt.Sprintf("Blocked request from unauthorized IP: %s", r.RemoteAddr))
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Rate limiting
if !m.checkRateLimit(r.RemoteAddr) {
m.config.Logger.Warn(fmt.Sprintf("Rate limit exceeded for IP: %s", r.RemoteAddr))
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
// Try API key authentication first
if m.authenticateAPIKey(r) {
next.ServeHTTP(w, r)
return
}
// Try basic authentication
if m.authenticateBasic(r) {
next.ServeHTTP(w, r)
return
}
// Authentication failed
w.Header().Set("WWW-Authenticate", `Basic realm="MEV Bot API"`)
http.Error(w, "Authentication required", http.StatusUnauthorized)
}
}
// authenticateAPIKey checks for valid API key in header or query param
func (m *Middleware) authenticateAPIKey(r *http.Request) bool {
if m.config.APIKey == "" {
return false
}
// Check Authorization header
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
token := strings.TrimPrefix(auth, "Bearer ")
return subtle.ConstantTimeCompare([]byte(token), []byte(m.config.APIKey)) == 1
}
// Check X-API-Key header
apiKey := r.Header.Get("X-API-Key")
if apiKey != "" {
return subtle.ConstantTimeCompare([]byte(apiKey), []byte(m.config.APIKey)) == 1
}
// Check query parameter (less secure, but sometimes necessary)
queryKey := r.URL.Query().Get("api_key")
if queryKey != "" {
return subtle.ConstantTimeCompare([]byte(queryKey), []byte(m.config.APIKey)) == 1
}
return false
}
// authenticateBasic checks basic authentication credentials
func (m *Middleware) authenticateBasic(r *http.Request) bool {
if m.config.BasicUsername == "" || m.config.BasicPassword == "" {
return false
}
username, password, ok := r.BasicAuth()
if !ok {
return false
}
// Use constant time comparison to prevent timing attacks
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(m.config.BasicUsername)) == 1
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(m.config.BasicPassword)) == 1
return usernameMatch && passwordMatch
}
// isIPAllowed checks if the request IP is in the allowlist
func (m *Middleware) isIPAllowed(remoteAddr string) bool {
if len(m.config.AllowedIPs) == 0 {
return true // No IP restrictions
}
// Extract IP from address (remove port)
ip := strings.Split(remoteAddr, ":")[0]
for _, allowedIP := range m.config.AllowedIPs {
if ip == allowedIP || allowedIP == "*" {
return true
}
// Support CIDR notation in future versions
}
return false
}
// checkRateLimit implements simple rate limiting per IP
func (m *Middleware) checkRateLimit(remoteAddr string) bool {
if m.config.RateLimitRPS <= 0 {
return true // No rate limiting
}
ip := strings.Split(remoteAddr, ":")[0]
now := time.Now()
m.mu.Lock()
defer m.mu.Unlock()
// Get or create rate limiter for this IP
limiter, exists := m.rateLimiter[ip]
if !exists {
limiter = &RateLimiter{
requests: make([]time.Time, 0),
maxRequests: m.config.RateLimitRPS,
window: time.Minute,
}
m.rateLimiter[ip] = limiter
}
// Clean old requests outside the time window
cutoff := now.Add(-limiter.window)
validRequests := make([]time.Time, 0)
for _, reqTime := range limiter.requests {
if reqTime.After(cutoff) {
validRequests = append(validRequests, reqTime)
}
}
limiter.requests = validRequests
// Check if we're under the limit
if len(limiter.requests) >= limiter.maxRequests {
return false
}
// Add current request
limiter.requests = append(limiter.requests, now)
return true
}
// RequireReadOnly is a middleware for read-only endpoints (less strict)
func (m *Middleware) RequireReadOnly(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Only allow GET requests for read-only endpoints
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Apply basic security checks
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
// Check IP allowlist
if !m.isIPAllowed(r.RemoteAddr) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Apply rate limiting
if !m.checkRateLimit(r.RemoteAddr) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
}
}
// CleanupRateLimiters removes old rate limiter entries
func (m *Middleware) CleanupRateLimiters() {
cutoff := time.Now().Add(-time.Hour)
m.mu.Lock()
defer m.mu.Unlock()
for ip, limiter := range m.rateLimiter {
// Remove limiters that haven't been used recently
if len(limiter.requests) == 0 {
delete(m.rateLimiter, ip)
continue
}
lastRequest := limiter.requests[len(limiter.requests)-1]
if lastRequest.Before(cutoff) {
delete(m.rateLimiter, ip)
}
}
}

View File

@@ -0,0 +1,879 @@
package config
import (
"fmt"
"net/url"
"os"
"regexp"
"strconv"
"strings"
"time"
"gopkg.in/yaml.v3"
)
// Config represents the application configuration
type Config struct {
Arbitrum ArbitrumConfig `yaml:"arbitrum"`
Bot BotConfig `yaml:"bot"`
Uniswap UniswapConfig `yaml:"uniswap"`
Log LogConfig `yaml:"log"`
Database DatabaseConfig `yaml:"database"`
Ethereum EthereumConfig `yaml:"ethereum"`
Contracts ContractsConfig `yaml:"contracts"`
Arbitrage ArbitrageConfig `yaml:"arbitrage"`
Features Features `yaml:"features"`
ArbitrageOptimized ArbitrageOptimizedConfig `yaml:"arbitrage_optimized"`
}
// ArbitrumConfig represents the Arbitrum node configuration
type ArbitrumConfig struct {
// Chain ID for Arbitrum (42161 for mainnet)
ChainID int64 `yaml:"chain_id"`
// Reading endpoints (WSS preferred for real-time monitoring)
ReadingEndpoints []EndpointConfig `yaml:"reading_endpoints"`
// Execution endpoints (HTTP/HTTPS or WSS for transaction submission)
ExecutionEndpoints []EndpointConfig `yaml:"execution_endpoints"`
// Fallback endpoints for failover scenarios
FallbackEndpoints []EndpointConfig `yaml:"fallback_endpoints"`
// Legacy fields for backward compatibility
RPCEndpoint string `yaml:"rpc_endpoint,omitempty"`
WSEndpoint string `yaml:"ws_endpoint,omitempty"`
// Global rate limiting configuration
RateLimit RateLimitConfig `yaml:"rate_limit"`
}
// EndpointConfig represents an RPC endpoint configuration
type EndpointConfig struct {
// RPC endpoint URL
URL string `yaml:"url"`
// Endpoint name for identification
Name string `yaml:"name"`
// Priority (lower number = higher priority)
Priority int `yaml:"priority"`
// Maximum requests per second for this endpoint
MaxRPS int `yaml:"max_rps"`
// Maximum concurrent connections
MaxConcurrent int `yaml:"max_concurrent"`
// Connection timeout in seconds
TimeoutSeconds int `yaml:"timeout_seconds"`
// Health check interval in seconds
HealthCheckInterval int `yaml:"health_check_interval"`
// Rate limiting configuration for this endpoint
RateLimit RateLimitConfig `yaml:"rate_limit"`
}
// RateLimitConfig represents rate limiting configuration
type RateLimitConfig struct {
// Maximum requests per second
RequestsPerSecond int `yaml:"requests_per_second"`
// Maximum concurrent requests
MaxConcurrent int `yaml:"max_concurrent"`
// Burst size for rate limiting
Burst int `yaml:"burst"`
}
// BotConfig represents the bot configuration
type BotConfig struct {
// Enable or disable the bot
Enabled bool `yaml:"enabled"`
// Polling interval in seconds
PollingInterval int `yaml:"polling_interval"`
// Minimum profit threshold in USD
MinProfitThreshold float64 `yaml:"min_profit_threshold"`
// Gas price multiplier (for faster transactions)
GasPriceMultiplier float64 `yaml:"gas_price_multiplier"`
// Maximum number of concurrent workers for processing
MaxWorkers int `yaml:"max_workers"`
// Buffer size for channels
ChannelBufferSize int `yaml:"channel_buffer_size"`
// Timeout for RPC calls in seconds
RPCTimeout int `yaml:"rpc_timeout"`
}
// UniswapConfig represents the Uniswap configuration
type UniswapConfig struct {
// Factory contract address
FactoryAddress string `yaml:"factory_address"`
// Position manager contract address
PositionManagerAddress string `yaml:"position_manager_address"`
// Supported fee tiers
FeeTiers []int64 `yaml:"fee_tiers"`
// Cache configuration for pool data
Cache CacheConfig `yaml:"cache"`
}
// CacheConfig represents caching configuration
type CacheConfig struct {
// Enable or disable caching
Enabled bool `yaml:"enabled"`
// Cache expiration time in seconds
Expiration int `yaml:"expiration"`
// Maximum cache size
MaxSize int `yaml:"max_size"`
}
// LogConfig represents the logging configuration
type LogConfig struct {
// Log level (debug, info, warn, error)
Level string `yaml:"level"`
// Log format (json, text)
Format string `yaml:"format"`
// Log file path (empty for stdout)
File string `yaml:"file"`
}
// DatabaseConfig represents the database configuration
type DatabaseConfig struct {
// Database file path
File string `yaml:"file"`
// Maximum number of open connections
MaxOpenConnections int `yaml:"max_open_connections"`
// Maximum number of idle connections
MaxIdleConnections int `yaml:"max_idle_connections"`
}
// EthereumConfig represents the Ethereum account configuration
type EthereumConfig struct {
// Private key for transaction signing
PrivateKey string `yaml:"private_key"`
// Account address
AccountAddress string `yaml:"account_address"`
// Gas price multiplier (for faster transactions)
GasPriceMultiplier float64 `yaml:"gas_price_multiplier"`
}
// ContractsConfig represents the smart contract addresses
type ContractsConfig struct {
// Arbitrage executor contract address
ArbitrageExecutor string `yaml:"arbitrage_executor"`
// Flash swapper contract address
FlashSwapper string `yaml:"flash_swapper"`
// Flash loan receiver contract address (Balancer flash loans)
FlashLoanReceiver string `yaml:"flash_loan_receiver"`
// Balancer Vault address for flash loans
BalancerVault string `yaml:"balancer_vault"`
// Data fetcher contract address for batch pool data fetching
DataFetcher string `yaml:"data_fetcher"`
// Authorized caller addresses
AuthorizedCallers []string `yaml:"authorized_callers"`
// Authorized DEX addresses
AuthorizedDEXes []string `yaml:"authorized_dexes"`
}
// Load loads the configuration from a file
func Load(filename string) (*Config, error) {
// Read the config file
data, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
// Expand environment variables in the raw YAML
expandedData := expandEnvVars(string(data))
// Parse the YAML
var config Config
if err := yaml.Unmarshal([]byte(expandedData), &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Override with environment variables if they exist
config.OverrideWithEnv()
return &config, nil
}
// expandEnvVars expands ${VAR} and $VAR patterns in the given string
func expandEnvVars(s string) string {
// Pattern to match ${VAR} and $VAR
envVarPattern := regexp.MustCompile(`\$\{([^}]+)\}|\$([A-Za-z_][A-Za-z0-9_]*)`)
return envVarPattern.ReplaceAllStringFunc(s, func(match string) string {
var varName string
// Handle ${VAR} format
if strings.HasPrefix(match, "${") && strings.HasSuffix(match, "}") {
varName = match[2 : len(match)-1]
} else if strings.HasPrefix(match, "$") {
// Handle $VAR format
varName = match[1:]
}
// Get environment variable value
if value := os.Getenv(varName); value != "" {
return value
}
// Return empty string if environment variable is not set
// This prevents invalid YAML when variables are missing
return ""
})
}
// OverrideWithEnv overrides configuration with environment variables
func (c *Config) OverrideWithEnv() {
// Override legacy RPC endpoint (backward compatibility)
if rpcEndpoint := os.Getenv("ARBITRUM_RPC_ENDPOINT"); rpcEndpoint != "" {
c.Arbitrum.RPCEndpoint = rpcEndpoint
// Also add to execution endpoints if not already configured
if len(c.Arbitrum.ExecutionEndpoints) == 0 {
// Determine RPS based on endpoint type
rps := 200
if strings.HasPrefix(rpcEndpoint, "ws") {
rps = 300
}
c.Arbitrum.ExecutionEndpoints = append(c.Arbitrum.ExecutionEndpoints, EndpointConfig{
URL: rpcEndpoint,
Name: "Arbitrum Public HTTP",
Priority: 1,
MaxRPS: rps,
MaxConcurrent: 20,
TimeoutSeconds: 30,
HealthCheckInterval: 60,
RateLimit: RateLimitConfig{
RequestsPerSecond: rps,
MaxConcurrent: 20,
Burst: rps * 2,
},
})
}
}
// Override legacy WebSocket endpoint (backward compatibility)
if wsEndpoint := os.Getenv("ARBITRUM_WS_ENDPOINT"); wsEndpoint != "" {
c.Arbitrum.WSEndpoint = wsEndpoint
// Also add to reading endpoints if not already configured
if len(c.Arbitrum.ReadingEndpoints) == 0 {
c.Arbitrum.ReadingEndpoints = append(c.Arbitrum.ReadingEndpoints, EndpointConfig{
URL: wsEndpoint,
Name: "Arbitrum Public WS",
Priority: 1,
MaxRPS: 300,
MaxConcurrent: 25,
TimeoutSeconds: 60,
HealthCheckInterval: 30,
RateLimit: RateLimitConfig{
RequestsPerSecond: 300,
MaxConcurrent: 25,
Burst: 600,
},
})
}
}
// Override reading endpoints from environment
if readingEndpoints := os.Getenv("ARBITRUM_READING_ENDPOINTS"); readingEndpoints != "" {
c.Arbitrum.ReadingEndpoints = c.parseEndpointsFromEnv(readingEndpoints, "Reading")
}
// Override execution endpoints from environment
if executionEndpoints := os.Getenv("ARBITRUM_EXECUTION_ENDPOINTS"); executionEndpoints != "" {
c.Arbitrum.ExecutionEndpoints = c.parseEndpointsFromEnv(executionEndpoints, "Execution")
}
// Override fallback endpoints from environment (legacy support)
if fallbackEndpoints := os.Getenv("ARBITRUM_FALLBACK_ENDPOINTS"); fallbackEndpoints != "" {
// Add to both reading and execution if they're empty
fallbackConfigs := c.parseEndpointsFromEnv(fallbackEndpoints, "Fallback")
if len(c.Arbitrum.ReadingEndpoints) == 0 {
c.Arbitrum.ReadingEndpoints = append(c.Arbitrum.ReadingEndpoints, fallbackConfigs...)
}
if len(c.Arbitrum.ExecutionEndpoints) == 0 {
c.Arbitrum.ExecutionEndpoints = append(c.Arbitrum.ExecutionEndpoints, fallbackConfigs...)
}
}
// Override rate limit settings
if rps := os.Getenv("RPC_REQUESTS_PER_SECOND"); rps != "" {
if val, err := strconv.Atoi(rps); err == nil {
c.Arbitrum.RateLimit.RequestsPerSecond = val
}
}
if maxConcurrent := os.Getenv("RPC_MAX_CONCURRENT"); maxConcurrent != "" {
if val, err := strconv.Atoi(maxConcurrent); err == nil {
c.Arbitrum.RateLimit.MaxConcurrent = val
}
}
// Override bot settings
if maxWorkers := os.Getenv("BOT_MAX_WORKERS"); maxWorkers != "" {
if val, err := strconv.Atoi(maxWorkers); err == nil {
c.Bot.MaxWorkers = val
}
}
if channelBufferSize := os.Getenv("BOT_CHANNEL_BUFFER_SIZE"); channelBufferSize != "" {
if val, err := strconv.Atoi(channelBufferSize); err == nil {
c.Bot.ChannelBufferSize = val
}
}
// Override Ethereum settings
if privateKey := os.Getenv("ETHEREUM_PRIVATE_KEY"); privateKey != "" {
c.Ethereum.PrivateKey = privateKey
}
if accountAddress := os.Getenv("ETHEREUM_ACCOUNT_ADDRESS"); accountAddress != "" {
c.Ethereum.AccountAddress = accountAddress
}
if gasPriceMultiplier := os.Getenv("ETHEREUM_GAS_PRICE_MULTIPLIER"); gasPriceMultiplier != "" {
if val, err := strconv.ParseFloat(gasPriceMultiplier, 64); err == nil {
c.Ethereum.GasPriceMultiplier = val
}
}
// Override contract addresses
if arbitrageExecutor := os.Getenv("CONTRACT_ARBITRAGE_EXECUTOR"); arbitrageExecutor != "" {
c.Contracts.ArbitrageExecutor = arbitrageExecutor
}
if flashSwapper := os.Getenv("CONTRACT_FLASH_SWAPPER"); flashSwapper != "" {
c.Contracts.FlashSwapper = flashSwapper
}
}
// parseEndpointsFromEnv parses comma-separated endpoint URLs from environment variable
func (c *Config) parseEndpointsFromEnv(endpointsStr, namePrefix string) []EndpointConfig {
if endpointsStr == "" {
return nil
}
urls := strings.Split(endpointsStr, ",")
endpoints := make([]EndpointConfig, 0, len(urls))
for i, url := range urls {
url = strings.TrimSpace(url)
if url == "" {
continue
}
// Determine defaults based on URL scheme
var maxRPS, maxConcurrent, timeoutSeconds, healthCheckInterval int
if strings.HasPrefix(url, "ws") {
// WebSocket endpoints - higher rate limits for real-time data
maxRPS = 300
maxConcurrent = 25
timeoutSeconds = 60
healthCheckInterval = 30
} else {
// HTTP endpoints - conservative rate limits
maxRPS = 200
maxConcurrent = 20
timeoutSeconds = 30
healthCheckInterval = 60
}
endpoint := EndpointConfig{
URL: url,
Name: fmt.Sprintf("%s-%d", namePrefix, i+1),
Priority: i + 1, // Lower number = higher priority
MaxRPS: maxRPS,
MaxConcurrent: maxConcurrent,
TimeoutSeconds: timeoutSeconds,
HealthCheckInterval: healthCheckInterval,
RateLimit: RateLimitConfig{
RequestsPerSecond: maxRPS,
MaxConcurrent: maxConcurrent,
Burst: maxRPS * 2, // Allow burst of 2x normal rate
},
}
endpoints = append(endpoints, endpoint)
}
return endpoints
}
// CreateProviderConfigFile creates a temporary YAML config file for the transport system
func (c *Config) CreateProviderConfigFile(tempPath string) error {
// Convert config to provider format
providerConfig := c.ConvertToProviderConfig()
// Marshal to YAML
yamlData, err := yaml.Marshal(providerConfig)
if err != nil {
return fmt.Errorf("failed to marshal provider config: %w", err)
}
// Write to file
if err := os.WriteFile(tempPath, yamlData, 0644); err != nil {
return fmt.Errorf("failed to write provider config file: %w", err)
}
return nil
}
// ConvertToProviderConfig converts ArbitrumConfig to transport.ProvidersConfig
func (c *Config) createProviderConfig(endpoint EndpointConfig, features []string) map[string]interface{} {
provider := map[string]interface{}{
"name": endpoint.Name,
"type": "standard",
"http_endpoint": "",
"ws_endpoint": "",
"priority": endpoint.Priority,
"rate_limit": map[string]interface{}{
"requests_per_second": endpoint.RateLimit.RequestsPerSecond,
"burst": endpoint.RateLimit.Burst,
"timeout": fmt.Sprintf("%ds", endpoint.TimeoutSeconds),
"retry_delay": "1s",
"max_retries": 3,
},
"features": features,
"health_check": map[string]interface{}{
"enabled": true,
"interval": fmt.Sprintf("%ds", endpoint.HealthCheckInterval),
"timeout": fmt.Sprintf("%ds", endpoint.TimeoutSeconds),
},
}
// Determine endpoint type and assign to appropriate field
if strings.HasPrefix(endpoint.URL, "ws") {
provider["ws_endpoint"] = endpoint.URL
} else {
provider["http_endpoint"] = endpoint.URL
}
return provider
}
func (c *Config) ConvertToProviderConfig() map[string]interface{} {
providerConfigs := make([]map[string]interface{}, 0)
// Handle legacy configuration if new endpoints are not configured
if len(c.Arbitrum.ReadingEndpoints) == 0 && len(c.Arbitrum.ExecutionEndpoints) == 0 {
// Use legacy RPC and WS endpoints
if c.Arbitrum.RPCEndpoint != "" {
// Set default rate limits if zero
rps := c.Arbitrum.RateLimit.RequestsPerSecond
if rps <= 0 {
if strings.HasPrefix(c.Arbitrum.RPCEndpoint, "ws") {
rps = 300 // Default for WebSocket
} else {
rps = 200 // Default for HTTP
}
}
burst := c.Arbitrum.RateLimit.Burst
if burst <= 0 {
burst = rps * 2 // Default burst is 2x RPS
}
provider := map[string]interface{}{
"name": "Legacy-RPC",
"type": "standard",
"priority": 1,
"rate_limit": map[string]interface{}{
"requests_per_second": rps,
"burst": burst,
"timeout": "30s",
"retry_delay": "1s",
"max_retries": 3,
},
"features": []string{"execution", "reading"},
"health_check": map[string]interface{}{
"enabled": true,
"interval": "60s",
"timeout": "30s",
},
}
// Determine endpoint type and assign to appropriate field
if strings.HasPrefix(c.Arbitrum.RPCEndpoint, "ws") {
provider["http_endpoint"] = ""
provider["ws_endpoint"] = c.Arbitrum.RPCEndpoint
} else {
provider["http_endpoint"] = c.Arbitrum.RPCEndpoint
provider["ws_endpoint"] = ""
}
providerConfigs = append(providerConfigs, provider)
}
if c.Arbitrum.WSEndpoint != "" {
// Set default rate limits if zero
rps := c.Arbitrum.RateLimit.RequestsPerSecond
if rps <= 0 {
rps = 300 // Default for WebSocket
}
burst := c.Arbitrum.RateLimit.Burst
if burst <= 0 {
burst = rps * 2 // Default burst is 2x RPS
}
provider := map[string]interface{}{
"name": "Legacy-WSS",
"type": "standard",
"http_endpoint": "",
"ws_endpoint": c.Arbitrum.WSEndpoint,
"priority": 1,
"rate_limit": map[string]interface{}{
"requests_per_second": rps,
"burst": burst,
"timeout": "60s",
"retry_delay": "1s",
"max_retries": 3,
},
"features": []string{"reading", "real_time"},
"health_check": map[string]interface{}{
"enabled": true,
"interval": "30s",
"timeout": "60s",
},
}
providerConfigs = append(providerConfigs, provider)
}
// Create simple pool configuration for legacy mode
providerPools := make(map[string]interface{})
if len(providerConfigs) > 0 {
providerNames := make([]string, 0)
for _, provider := range providerConfigs {
providerNames = append(providerNames, provider["name"].(string))
}
// Use same providers for both reading and execution in legacy mode
providerPools["read_only"] = map[string]interface{}{
"strategy": "priority_based",
"max_concurrent_connections": 25,
"health_check_interval": "30s",
"failover_enabled": true,
"providers": providerNames,
}
providerPools["execution"] = map[string]interface{}{
"strategy": "priority_based",
"max_concurrent_connections": 20,
"health_check_interval": "30s",
"failover_enabled": true,
"providers": providerNames,
}
}
return map[string]interface{}{
"provider_pools": providerPools,
"providers": providerConfigs,
"rotation": map[string]interface{}{
"strategy": "priority_based",
"health_check_required": true,
"fallover_enabled": true,
"retry_failed_after": "5m",
},
"global_limits": map[string]interface{}{
"max_concurrent_connections": 50,
"connection_timeout": "30s",
"read_timeout": "60s",
"write_timeout": "30s",
"idle_timeout": "300s",
},
"monitoring": map[string]interface{}{
"enabled": true,
"metrics_interval": "60s",
"log_slow_requests": true,
"slow_request_threshold": "5s",
"track_provider_performance": true,
},
}
}
// Convert reading endpoints
for _, endpoint := range c.Arbitrum.ReadingEndpoints {
providerConfigs = append(providerConfigs, c.createProviderConfig(endpoint, []string{"reading", "real_time"}))
}
// Convert execution endpoints
for _, endpoint := range c.Arbitrum.ExecutionEndpoints {
providerConfigs = append(providerConfigs, c.createProviderConfig(endpoint, []string{"execution", "transaction_submission"}))
}
// Build provider pool configurations
providerPools := make(map[string]interface{})
// Reading pool configuration
if len(c.Arbitrum.ReadingEndpoints) > 0 {
readingProviders := make([]string, 0)
for _, endpoint := range c.Arbitrum.ReadingEndpoints {
readingProviders = append(readingProviders, endpoint.Name)
}
providerPools["read_only"] = map[string]interface{}{
"strategy": "websocket_preferred",
"max_concurrent_connections": 25,
"health_check_interval": "30s",
"failover_enabled": true,
"providers": readingProviders,
}
}
// Execution pool configuration
if len(c.Arbitrum.ExecutionEndpoints) > 0 {
executionProviders := make([]string, 0)
for _, endpoint := range c.Arbitrum.ExecutionEndpoints {
executionProviders = append(executionProviders, endpoint.Name)
}
providerPools["execution"] = map[string]interface{}{
"strategy": "reliability_first",
"max_concurrent_connections": 20,
"health_check_interval": "30s",
"failover_enabled": true,
"providers": executionProviders,
}
}
// Complete configuration
return map[string]interface{}{
"provider_pools": providerPools,
"providers": providerConfigs,
"rotation": map[string]interface{}{
"strategy": "priority_based",
"health_check_required": true,
"fallover_enabled": true,
"retry_failed_after": "5m",
},
"global_limits": map[string]interface{}{
"max_concurrent_connections": 50,
"connection_timeout": "30s",
"read_timeout": "60s",
"write_timeout": "30s",
"idle_timeout": "300s",
},
"monitoring": map[string]interface{}{
"enabled": true,
"metrics_interval": "60s",
"log_slow_requests": true,
"slow_request_threshold": "5s",
"track_provider_performance": true,
},
}
}
// ValidateEnvironmentVariables validates all required environment variables
func (c *Config) ValidateEnvironmentVariables() error {
// Validate RPC endpoint
if c.Arbitrum.RPCEndpoint == "" {
return fmt.Errorf("ARBITRUM_RPC_ENDPOINT environment variable is required")
}
if err := validateRPCEndpoint(c.Arbitrum.RPCEndpoint); err != nil {
return fmt.Errorf("invalid ARBITRUM_RPC_ENDPOINT: %w", err)
}
// Validate WebSocket endpoint if provided
if c.Arbitrum.WSEndpoint != "" {
if err := validateRPCEndpoint(c.Arbitrum.WSEndpoint); err != nil {
return fmt.Errorf("invalid ARBITRUM_WS_ENDPOINT: %w", err)
}
}
// Validate Ethereum private key
if c.Ethereum.PrivateKey == "" {
return fmt.Errorf("ETHEREUM_PRIVATE_KEY environment variable is required")
}
// Validate account address
if c.Ethereum.AccountAddress == "" {
return fmt.Errorf("ETHEREUM_ACCOUNT_ADDRESS environment variable is required")
}
// Validate contract addresses
if c.Contracts.ArbitrageExecutor == "" {
return fmt.Errorf("CONTRACT_ARBITRAGE_EXECUTOR environment variable is required")
}
if c.Contracts.FlashSwapper == "" {
return fmt.Errorf("CONTRACT_FLASH_SWAPPER environment variable is required")
}
// Validate numeric values
if c.Arbitrum.RateLimit.RequestsPerSecond < 0 {
return fmt.Errorf("RPC_REQUESTS_PER_SECOND must be non-negative")
}
if c.Arbitrum.RateLimit.MaxConcurrent < 0 {
return fmt.Errorf("RPC_MAX_CONCURRENT must be non-negative")
}
if c.Bot.MaxWorkers <= 0 {
return fmt.Errorf("BOT_MAX_WORKERS must be positive")
}
if c.Bot.ChannelBufferSize < 0 {
return fmt.Errorf("BOT_CHANNEL_BUFFER_SIZE must be non-negative")
}
if c.Ethereum.GasPriceMultiplier < 0 {
return fmt.Errorf("ETHEREUM_GAS_PRICE_MULTIPLIER must be non-negative")
}
return nil
}
// validateRPCEndpoint validates RPC endpoint URL for security and format
func validateRPCEndpoint(endpoint string) error {
if endpoint == "" {
return fmt.Errorf("RPC endpoint cannot be empty")
}
u, err := url.Parse(endpoint)
if err != nil {
return fmt.Errorf("invalid RPC endpoint URL: %w", err)
}
// Check for valid schemes
switch u.Scheme {
case "http", "https", "ws", "wss":
// Valid schemes
default:
return fmt.Errorf("invalid RPC scheme: %s (must be http, https, ws, or wss)", u.Scheme)
}
// Check for localhost/private networks in production
if strings.Contains(u.Hostname(), "localhost") || strings.Contains(u.Hostname(), "127.0.0.1") {
// Allow localhost only if explicitly enabled
if os.Getenv("MEV_BOT_ALLOW_LOCALHOST") != "true" {
return fmt.Errorf("localhost RPC endpoints not allowed in production (set MEV_BOT_ALLOW_LOCALHOST=true to override)")
}
}
// Validate hostname is not empty
if u.Hostname() == "" {
return fmt.Errorf("RPC endpoint must have a valid hostname")
}
return nil
}
// ArbitrageConfig represents the arbitrage service configuration
type ArbitrageConfig struct {
// Enable or disable arbitrage service
Enabled bool `yaml:"enabled"`
// Contract addresses
ArbitrageContractAddress string `yaml:"arbitrage_contract_address"`
FlashSwapContractAddress string `yaml:"flash_swap_contract_address"`
// Profitability settings
MinProfitWei int64 `yaml:"min_profit_wei"`
MinROIPercent float64 `yaml:"min_roi_percent"`
MinSignificantSwapSize int64 `yaml:"min_significant_swap_size"`
SlippageTolerance float64 `yaml:"slippage_tolerance"`
// Scanning configuration
MinScanAmountWei int64 `yaml:"min_scan_amount_wei"`
MaxScanAmountWei int64 `yaml:"max_scan_amount_wei"`
// Gas configuration
MaxGasPriceWei int64 `yaml:"max_gas_price_wei"`
// Execution limits
MaxConcurrentExecutions int `yaml:"max_concurrent_executions"`
MaxOpportunitiesPerEvent int `yaml:"max_opportunities_per_event"`
// Timing settings
OpportunityTTL time.Duration `yaml:"opportunity_ttl"`
MaxPathAge time.Duration `yaml:"max_path_age"`
StatsUpdateInterval time.Duration `yaml:"stats_update_interval"`
// Pool discovery configuration
PoolDiscoveryConfig PoolDiscoveryConfig `yaml:"pool_discovery"`
}
// PoolDiscoveryConfig represents pool discovery service configuration
type PoolDiscoveryConfig struct {
// Enable or disable pool discovery
Enabled bool `yaml:"enabled"`
// Block range to scan for new pools
BlockRange uint64 `yaml:"block_range"`
// Polling interval for new pools
PollingInterval time.Duration `yaml:"polling_interval"`
// DEX factory addresses to monitor
FactoryAddresses []string `yaml:"factory_addresses"`
// Minimum liquidity threshold for pools
MinLiquidityWei int64 `yaml:"min_liquidity_wei"`
// Cache configuration
CacheSize int `yaml:"cache_size"`
CacheTTL time.Duration `yaml:"cache_ttl"`
}
// Features represents Layer 2 optimization feature flags
type Features struct {
// Phase 1: Configuration tuning
UseArbitrumOptimizedTimeouts bool `yaml:"use_arbitrum_optimized_timeouts"`
UseDynamicTTL bool `yaml:"use_dynamic_ttl"`
// Phase 2: Transaction filtering
EnableDEXPrefilter bool `yaml:"enable_dex_prefilter"`
// Phase 3: Sequencer optimization
UseDirectSequencerFeed bool `yaml:"use_direct_sequencer_feed"`
// Phase 4-5: Timeboost
EnableTimeboost bool `yaml:"enable_timeboost"`
}
// ArbitrageOptimizedConfig represents Arbitrum-optimized arbitrage timing
type ArbitrageOptimizedConfig struct {
// Opportunity lifecycle (tuned for 250ms blocks)
OpportunityTTL time.Duration `yaml:"opportunity_ttl"`
MaxPathAge time.Duration `yaml:"max_path_age"`
ExecutionDeadline time.Duration `yaml:"execution_deadline"`
// Legacy values for rollback
LegacyOpportunityTTL time.Duration `yaml:"legacy_opportunity_ttl"`
LegacyMaxPathAge time.Duration `yaml:"legacy_max_path_age"`
// Dynamic TTL settings
DynamicTTL DynamicTTLConfig `yaml:"dynamic_ttl"`
}
// DynamicTTLConfig represents dynamic TTL calculation settings
type DynamicTTLConfig struct {
MinTTLBlocks int `yaml:"min_ttl_blocks"`
MaxTTLBlocks int `yaml:"max_ttl_blocks"`
ProfitMultiplier bool `yaml:"profit_multiplier"`
VolatilityAdjustment bool `yaml:"volatility_adjustment"`
}
// GetOpportunityTTL returns the active opportunity TTL based on feature flags
func (c *Config) GetOpportunityTTL() time.Duration {
if c.Features.UseArbitrumOptimizedTimeouts {
return c.ArbitrageOptimized.OpportunityTTL
}
// Fallback to legacy config
if c.Arbitrage.OpportunityTTL > 0 {
return c.Arbitrage.OpportunityTTL
}
// Default fallback
return 30 * time.Second
}
// GetMaxPathAge returns the active max path age based on feature flags
func (c *Config) GetMaxPathAge() time.Duration {
if c.Features.UseArbitrumOptimizedTimeouts {
return c.ArbitrageOptimized.MaxPathAge
}
// Fallback to legacy config
if c.Arbitrage.MaxPathAge > 0 {
return c.Arbitrage.MaxPathAge
}
// Default fallback
return 60 * time.Second
}
// GetExecutionDeadline returns the execution deadline
func (c *Config) GetExecutionDeadline() time.Duration {
if c.Features.UseArbitrumOptimizedTimeouts && c.ArbitrageOptimized.ExecutionDeadline > 0 {
return c.ArbitrageOptimized.ExecutionDeadline
}
// Default fallback for Arbitrum (12 blocks @ 250ms)
return 3 * time.Second
}

View File

@@ -0,0 +1,139 @@
package config
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoad(t *testing.T) {
// Create a temporary config file for testing
tmpFile, err := os.CreateTemp("", "config_test_*.yaml")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
// Write test config content
configContent := `
arbitrum:
rpc_endpoint: "${ARBITRUM_RPC_ENDPOINT}"
ws_endpoint: "${ARBITRUM_WS_ENDPOINT}"
chain_id: 42161
rate_limit:
requests_per_second: 5
max_concurrent: 3
burst: 10
bot:
enabled: true
polling_interval: 3
min_profit_threshold: 10.0
gas_price_multiplier: 1.2
max_workers: 3
channel_buffer_size: 50
rpc_timeout: 30
uniswap:
factory_address: "0x1F98431c8aD98523631AE4a59f267346ea31F984"
position_manager_address: "0xC36442b4a4522E871399CD717aBDD847Ab11FE88"
fee_tiers:
- 500
- 3000
- 10000
cache:
enabled: true
expiration: 300
max_size: 10000
log:
level: "debug"
format: "text"
file: "logs/mev-bot.log"
database:
file: "mev-bot.db"
max_open_connections: 10
max_idle_connections: 5
`
_, err = tmpFile.Write([]byte(configContent))
require.NoError(t, err)
err = tmpFile.Close()
require.NoError(t, err)
// Set environment variables for test
os.Setenv("ARBITRUM_RPC_ENDPOINT", "wss://arbitrum-mainnet.core.chainstack.com/53c30e7a941160679fdcc396c894fc57")
os.Setenv("ARBITRUM_WS_ENDPOINT", "wss://arbitrum-mainnet.core.chainstack.com/53c30e7a941160679fdcc396c894fc57")
defer func() {
os.Unsetenv("ARBITRUM_RPC_ENDPOINT")
os.Unsetenv("ARBITRUM_WS_ENDPOINT")
}()
// Test loading the config
cfg, err := Load(tmpFile.Name())
require.NoError(t, err)
// Verify the loaded config
assert.Equal(t, "wss://arbitrum-mainnet.core.chainstack.com/53c30e7a941160679fdcc396c894fc57", cfg.Arbitrum.RPCEndpoint)
assert.Equal(t, "wss://arbitrum-mainnet.core.chainstack.com/53c30e7a941160679fdcc396c894fc57", cfg.Arbitrum.WSEndpoint)
assert.Equal(t, int64(42161), cfg.Arbitrum.ChainID)
assert.Equal(t, 5, cfg.Arbitrum.RateLimit.RequestsPerSecond)
assert.True(t, cfg.Bot.Enabled)
assert.Equal(t, 3, cfg.Bot.PollingInterval)
assert.Equal(t, 10.0, cfg.Bot.MinProfitThreshold)
assert.Equal(t, "0x1F98431c8aD98523631AE4a59f267346ea31F984", cfg.Uniswap.FactoryAddress)
assert.Len(t, cfg.Uniswap.FeeTiers, 3)
assert.Equal(t, true, cfg.Uniswap.Cache.Enabled)
assert.Equal(t, "debug", cfg.Log.Level)
assert.Equal(t, "logs/mev-bot.log", cfg.Log.File)
assert.Equal(t, "mev-bot.db", cfg.Database.File)
}
func TestLoadWithInvalidFile(t *testing.T) {
// Test loading a non-existent config file
_, err := Load("/non/existent/file.yaml")
assert.Error(t, err)
}
func TestOverrideWithEnv(t *testing.T) {
// Create a temporary config file for testing
tmpFile, err := os.CreateTemp("", "config_test_*.yaml")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
// Write test config content
configContent := `
arbitrum:
rpc_endpoint: "https://arb1.arbitrum.io/rpc"
rate_limit:
requests_per_second: 10
max_concurrent: 5
bot:
max_workers: 10
channel_buffer_size: 100
`
_, err = tmpFile.Write([]byte(configContent))
require.NoError(t, err)
err = tmpFile.Close()
require.NoError(t, err)
// Set environment variables to override config
os.Setenv("ARBITRUM_RPC_ENDPOINT", "https://override.arbitrum.io/rpc")
os.Setenv("RPC_REQUESTS_PER_SECOND", "20")
os.Setenv("BOT_MAX_WORKERS", "20")
defer func() {
os.Unsetenv("ARBITRUM_RPC_ENDPOINT")
os.Unsetenv("RPC_REQUESTS_PER_SECOND")
os.Unsetenv("BOT_MAX_WORKERS")
}()
// Load the config
cfg, err := Load(tmpFile.Name())
require.NoError(t, err)
// Verify the overridden values
assert.Equal(t, "https://override.arbitrum.io/rpc", cfg.Arbitrum.RPCEndpoint)
assert.Equal(t, 20, cfg.Arbitrum.RateLimit.RequestsPerSecond)
assert.Equal(t, 20, cfg.Bot.MaxWorkers)
}

View File

@@ -0,0 +1,349 @@
package contracts
import (
"context"
"fmt"
"math/big"
"strings"
"time"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/fraktal/mev-beta/internal/logger"
)
// ContractType represents the detected type of a contract
type ContractType int
const (
ContractTypeUnknown ContractType = iota
ContractTypeERC20Token
ContractTypeUniswapV2Pool
ContractTypeUniswapV3Pool
ContractTypeUniswapV2Router
ContractTypeUniswapV3Router
ContractTypeUniversalRouter
ContractTypeFactory
ContractTypeEOA // Externally Owned Account
)
// String returns the string representation of the contract type
func (ct ContractType) String() string {
switch ct {
case ContractTypeERC20Token:
return "ERC-20"
case ContractTypeUniswapV2Pool:
return "UniswapV2Pool"
case ContractTypeUniswapV3Pool:
return "UniswapV3Pool"
case ContractTypeUniswapV2Router:
return "UniswapV2Router"
case ContractTypeUniswapV3Router:
return "UniswapV3Router"
case ContractTypeUniversalRouter:
return "UniversalRouter"
case ContractTypeFactory:
return "Factory"
case ContractTypeEOA:
return "EOA"
default:
return "Unknown"
}
}
// DetectionResult contains the result of contract type detection
type DetectionResult struct {
ContractType ContractType
IsContract bool
HasCode bool
SupportedFunctions []string
Confidence float64 // 0.0 to 1.0
Error error
Warnings []string
}
// ContractDetector provides runtime contract type detection
type ContractDetector struct {
client *ethclient.Client
logger *logger.Logger
cache map[common.Address]*DetectionResult
timeout time.Duration
}
// NewContractDetector creates a new contract detector
func NewContractDetector(client *ethclient.Client, logger *logger.Logger) *ContractDetector {
return &ContractDetector{
client: client,
logger: logger,
cache: make(map[common.Address]*DetectionResult),
timeout: 5 * time.Second,
}
}
// DetectContractType determines the type of contract at the given address
func (cd *ContractDetector) DetectContractType(ctx context.Context, address common.Address) *DetectionResult {
// Check cache first
if result, exists := cd.cache[address]; exists {
return result
}
result := &DetectionResult{
ContractType: ContractTypeUnknown,
IsContract: false,
HasCode: false,
SupportedFunctions: []string{},
Confidence: 0.0,
Warnings: []string{},
}
// Create context with timeout
ctxWithTimeout, cancel := context.WithTimeout(ctx, cd.timeout)
defer cancel()
// Check if address has code (is a contract)
code, err := cd.client.CodeAt(ctxWithTimeout, address, nil)
if err != nil {
result.Error = fmt.Errorf("failed to get code at address: %w", err)
cd.cache[address] = result
return result
}
// If no code, it's an EOA
if len(code) == 0 {
result.ContractType = ContractTypeEOA
result.IsContract = false
result.Confidence = 1.0
cd.cache[address] = result
return result
}
result.IsContract = true
result.HasCode = true
// Detect contract type by testing function signatures
contractType, confidence, supportedFunctions := cd.detectByFunctionSignatures(ctxWithTimeout, address)
result.ContractType = contractType
result.Confidence = confidence
result.SupportedFunctions = supportedFunctions
// Additional validation for high-confidence detection
if confidence > 0.8 {
if err := cd.validateContractType(ctxWithTimeout, address, contractType); err != nil {
result.Warnings = append(result.Warnings, fmt.Sprintf("validation warning: %v", err))
result.Confidence *= 0.8 // Reduce confidence
}
}
cd.cache[address] = result
return result
}
// detectByFunctionSignatures detects contract type by testing known function signatures
func (cd *ContractDetector) detectByFunctionSignatures(ctx context.Context, address common.Address) (ContractType, float64, []string) {
supportedFunctions := []string{}
scores := make(map[ContractType]float64)
// Test ERC-20 functions
erc20Functions := map[string][]byte{
"name()": {0x06, 0xfd, 0xde, 0x03},
"symbol()": {0x95, 0xd8, 0x9b, 0x41},
"decimals()": {0x31, 0x3c, 0xe5, 0x67},
"totalSupply()": {0x18, 0x16, 0x0d, 0xdd},
"balanceOf()": {0x70, 0xa0, 0x82, 0x31},
}
erc20Score := cd.testFunctionSignatures(ctx, address, erc20Functions, &supportedFunctions)
if erc20Score > 0.6 {
scores[ContractTypeERC20Token] = erc20Score
}
// Test Uniswap V2 Pool functions
v2PoolFunctions := map[string][]byte{
"token0()": {0x0d, 0xfe, 0x16, 0x81},
"token1()": {0xd2, 0x12, 0x20, 0xa7},
"getReserves()": {0x09, 0x02, 0xf1, 0xac},
"price0CumulativeLast()": {0x54, 0x1c, 0x5c, 0xfa},
"kLast()": {0x7d, 0xc0, 0xd1, 0xd0},
}
v2PoolScore := cd.testFunctionSignatures(ctx, address, v2PoolFunctions, &supportedFunctions)
if v2PoolScore > 0.6 {
scores[ContractTypeUniswapV2Pool] = v2PoolScore
}
// Test Uniswap V3 Pool functions
v3PoolFunctions := map[string][]byte{
"token0()": {0x0d, 0xfe, 0x16, 0x81},
"token1()": {0xd2, 0x12, 0x20, 0xa7},
"fee()": {0xdd, 0xca, 0x3f, 0x43},
"slot0()": {0x38, 0x50, 0xc7, 0xbd},
"liquidity()": {0x1a, 0x68, 0x65, 0x0f},
"tickSpacing()": {0xd0, 0xc9, 0x32, 0x07},
}
v3PoolScore := cd.testFunctionSignatures(ctx, address, v3PoolFunctions, &supportedFunctions)
if v3PoolScore > 0.6 {
scores[ContractTypeUniswapV3Pool] = v3PoolScore
}
// Test Router functions
routerFunctions := map[string][]byte{
"WETH()": {0xad, 0x5c, 0x46, 0x48},
"swapExactTokensForTokens()": {0x38, 0xed, 0x17, 0x39},
"factory()": {0xc4, 0x5a, 0x01, 0x55},
}
routerScore := cd.testFunctionSignatures(ctx, address, routerFunctions, &supportedFunctions)
if routerScore > 0.5 {
scores[ContractTypeUniswapV2Router] = routerScore
}
// Find highest scoring type
var bestType ContractType = ContractTypeUnknown
var bestScore float64 = 0.0
for contractType, score := range scores {
if score > bestScore {
bestScore = score
bestType = contractType
}
}
return bestType, bestScore, supportedFunctions
}
// testFunctionSignatures tests if a contract supports given function signatures
func (cd *ContractDetector) testFunctionSignatures(ctx context.Context, address common.Address, functions map[string][]byte, supportedFunctions *[]string) float64 {
supported := 0
total := len(functions)
for funcName, signature := range functions {
// Test the function call
_, err := cd.client.CallContract(ctx, ethereum.CallMsg{
To: &address,
Data: signature,
}, nil)
if err == nil {
supported++
*supportedFunctions = append(*supportedFunctions, funcName)
} else if !strings.Contains(err.Error(), "execution reverted") {
// If it's not a revert, it might be a network error, so we don't count it against
total--
}
}
if total == 0 {
return 0.0
}
return float64(supported) / float64(total)
}
// validateContractType performs additional validation for detected contract types
func (cd *ContractDetector) validateContractType(ctx context.Context, address common.Address, contractType ContractType) error {
switch contractType {
case ContractTypeERC20Token:
return cd.validateERC20(ctx, address)
case ContractTypeUniswapV2Pool:
return cd.validateUniswapV2Pool(ctx, address)
case ContractTypeUniswapV3Pool:
return cd.validateUniswapV3Pool(ctx, address)
default:
return nil // No additional validation for other types
}
}
// validateERC20 validates that a contract is actually an ERC-20 token
func (cd *ContractDetector) validateERC20(ctx context.Context, address common.Address) error {
// Test decimals() - should return a reasonable value (0-18)
decimalsData := []byte{0x31, 0x3c, 0xe5, 0x67} // decimals()
result, err := cd.client.CallContract(ctx, ethereum.CallMsg{
To: &address,
Data: decimalsData,
}, nil)
if err != nil {
return fmt.Errorf("decimals() call failed: %w", err)
}
if len(result) == 32 {
decimals := new(big.Int).SetBytes(result).Uint64()
if decimals > 18 {
return fmt.Errorf("unrealistic decimals value: %d", decimals)
}
}
return nil
}
// validateUniswapV2Pool validates that a contract is actually a Uniswap V2 pool
func (cd *ContractDetector) validateUniswapV2Pool(ctx context.Context, address common.Address) error {
// Test getReserves() - should return 3 values
getReservesData := []byte{0x09, 0x02, 0xf1, 0xac} // getReserves()
result, err := cd.client.CallContract(ctx, ethereum.CallMsg{
To: &address,
Data: getReservesData,
}, nil)
if err != nil {
return fmt.Errorf("getReserves() call failed: %w", err)
}
// Should return 3 uint112 values (reserves + timestamp)
if len(result) != 96 { // 3 * 32 bytes
return fmt.Errorf("unexpected getReserves() return length: %d", len(result))
}
return nil
}
// validateUniswapV3Pool validates that a contract is actually a Uniswap V3 pool
func (cd *ContractDetector) validateUniswapV3Pool(ctx context.Context, address common.Address) error {
// Test slot0() - should return current state
slot0Data := []byte{0x38, 0x50, 0xc7, 0xbd} // slot0()
result, err := cd.client.CallContract(ctx, ethereum.CallMsg{
To: &address,
Data: slot0Data,
}, nil)
if err != nil {
return fmt.Errorf("slot0() call failed: %w", err)
}
// Should return multiple values including sqrtPriceX96
if len(result) < 32 {
return fmt.Errorf("unexpected slot0() return length: %d", len(result))
}
return nil
}
// IsERC20Token checks if an address is an ERC-20 token
func (cd *ContractDetector) IsERC20Token(ctx context.Context, address common.Address) bool {
result := cd.DetectContractType(ctx, address)
return result.ContractType == ContractTypeERC20Token && result.Confidence > 0.7
}
// IsUniswapPool checks if an address is a Uniswap pool (V2 or V3)
func (cd *ContractDetector) IsUniswapPool(ctx context.Context, address common.Address) bool {
result := cd.DetectContractType(ctx, address)
return (result.ContractType == ContractTypeUniswapV2Pool || result.ContractType == ContractTypeUniswapV3Pool) && result.Confidence > 0.7
}
// IsRouter checks if an address is a router contract
func (cd *ContractDetector) IsRouter(ctx context.Context, address common.Address) bool {
result := cd.DetectContractType(ctx, address)
return (result.ContractType == ContractTypeUniswapV2Router ||
result.ContractType == ContractTypeUniswapV3Router ||
result.ContractType == ContractTypeUniversalRouter) && result.Confidence > 0.7
}
// ClearCache clears the detection cache
func (cd *ContractDetector) ClearCache() {
cd.cache = make(map[common.Address]*DetectionResult)
}
// GetCacheSize returns the number of cached results
func (cd *ContractDetector) GetCacheSize() int {
return len(cd.cache)
}

View File

@@ -0,0 +1,239 @@
package contracts
import (
"context"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/fraktal/mev-beta/internal/logger"
)
// FunctionSignature represents a known function signature
type FunctionSignature struct {
Name string
Selector []byte
AllowedTypes []ContractType
}
// SignatureValidator validates function calls against contract types
type SignatureValidator struct {
detector *ContractDetector
logger *logger.Logger
signatures map[string]*FunctionSignature
}
// NewSignatureValidator creates a new function signature validator
func NewSignatureValidator(detector *ContractDetector, logger *logger.Logger) *SignatureValidator {
sv := &SignatureValidator{
detector: detector,
logger: logger,
signatures: make(map[string]*FunctionSignature),
}
// Initialize known function signatures
sv.initializeSignatures()
return sv
}
// initializeSignatures initializes the known function signatures and their allowed contract types
func (sv *SignatureValidator) initializeSignatures() {
// ERC-20 token functions
sv.signatures["name()"] = &FunctionSignature{
Name: "name()",
Selector: []byte{0x06, 0xfd, 0xde, 0x03},
AllowedTypes: []ContractType{ContractTypeERC20Token},
}
sv.signatures["symbol()"] = &FunctionSignature{
Name: "symbol()",
Selector: []byte{0x95, 0xd8, 0x9b, 0x41},
AllowedTypes: []ContractType{ContractTypeERC20Token},
}
sv.signatures["decimals()"] = &FunctionSignature{
Name: "decimals()",
Selector: []byte{0x31, 0x3c, 0xe5, 0x67},
AllowedTypes: []ContractType{ContractTypeERC20Token},
}
sv.signatures["totalSupply()"] = &FunctionSignature{
Name: "totalSupply()",
Selector: []byte{0x18, 0x16, 0x0d, 0xdd},
AllowedTypes: []ContractType{ContractTypeERC20Token},
}
sv.signatures["balanceOf()"] = &FunctionSignature{
Name: "balanceOf()",
Selector: []byte{0x70, 0xa0, 0x82, 0x31},
AllowedTypes: []ContractType{ContractTypeERC20Token},
}
// Uniswap V2 Pool functions
sv.signatures["token0()"] = &FunctionSignature{
Name: "token0()",
Selector: []byte{0x0d, 0xfe, 0x16, 0x81},
AllowedTypes: []ContractType{
ContractTypeUniswapV2Pool,
ContractTypeUniswapV3Pool,
},
}
sv.signatures["token1()"] = &FunctionSignature{
Name: "token1()",
Selector: []byte{0xd2, 0x12, 0x20, 0xa7},
AllowedTypes: []ContractType{
ContractTypeUniswapV2Pool,
ContractTypeUniswapV3Pool,
},
}
sv.signatures["getReserves()"] = &FunctionSignature{
Name: "getReserves()",
Selector: []byte{0x09, 0x02, 0xf1, 0xac},
AllowedTypes: []ContractType{ContractTypeUniswapV2Pool},
}
// Uniswap V3 Pool specific functions
sv.signatures["slot0()"] = &FunctionSignature{
Name: "slot0()",
Selector: []byte{0x38, 0x50, 0xc7, 0xbd},
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
}
sv.signatures["fee()"] = &FunctionSignature{
Name: "fee()",
Selector: []byte{0xdd, 0xca, 0x3f, 0x43},
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
}
sv.signatures["liquidity()"] = &FunctionSignature{
Name: "liquidity()",
Selector: []byte{0x1a, 0x68, 0x65, 0x0f},
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
}
sv.signatures["tickSpacing()"] = &FunctionSignature{
Name: "tickSpacing()",
Selector: []byte{0xd0, 0xc9, 0x32, 0x07},
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
}
// Router functions
sv.signatures["WETH()"] = &FunctionSignature{
Name: "WETH()",
Selector: []byte{0xad, 0x5c, 0x46, 0x48},
AllowedTypes: []ContractType{
ContractTypeUniswapV2Router,
ContractTypeUniswapV3Router,
},
}
sv.signatures["factory()"] = &FunctionSignature{
Name: "factory()",
Selector: []byte{0xc4, 0x5a, 0x01, 0x55},
AllowedTypes: []ContractType{
ContractTypeUniswapV2Router,
ContractTypeUniswapV3Router,
},
}
}
// ValidationResult contains the result of function signature validation
type ValidationResult struct {
IsValid bool
FunctionName string
ContractType ContractType
Error error
Warnings []string
}
// ValidateFunctionCall validates if a function can be called on a contract
func (sv *SignatureValidator) ValidateFunctionCall(ctx context.Context, contractAddress common.Address, functionSelector []byte) *ValidationResult {
result := &ValidationResult{
IsValid: false,
Warnings: []string{},
}
// Detect contract type
detection := sv.detector.DetectContractType(ctx, contractAddress)
result.ContractType = detection.ContractType
if detection.Error != nil {
result.Error = fmt.Errorf("contract type detection failed: %w", detection.Error)
return result
}
// Find matching function signature
var matchedSignature *FunctionSignature
for _, sig := range sv.signatures {
if len(sig.Selector) >= 4 && len(functionSelector) >= 4 {
if sig.Selector[0] == functionSelector[0] &&
sig.Selector[1] == functionSelector[1] &&
sig.Selector[2] == functionSelector[2] &&
sig.Selector[3] == functionSelector[3] {
matchedSignature = sig
result.FunctionName = sig.Name
break
}
}
}
// If no signature match found, warn but allow (could be unknown function)
if matchedSignature == nil {
result.IsValid = true
result.Warnings = append(result.Warnings, fmt.Sprintf("unknown function selector: %x", functionSelector))
return result
}
// Check if the detected contract type is allowed for this function
allowed := false
for _, allowedType := range matchedSignature.AllowedTypes {
if detection.ContractType == allowedType {
allowed = true
break
}
}
if !allowed {
result.Error = fmt.Errorf("function %s cannot be called on contract type %s (allowed types: %v)",
matchedSignature.Name, detection.ContractType.String(), matchedSignature.AllowedTypes)
return result
}
// Check confidence level
if detection.Confidence < 0.7 {
result.Warnings = append(result.Warnings, fmt.Sprintf("low confidence in contract type detection: %.2f", detection.Confidence))
}
result.IsValid = true
return result
}
// ValidateToken0Call specifically validates token0() function calls
func (sv *SignatureValidator) ValidateToken0Call(ctx context.Context, contractAddress common.Address) *ValidationResult {
token0Selector := []byte{0x0d, 0xfe, 0x16, 0x81}
return sv.ValidateFunctionCall(ctx, contractAddress, token0Selector)
}
// ValidateToken1Call specifically validates token1() function calls
func (sv *SignatureValidator) ValidateToken1Call(ctx context.Context, contractAddress common.Address) *ValidationResult {
token1Selector := []byte{0xd2, 0x12, 0x20, 0xa7}
return sv.ValidateFunctionCall(ctx, contractAddress, token1Selector)
}
// ValidateGetReservesCall specifically validates getReserves() function calls
func (sv *SignatureValidator) ValidateGetReservesCall(ctx context.Context, contractAddress common.Address) *ValidationResult {
getReservesSelector := []byte{0x09, 0x02, 0xf1, 0xac}
return sv.ValidateFunctionCall(ctx, contractAddress, getReservesSelector)
}
// ValidateSlot0Call specifically validates slot0() function calls for Uniswap V3
func (sv *SignatureValidator) ValidateSlot0Call(ctx context.Context, contractAddress common.Address) *ValidationResult {
slot0Selector := []byte{0x38, 0x50, 0xc7, 0xbd}
return sv.ValidateFunctionCall(ctx, contractAddress, slot0Selector)
}
// GetSupportedFunctions returns the functions supported by a contract type
func (sv *SignatureValidator) GetSupportedFunctions(contractType ContractType) []string {
var functions []string
for _, sig := range sv.signatures {
for _, allowedType := range sig.AllowedTypes {
if allowedType == contractType {
functions = append(functions, sig.Name)
break
}
}
}
return functions
}

View File

@@ -0,0 +1,484 @@
package logger
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"time"
pkgerrors "github.com/fraktal/mev-beta/pkg/errors"
)
// LogLevel represents different log levels
type LogLevel int
const (
DEBUG LogLevel = iota
INFO
WARN
ERROR
OPPORTUNITY // Special level for opportunities
)
var logLevelNames = map[LogLevel]string{
DEBUG: "DEBUG",
INFO: "INFO",
WARN: "WARN",
ERROR: "ERROR",
OPPORTUNITY: "OPPORTUNITY",
}
var suppressedWarningSubstrings = []string{
"extractTokensGeneric",
"extractTokensFromMulticall",
}
// Logger represents a multi-file logger with separation of concerns
type Logger struct {
// Main application logger
logger *log.Logger
level LogLevel
// Specialized loggers for different concerns
opportunityLogger *log.Logger // MEV opportunities and arbitrage attempts
errorLogger *log.Logger // Errors and warnings only
performanceLogger *log.Logger // Performance metrics and RPC calls
transactionLogger *log.Logger // Detailed transaction analysis
// Security filtering
secureFilter *SecureFilter
levelName string
}
// parseLogLevel converts string log level to LogLevel enum
func parseLogLevel(level string) LogLevel {
switch strings.ToLower(level) {
case "debug":
return DEBUG
case "info":
return INFO
case "warn", "warning":
return WARN
case "error":
return ERROR
default:
return INFO // Default to INFO level
}
}
// createLogFile creates a log file or returns stdout if it fails
func createLogFile(filename string) *os.File {
if filename == "" {
return os.Stdout
}
if err := os.MkdirAll(filepath.Dir(filename), 0o755); err != nil {
log.Printf("Failed to create log directory for %s: %v, falling back to stdout", filename, err)
return os.Stdout
}
// Check and rotate log file if needed (100MB max size)
maxSize := int64(100 * 1024 * 1024) // 100 MB
if err := rotateLogFile(filename, maxSize); err != nil {
log.Printf("Failed to rotate log file %s: %v", filename, err)
// Continue anyway, rotation failure shouldn't stop logging
}
f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
log.Printf("Failed to create log file %s: %v, falling back to stdout", filename, err)
return os.Stdout
}
return f
}
// New creates a new multi-file logger with separation of concerns
func New(level string, format string, file string) *Logger {
// Parse base filename for specialized logs
baseDir := "logs"
baseName := "mev_bot"
if file != "" {
// Extract directory and base filename
parts := strings.Split(file, "/")
if len(parts) > 1 {
baseDir = strings.Join(parts[:len(parts)-1], "/")
}
filename := parts[len(parts)-1]
if strings.Contains(filename, ".") {
baseName = strings.Split(filename, ".")[0]
}
}
// Create specialized log files
mainFile := createLogFile(file)
opportunityFile := createLogFile(fmt.Sprintf("%s/%s_opportunities.log", baseDir, baseName))
errorFile := createLogFile(fmt.Sprintf("%s/%s_errors.log", baseDir, baseName))
performanceFile := createLogFile(fmt.Sprintf("%s/%s_performance.log", baseDir, baseName))
transactionFile := createLogFile(fmt.Sprintf("%s/%s_transactions.log", baseDir, baseName))
// Create loggers with no prefixes (we format ourselves)
logLevel := parseLogLevel(level)
// Determine security level based on environment and log level
var securityLevel SecurityLevel
env := os.Getenv("GO_ENV")
switch {
case env == "production":
securityLevel = SecurityLevelProduction
case logLevel >= WARN:
securityLevel = SecurityLevelInfo
default:
securityLevel = SecurityLevelDebug
}
return &Logger{
logger: log.New(mainFile, "", 0),
opportunityLogger: log.New(opportunityFile, "", 0),
errorLogger: log.New(errorFile, "", 0),
performanceLogger: log.New(performanceFile, "", 0),
transactionLogger: log.New(transactionFile, "", 0),
level: logLevel,
secureFilter: NewSecureFilter(securityLevel),
levelName: level,
}
}
// shouldLog determines if a message should be logged based on level
func (l *Logger) shouldLog(level LogLevel) bool {
return level >= l.level
}
// formatMessage formats a log message with timestamp and level
func (l *Logger) formatMessage(level LogLevel, v ...interface{}) string {
timestamp := time.Now().Format("2006/01/02 15:04:05")
levelName := logLevelNames[level]
message := formatKVMessage(v...)
return fmt.Sprintf("%s [%s] %s", timestamp, levelName, message)
}
// formatKVMessage converts a variadic list of arguments into a structured log string.
// It treats consecutive key/value pairs (string key followed by any value) specially
// so that existing logger calls like logger.Error("msg", "key", value) render as
// `msg key=value`.
func formatKVMessage(args ...interface{}) string {
if len(args) == 0 {
return ""
}
var b strings.Builder
// Always print the first argument verbatim to preserve legacy formatting.
fmt.Fprintf(&b, "%v", args[0])
// Process subsequent arguments as key/value pairs where possible.
for i := 1; i < len(args); i++ {
key, ok := args[i].(string)
if !ok || i == len(args)-1 {
// Not a key/value pair, fall back to simple spacing.
fmt.Fprintf(&b, " %v", args[i])
continue
}
value := args[i+1]
fmt.Fprintf(&b, " %s=%v", key, value)
i++
}
return b.String()
}
// Debug logs a debug message
func (l *Logger) Debug(v ...interface{}) {
if l.shouldLog(DEBUG) {
l.logger.Println(l.formatMessage(DEBUG, v...))
}
}
// Info logs an info message
func (l *Logger) Info(v ...interface{}) {
if l.shouldLog(INFO) {
l.logger.Println(l.formatMessage(INFO, v...))
}
}
// Warn logs a warning message
func (l *Logger) Warn(v ...interface{}) {
if l.shouldLog(WARN) {
message := l.formatMessage(WARN, v...)
for _, substr := range suppressedWarningSubstrings {
if strings.Contains(message, substr) {
return
}
}
l.logger.Println(message)
l.errorLogger.Println(message) // Also log to error file
}
}
// Error logs an error message
func (l *Logger) Error(v ...interface{}) {
if l.shouldLog(ERROR) {
message := l.formatMessage(ERROR, v...)
l.logger.Println(message)
l.errorLogger.Println(message) // Also log to error file
}
}
// ErrorStructured logs a structured error with full context
func (l *Logger) ErrorStructured(err *pkgerrors.StructuredError) {
if !l.shouldLog(ERROR) {
return
}
// Log compact format to main log
compactMsg := fmt.Sprintf("%s [%s] %s",
time.Now().Format("2006/01/02 15:04:05"),
"ERROR",
err.FormatCompact())
l.logger.Println(compactMsg)
// Log full detailed format to error log
fullMsg := fmt.Sprintf("%s [%s] %s",
time.Now().Format("2006/01/02 15:04:05"),
"ERROR",
err.FormatForLogging())
l.errorLogger.Println(fullMsg)
}
// WarnStructured logs a structured warning with full context
func (l *Logger) WarnStructured(err *pkgerrors.StructuredError) {
if !l.shouldLog(WARN) {
return
}
// Log compact format to main log
compactMsg := fmt.Sprintf("%s [%s] %s",
time.Now().Format("2006/01/02 15:04:05"),
"WARN",
err.FormatCompact())
// Check if warning should be suppressed
for _, substr := range suppressedWarningSubstrings {
if strings.Contains(compactMsg, substr) {
return
}
}
l.logger.Println(compactMsg)
l.errorLogger.Println(compactMsg)
}
// Opportunity logs a found opportunity with detailed information
// This always logs regardless of level since opportunities are critical
func (l *Logger) Opportunity(txHash, from, to, method, protocol string, amountIn, amountOut, minOut, profitUSD float64, additionalData map[string]interface{}) {
timestamp := time.Now().Format("2006/01/02 15:04:05")
// Create sanitized additional data for production
sanitizedData := l.secureFilter.SanitizeForProduction(additionalData)
message := fmt.Sprintf(`%s [OPPORTUNITY] 🎯 ARBITRAGE OPPORTUNITY DETECTED
├── Transaction: %s
├── From: %s → To: %s
├── Method: %s (%s)
├── Amount In: %.6f tokens
├── Amount Out: %.6f tokens
├── Min Out: %.6f tokens
├── Estimated Profit: $%.2f USD
└── Additional Data: %v`,
timestamp, txHash, from, to, method, protocol,
amountIn, amountOut, minOut, profitUSD, sanitizedData)
// Apply security filtering to the entire message
filteredMessage := l.secureFilter.FilterMessage(message)
l.logger.Println(filteredMessage)
l.opportunityLogger.Println(filteredMessage) // Dedicated opportunity log
}
// OpportunitySimple logs a simple opportunity message (for backwards compatibility)
func (l *Logger) OpportunitySimple(v ...interface{}) {
timestamp := time.Now().Format("2006/01/02 15:04:05")
message := fmt.Sprintf("%s [OPPORTUNITY] %s", timestamp, fmt.Sprint(v...))
l.logger.Println(message)
l.opportunityLogger.Println(message) // Dedicated opportunity log
}
// Performance logs performance metrics for optimization analysis
func (l *Logger) Performance(component, operation string, duration time.Duration, metadata map[string]interface{}) {
timestamp := time.Now().Format("2006/01/02 15:04:05")
// Add standard performance fields
data := map[string]interface{}{
"component": component,
"operation": operation,
"duration_ms": duration.Milliseconds(),
"duration_ns": duration.Nanoseconds(),
"timestamp": timestamp,
}
// Merge with provided metadata
for k, v := range metadata {
data[k] = v
}
message := fmt.Sprintf(`%s [PERFORMANCE] 📊 %s.%s completed in %v - %v`,
timestamp, component, operation, duration, data)
l.performanceLogger.Println(message) // Dedicated performance log only
}
// Metrics logs business metrics for analysis
func (l *Logger) Metrics(name string, value float64, unit string, tags map[string]string) {
timestamp := time.Now().Format("2006/01/02 15:04:05")
message := fmt.Sprintf(`%s [METRICS] 📈 %s: %.6f %s %v`,
timestamp, name, value, unit, tags)
l.performanceLogger.Println(message) // Metrics go to performance log
}
// Transaction logs detailed transaction information for MEV analysis
func (l *Logger) Transaction(txHash, from, to, method, protocol string, gasUsed, gasPrice uint64, value float64, success bool, metadata map[string]interface{}) {
timestamp := time.Now().Format("2006/01/02 15:04:05")
status := "FAILED"
if success {
status = "SUCCESS"
}
// Sanitize metadata for production
sanitizedMetadata := l.secureFilter.SanitizeForProduction(metadata)
message := fmt.Sprintf(`%s [TRANSACTION] 💳 %s
├── Hash: %s
├── From: %s → To: %s
├── Method: %s (%s)
├── Gas Used: %d (Price: %d wei)
├── Value: %.6f ETH
├── Status: %s
└── Metadata: %v`,
timestamp, status, txHash, from, to, method, protocol,
gasUsed, gasPrice, value, status, sanitizedMetadata)
// Apply security filtering to the entire message
filteredMessage := l.secureFilter.FilterMessage(message)
l.transactionLogger.Println(filteredMessage) // Dedicated transaction log only
}
// BlockProcessing logs block processing metrics for sequencer monitoring
func (l *Logger) BlockProcessing(blockNumber uint64, txCount, dexTxCount int, processingTime time.Duration) {
timestamp := time.Now().Format("2006/01/02 15:04:05")
message := fmt.Sprintf(`%s [BLOCK_PROCESSING] 🧱 Block %d: %d txs (%d DEX) processed in %v`,
timestamp, blockNumber, txCount, dexTxCount, processingTime)
l.performanceLogger.Println(message) // Block processing metrics go to performance log
}
// ArbitrageAnalysis logs arbitrage opportunity analysis results
func (l *Logger) ArbitrageAnalysis(poolA, poolB, tokenPair string, priceA, priceB, priceDiff, estimatedProfit float64, feasible bool) {
timestamp := time.Now().Format("2006/01/02 15:04:05")
status := "REJECTED"
if feasible {
status = "VIABLE"
}
message := fmt.Sprintf(`%s [ARBITRAGE_ANALYSIS] 🔍 %s %s
├── Pool A: %s (Price: %.6f)
├── Pool B: %s (Price: %.6f)
├── Price Difference: %.4f%%
├── Estimated Profit: $%.2f
└── Status: %s`,
timestamp, status, tokenPair, poolA, priceA, poolB, priceB,
priceDiff*100, estimatedProfit, status)
// Apply security filtering to protect sensitive pricing data
filteredMessage := l.secureFilter.FilterMessage(message)
l.opportunityLogger.Println(filteredMessage) // Arbitrage analysis goes to opportunity log
}
// RPC logs RPC call metrics for endpoint optimization
func (l *Logger) RPC(endpoint, method string, duration time.Duration, success bool, errorMsg string) {
timestamp := time.Now().Format("2006/01/02 15:04:05")
status := "SUCCESS"
if !success {
status = "FAILED"
}
message := fmt.Sprintf(`%s [RPC] 🌐 %s %s.%s in %v`,
timestamp, status, endpoint, method, duration)
if !success && errorMsg != "" {
message += fmt.Sprintf(" - Error: %s", errorMsg)
}
l.performanceLogger.Println(message) // RPC metrics go to performance log
}
// SwapAnalysis logs swap event analysis with security filtering
func (l *Logger) SwapAnalysis(tokenIn, tokenOut string, amountIn, amountOut float64, protocol, poolAddr string, metadata map[string]interface{}) {
timestamp := time.Now().Format("2006/01/02 15:04:05")
// Sanitize metadata for production
sanitizedMetadata := l.secureFilter.SanitizeForProduction(metadata)
message := fmt.Sprintf(`%s [SWAP_ANALYSIS] 🔄 %s → %s
├── Amount In: %.6f %s
├── Amount Out: %.6f %s
├── Protocol: %s
├── Pool: %s
└── Metadata: %v`,
timestamp, tokenIn, tokenOut, amountIn, tokenIn, amountOut, tokenOut,
protocol, poolAddr, sanitizedMetadata)
// Apply security filtering to the entire message
filteredMessage := l.secureFilter.FilterMessage(message)
l.transactionLogger.Println(filteredMessage) // Dedicated transaction log
}
// rotateLogFile rotates a log file when it exceeds the maximum size
func rotateLogFile(filename string, maxSize int64) error {
// Check if file exists
if _, err := os.Stat(filename); os.IsNotExist(err) {
return nil // File doesn't exist, nothing to rotate
}
// Get file info
fileInfo, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("failed to get file info: %w", err)
}
// Check if file exceeds max size
if fileInfo.Size() < maxSize {
return nil // File is within size limits
}
// Create archive directory if it doesn't exist
archiveDir := "logs/archived"
if err := os.MkdirAll(archiveDir, 0755); err != nil {
return fmt.Errorf("failed to create archive directory: %w", err)
}
// Generate archive filename with timestamp
timestamp := time.Now().Format("20060102_150405")
baseName := filepath.Base(filename)
ext := filepath.Ext(baseName)
name := strings.TrimSuffix(baseName, ext)
archiveFilename := filepath.Join(archiveDir, fmt.Sprintf("%s_%s%s", name, timestamp, ext))
// Close current file handle and rename
if err := os.Rename(filename, archiveFilename); err != nil {
return fmt.Errorf("failed to rotate log file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,243 @@
package logger
import (
"bytes"
"io"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewLogger(t *testing.T) {
// Test creating a logger with stdout
logger := New("info", "text", "")
assert.NotNil(t, logger)
assert.NotNil(t, logger.logger)
assert.Equal(t, "info", logger.levelName)
}
func TestNewLoggerWithFile(t *testing.T) {
// Create a temporary file for testing
tmpFile, err := os.CreateTemp("", "logger_test_*.log")
assert.NoError(t, err)
defer os.Remove(tmpFile.Name())
err = tmpFile.Close()
assert.NoError(t, err)
// Test creating a logger with a file
logger := New("info", "text", tmpFile.Name())
assert.NotNil(t, logger)
assert.Equal(t, "info", logger.levelName)
}
func TestDebug(t *testing.T) {
// Capture stdout
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// Create logger with debug level
logger := New("debug", "text", "")
// Log a debug message
logger.Debug("test debug message")
// Restore stdout
w.Close()
os.Stdout = old
// Read the captured output
var buf bytes.Buffer
io.Copy(&buf, r)
output := buf.String()
// Check that the log message contains the expected content with brackets
assert.Contains(t, output, "[DEBUG] test debug message")
}
func TestDebugWithInfoLevel(t *testing.T) {
// Capture stdout
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// Create logger with info level (should not log debug messages)
logger := New("info", "text", "")
// Log a debug message
logger.Debug("test debug message")
// Restore stdout
w.Close()
os.Stdout = old
// Read the captured output
var buf bytes.Buffer
io.Copy(&buf, r)
output := buf.String()
// Verify the output does not contain the debug message
assert.NotContains(t, output, "DEBUG:")
assert.NotContains(t, output, "test debug message")
}
func TestInfo(t *testing.T) {
// Capture stdout
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// Create logger with info level
logger := New("info", "text", "")
// Log an info message
logger.Info("test info message")
// Restore stdout
w.Close()
os.Stdout = old
// Read the captured output
var buf bytes.Buffer
io.Copy(&buf, r)
output := buf.String()
// Verify the output contains the info message
assert.Contains(t, output, "[INFO]")
assert.Contains(t, output, "test info message")
}
func TestInfoWithDebugLevel(t *testing.T) {
// Capture stdout
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// Create logger with debug level
logger := New("debug", "text", "")
// Log an info message
logger.Info("test info message")
// Restore stdout
w.Close()
os.Stdout = old
// Read the captured output
var buf bytes.Buffer
io.Copy(&buf, r)
output := buf.String()
// Verify the output contains the info message
assert.Contains(t, output, "[INFO]")
assert.Contains(t, output, "test info message")
}
func TestWarn(t *testing.T) {
// Capture stdout
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// Create logger with warn level
logger := New("warn", "text", "")
// Log a warning message
logger.Warn("test warn message")
// Restore stdout
w.Close()
os.Stdout = old
// Read the captured output
var buf bytes.Buffer
io.Copy(&buf, r)
output := buf.String()
// Verify the output contains the warning message
assert.Contains(t, output, "[WARN]")
assert.Contains(t, output, "test warn message")
}
func TestWarnWithInfoLevel(t *testing.T) {
// Capture stdout
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// Create logger with info level (should log warnings)
logger := New("info", "text", "")
// Log a warning message
logger.Warn("test warn message")
// Restore stdout
w.Close()
os.Stdout = old
// Read the captured output
var buf bytes.Buffer
io.Copy(&buf, r)
output := buf.String()
// Verify the output contains the warning message
assert.Contains(t, output, "[WARN]")
assert.Contains(t, output, "test warn message")
}
func TestError(t *testing.T) {
// Capture stdout
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// Create logger
logger := New("error", "text", "")
// Log an error message
logger.Error("test error message")
// Restore stdout
w.Close()
os.Stdout = old
// Read the captured output
var buf bytes.Buffer
io.Copy(&buf, r)
output := buf.String()
// Verify the output contains the error message
assert.Contains(t, output, "[ERROR]")
assert.Contains(t, output, "test error message")
}
func TestErrorWithAllLevels(t *testing.T) {
// Test that error messages are logged at all levels
levels := []string{"debug", "info", "warn", "error"}
for _, level := range levels {
// Capture stdout
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// Create logger with current level
logger := New(level, "text", "")
// Log an error message
logger.Error("test error message")
// Restore stdout
w.Close()
os.Stdout = old
// Read the captured output
var buf bytes.Buffer
io.Copy(&buf, r)
output := buf.String()
// Verify the output contains the error message
assert.Contains(t, output, "[ERROR]")
assert.Contains(t, output, "test error message")
}
}

View File

@@ -0,0 +1,241 @@
package logger
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"time"
)
// FilterMessageEnhanced provides comprehensive filtering with audit logging
func (sf *SecureFilter) FilterMessageEnhanced(message string, context map[string]interface{}) string {
originalMessage := message
filtered := sf.FilterMessage(message)
// Audit sensitive data detection if enabled
if sf.auditEnabled {
auditData := sf.detectSensitiveData(originalMessage, context)
if len(auditData) > 0 {
sf.logAuditEvent(auditData)
}
}
return filtered
}
// detectSensitiveData identifies and catalogs sensitive data found in messages
func (sf *SecureFilter) detectSensitiveData(message string, context map[string]interface{}) map[string]interface{} {
detected := make(map[string]interface{})
detected["timestamp"] = time.Now().UTC().Format(time.RFC3339)
detected["security_level"] = sf.level
if context != nil {
detected["context"] = context
}
// Check for different types of sensitive data
sensitiveTypes := []string{}
// Check for private keys (CRITICAL)
for _, pattern := range sf.privateKeyPatterns {
if pattern.MatchString(message) {
sensitiveTypes = append(sensitiveTypes, "private_key")
detected["severity"] = "CRITICAL"
break
}
}
// Check for transaction hashes BEFORE addresses (64 chars vs 40 chars)
for _, pattern := range sf.hashPatterns {
if pattern.MatchString(message) {
sensitiveTypes = append(sensitiveTypes, "transaction_hash")
if detected["severity"] == nil {
detected["severity"] = "LOW"
}
break
}
}
// Check for addresses AFTER hashes
for _, pattern := range sf.addressPatterns {
if pattern.MatchString(message) {
sensitiveTypes = append(sensitiveTypes, "address")
if detected["severity"] == nil {
detected["severity"] = "MEDIUM"
}
break
}
}
// Check for amounts/values
for _, pattern := range sf.amountPatterns {
if pattern.MatchString(message) {
sensitiveTypes = append(sensitiveTypes, "amount")
if detected["severity"] == nil {
detected["severity"] = "LOW"
}
break
}
}
if len(sensitiveTypes) > 0 {
detected["types"] = sensitiveTypes
detected["message_length"] = len(message)
detected["filtered_length"] = len(sf.FilterMessage(message))
return detected
}
return nil
}
// logAuditEvent logs sensitive data detection events
func (sf *SecureFilter) logAuditEvent(auditData map[string]interface{}) {
// Create audit log entry
auditEntry := map[string]interface{}{
"event_type": "sensitive_data_detected",
"timestamp": auditData["timestamp"],
"severity": auditData["severity"],
"types": auditData["types"],
"message_stats": map[string]interface{}{
"original_length": auditData["message_length"],
"filtered_length": auditData["filtered_length"],
},
}
if auditData["context"] != nil {
auditEntry["context"] = auditData["context"]
}
// Encrypt audit data if encryption is enabled
if sf.auditEncryption && len(sf.encryptionKey) > 0 {
encrypted, err := sf.encryptAuditData(auditEntry)
if err == nil {
auditEntry = map[string]interface{}{
"encrypted": true,
"data": encrypted,
"timestamp": auditData["timestamp"],
}
}
}
// Log to audit trail (this would typically go to a separate audit log file)
// For now, we'll add it to a structured format that can be easily extracted
auditJSON, _ := json.Marshal(auditEntry)
fmt.Printf("AUDIT_LOG: %s\n", string(auditJSON))
}
// encryptAuditData encrypts sensitive audit data
func (sf *SecureFilter) encryptAuditData(data map[string]interface{}) (string, error) {
if len(sf.encryptionKey) == 0 {
return "", fmt.Errorf("no encryption key available")
}
// Serialize data to JSON
jsonData, err := json.Marshal(data)
if err != nil {
return "", fmt.Errorf("failed to marshal audit data: %w", err)
}
// Create AES-GCM cipher (AEAD - authenticated encryption)
key := sha256.Sum256(sf.encryptionKey)
block, err := aes.NewCipher(key[:])
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM instance
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
// Generate random nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt and authenticate data
encrypted := gcm.Seal(nonce, nonce, jsonData, nil)
return hex.EncodeToString(encrypted), nil
}
// DecryptAuditData decrypts audit data (for authorized access)
func (sf *SecureFilter) DecryptAuditData(encryptedHex string) (map[string]interface{}, error) {
if len(sf.encryptionKey) == 0 {
return nil, fmt.Errorf("no encryption key available")
}
// Decode hex
encryptedData, err := hex.DecodeString(encryptedHex)
if err != nil {
return nil, fmt.Errorf("failed to decode hex: %w", err)
}
// Create AES-GCM cipher (AEAD - authenticated encryption)
key := sha256.Sum256(sf.encryptionKey)
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM instance
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Check minimum length (nonce + encrypted data + tag)
if len(encryptedData) < gcm.NonceSize() {
return nil, fmt.Errorf("encrypted data too short")
}
// Extract nonce and encrypted data
nonce := encryptedData[:gcm.NonceSize()]
ciphertext := encryptedData[gcm.NonceSize():]
// Decrypt and authenticate data
decrypted, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt data: %w", err)
}
// Unmarshal JSON
var result map[string]interface{}
err = json.Unmarshal(decrypted, &result)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal decrypted data: %w", err)
}
return result, nil
}
// EnableAuditLogging enables audit logging with optional encryption
func (sf *SecureFilter) EnableAuditLogging(encryptionKey []byte) {
sf.auditEnabled = true
if len(encryptionKey) > 0 {
sf.encryptionKey = encryptionKey
sf.auditEncryption = true
}
}
// DisableAuditLogging disables audit logging
func (sf *SecureFilter) DisableAuditLogging() {
sf.auditEnabled = false
sf.auditEncryption = false
}
// SetSecurityLevel changes the security level dynamically
func (sf *SecureFilter) SetSecurityLevel(level SecurityLevel) {
sf.level = level
}
// GetSecurityLevel returns the current security level
func (sf *SecureFilter) GetSecurityLevel() SecurityLevel {
return sf.level
}

View File

@@ -0,0 +1,301 @@
package logger
import (
"math/big"
"regexp"
"strings"
"github.com/ethereum/go-ethereum/common"
)
// SecurityLevel defines the logging security level
type SecurityLevel int
const (
SecurityLevelDebug SecurityLevel = iota // Log everything (development only)
SecurityLevelInfo // Log basic info, filter amounts
SecurityLevelProduction // Log minimal info, filter sensitive data
)
// SecureFilter filters sensitive data from log messages
type SecureFilter struct {
level SecurityLevel
// Patterns to detect sensitive data
amountPatterns []*regexp.Regexp
addressPatterns []*regexp.Regexp
valuePatterns []*regexp.Regexp
hashPatterns []*regexp.Regexp
privateKeyPatterns []*regexp.Regexp
encryptionKey []byte
auditEnabled bool
auditEncryption bool
}
// SecureFilterConfig contains configuration for the secure filter
type SecureFilterConfig struct {
Level SecurityLevel
EncryptionKey []byte
AuditEnabled bool
AuditEncryption bool
}
// NewSecureFilter creates a new secure filter with enhanced configuration
func NewSecureFilter(level SecurityLevel) *SecureFilter {
return NewSecureFilterWithConfig(&SecureFilterConfig{
Level: level,
AuditEnabled: false,
AuditEncryption: false,
})
}
// NewSecureFilterWithConfig creates a new secure filter with full configuration
func NewSecureFilterWithConfig(config *SecureFilterConfig) *SecureFilter {
return &SecureFilter{
level: config.Level,
encryptionKey: config.EncryptionKey,
auditEnabled: config.AuditEnabled,
auditEncryption: config.AuditEncryption,
amountPatterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)amount[^=]*=\s*[0-9]+`),
regexp.MustCompile(`(?i)value[^=]*=\s*[0-9]+`),
regexp.MustCompile(`\$[0-9]+\.?[0-9]*`),
regexp.MustCompile(`[0-9]+\.[0-9]+ USD`),
regexp.MustCompile(`(?i)amountIn=[0-9]+`),
regexp.MustCompile(`(?i)amountOut=[0-9]+`),
regexp.MustCompile(`(?i)balance[^=]*=\s*[0-9]+`),
regexp.MustCompile(`(?i)profit[^=]*=\s*[0-9]+`),
regexp.MustCompile(`(?i)gas[Pp]rice[^=]*=\s*[0-9]+`),
regexp.MustCompile(`\b[0-9]{15,}\b`), // Very large numbers likely to be wei amounts (but not hex addresses)
},
addressPatterns: []*regexp.Regexp{
regexp.MustCompile(`0x[a-fA-F0-9]{40}`),
regexp.MustCompile(`(?i)address[^=]*=\s*0x[a-fA-F0-9]{40}`),
regexp.MustCompile(`(?i)contract[^=]*=\s*0x[a-fA-F0-9]{40}`),
regexp.MustCompile(`(?i)token[^=]*=\s*0x[a-fA-F0-9]{40}`),
},
valuePatterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)value:\s*\$[0-9]+\.?[0-9]*`),
regexp.MustCompile(`(?i)profit[^=]*=\s*\$?[0-9]+\.?[0-9]*`),
regexp.MustCompile(`(?i)total:\s*\$[0-9]+\.?[0-9]*`),
regexp.MustCompile(`(?i)revenue[^=]*=\s*\$?[0-9]+\.?[0-9]*`),
regexp.MustCompile(`(?i)fee[^=]*=\s*\$?[0-9]+\.?[0-9]*`),
},
hashPatterns: []*regexp.Regexp{
regexp.MustCompile(`0x[a-fA-F0-9]{64}`), // Transaction hashes
regexp.MustCompile(`(?i)txHash[^=]*=\s*0x[a-fA-F0-9]{64}`),
regexp.MustCompile(`(?i)blockHash[^=]*=\s*0x[a-fA-F0-9]{64}`),
},
privateKeyPatterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)private[_\s]*key[^=]*=\s*[a-fA-F0-9]{64}`),
regexp.MustCompile(`(?i)secret[^=]*=\s*[a-fA-F0-9]{64}`),
regexp.MustCompile(`(?i)mnemonic[^=]*=\s*\"[^\"]*\"`),
regexp.MustCompile(`(?i)seed[^=]*=\s*\"[^\"]*\"`),
},
}
}
// FilterMessage filters sensitive data from a log message with enhanced input sanitization
func (sf *SecureFilter) FilterMessage(message string) string {
if sf.level == SecurityLevelDebug {
return sf.sanitizeInput(message) // Still sanitize for security even in debug mode
}
filtered := sf.sanitizeInput(message)
// Filter private keys FIRST (highest security priority)
for _, pattern := range sf.privateKeyPatterns {
filtered = pattern.ReplaceAllString(filtered, "[PRIVATE_KEY_FILTERED]")
}
// Filter transaction hashes
if sf.level >= SecurityLevelInfo {
for _, pattern := range sf.hashPatterns {
filtered = pattern.ReplaceAllStringFunc(filtered, func(hash string) string {
if len(hash) == 66 { // Full transaction hash
return hash[:10] + "..." + hash[62:] // Show first 8 and last 4 chars
}
return "[HASH_FILTERED]"
})
}
}
// Filter addresses NEXT (before amounts) to prevent addresses from being treated as numbers
if sf.level >= SecurityLevelProduction {
for _, pattern := range sf.addressPatterns {
filtered = pattern.ReplaceAllStringFunc(filtered, func(addr string) string {
if len(addr) == 42 { // Full Ethereum address
return addr[:6] + "..." + addr[38:] // Show first 4 and last 4 chars
}
return "[ADDR_FILTERED]"
})
}
}
// Filter amounts LAST
if sf.level >= SecurityLevelInfo {
for _, pattern := range sf.amountPatterns {
filtered = pattern.ReplaceAllString(filtered, "[AMOUNT_FILTERED]")
}
for _, pattern := range sf.valuePatterns {
filtered = pattern.ReplaceAllString(filtered, "[VALUE_FILTERED]")
}
}
return filtered
}
// sanitizeInput performs comprehensive input sanitization for log messages
func (sf *SecureFilter) sanitizeInput(input string) string {
// Remove null bytes and other control characters that could cause issues
sanitized := strings.ReplaceAll(input, "\x00", "")
sanitized = strings.ReplaceAll(sanitized, "\r", "")
// Remove potential log injection patterns
sanitized = strings.ReplaceAll(sanitized, "\n", " ") // Replace newlines with spaces
sanitized = strings.ReplaceAll(sanitized, "\t", " ") // Replace tabs with spaces
// Remove ANSI escape codes that could interfere with log parsing
ansiPattern := regexp.MustCompile(`\x1b\[[0-9;]*[a-zA-Z]`)
sanitized = ansiPattern.ReplaceAllString(sanitized, "")
// Remove other control characters (keep only printable ASCII and common Unicode)
controlPattern := regexp.MustCompile(`[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]`)
sanitized = controlPattern.ReplaceAllString(sanitized, "")
// Prevent log injection by escaping special characters
sanitized = strings.ReplaceAll(sanitized, "%", "%%") // Escape printf format specifiers
// Limit message length to prevent DoS via large log messages
const maxLogMessageLength = 4096
if len(sanitized) > maxLogMessageLength {
sanitized = sanitized[:maxLogMessageLength-3] + "..."
}
return sanitized
}
// FilterTransactionData provides enhanced filtering for transaction data logging
func (sf *SecureFilter) FilterTransactionData(txHash, from, to string, value, gasPrice *big.Int, data []byte) map[string]interface{} {
result := map[string]interface{}{}
// Always include sanitized transaction hash
result["tx_hash"] = sf.sanitizeInput(txHash)
switch sf.level {
case SecurityLevelDebug:
result["from"] = sf.sanitizeInput(from)
result["to"] = sf.sanitizeInput(to)
if value != nil {
result["value"] = value.String()
}
if gasPrice != nil {
result["gas_price"] = gasPrice.String()
}
result["data_size"] = len(data)
case SecurityLevelInfo:
result["from"] = sf.shortenAddress(common.HexToAddress(from))
result["to"] = sf.shortenAddress(common.HexToAddress(to))
if value != nil {
result["value_range"] = sf.getAmountRange(value)
}
if gasPrice != nil {
result["gas_price_range"] = sf.getAmountRange(gasPrice)
}
result["data_size"] = len(data)
case SecurityLevelProduction:
result["has_from"] = from != ""
result["has_to"] = to != ""
result["has_value"] = value != nil && value.Sign() > 0
result["data_size"] = len(data)
}
return result
}
// FilterSwapData creates a safe representation of swap data for logging
func (sf *SecureFilter) FilterSwapData(tokenIn, tokenOut common.Address, amountIn, amountOut *big.Int, protocol string) map[string]interface{} {
data := map[string]interface{}{
"protocol": protocol,
}
switch sf.level {
case SecurityLevelDebug:
data["tokenIn"] = tokenIn.Hex()
data["tokenOut"] = tokenOut.Hex()
data["amountIn"] = amountIn.String()
data["amountOut"] = amountOut.String()
case SecurityLevelInfo:
data["tokenInShort"] = sf.shortenAddress(tokenIn)
data["tokenOutShort"] = sf.shortenAddress(tokenOut)
data["amountRange"] = sf.getAmountRange(amountIn)
data["amountOutRange"] = sf.getAmountRange(amountOut)
case SecurityLevelProduction:
data["tokenCount"] = 2
data["hasAmounts"] = amountIn != nil && amountOut != nil
}
return data
}
// shortenAddress returns a shortened version of an address
func (sf *SecureFilter) shortenAddress(addr common.Address) string {
hex := addr.Hex()
if len(hex) >= 10 {
return hex[:6] + "..." + hex[len(hex)-4:]
}
return "[ADDR]"
}
// getAmountRange returns a range category for an amount
func (sf *SecureFilter) getAmountRange(amount *big.Int) string {
if amount == nil {
return "nil"
}
// Convert to rough USD equivalent (assuming 18 decimals)
usdValue := new(big.Float).Quo(new(big.Float).SetInt(amount), big.NewFloat(1e18))
usdFloat, _ := usdValue.Float64()
switch {
case usdFloat < 1:
return "micro"
case usdFloat < 100:
return "small"
case usdFloat < 10000:
return "medium"
case usdFloat < 1000000:
return "large"
default:
return "whale"
}
}
// SanitizeForProduction removes all sensitive data for production logging
func (sf *SecureFilter) SanitizeForProduction(data map[string]interface{}) map[string]interface{} {
sanitized := make(map[string]interface{})
for key, value := range data {
switch strings.ToLower(key) {
case "amount", "amountin", "amountout", "value", "profit", "usd", "price":
sanitized[key] = "[FILTERED]"
case "address", "token", "tokenin", "tokenout", "pool", "contract":
if addr, ok := value.(common.Address); ok {
sanitized[key] = sf.shortenAddress(addr)
} else if str, ok := value.(string); ok && strings.HasPrefix(str, "0x") && len(str) == 42 {
sanitized[key] = str[:6] + "..." + str[38:]
} else {
sanitized[key] = value
}
default:
sanitized[key] = value
}
}
return sanitized
}

View File

@@ -0,0 +1,226 @@
package logger
import (
"crypto/rand"
"strings"
"testing"
"github.com/ethereum/go-ethereum/common"
)
func TestSecureFilterEnhanced(t *testing.T) {
tests := []struct {
name string
level SecurityLevel
message string
expectFiltered bool
expectedLevel string
}{
{
name: "Private key detection",
level: SecurityLevelProduction,
message: "private_key=1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
expectFiltered: true,
expectedLevel: "CRITICAL",
},
{
name: "Address detection",
level: SecurityLevelProduction,
message: "Swapping on address 0x1234567890123456789012345678901234567890",
expectFiltered: true,
expectedLevel: "MEDIUM",
},
{
name: "Amount detection",
level: SecurityLevelInfo,
message: "Profit amount=1000000 wei detected",
expectFiltered: true,
expectedLevel: "LOW",
},
{
name: "Transaction hash detection",
level: SecurityLevelInfo,
message: "txHash=0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
expectFiltered: true,
expectedLevel: "LOW",
},
{
name: "No sensitive data",
level: SecurityLevelProduction,
message: "Normal log message with no sensitive data",
expectFiltered: false,
expectedLevel: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &SecureFilterConfig{
Level: tt.level,
AuditEnabled: true,
AuditEncryption: false,
}
filter := NewSecureFilterWithConfig(config)
// Test detection without filtering yet
auditData := filter.detectSensitiveData(tt.message, nil)
if tt.expectFiltered {
if auditData == nil {
t.Errorf("Expected sensitive data detection for: %s", tt.message)
return
}
if auditData["severity"] != tt.expectedLevel {
t.Errorf("Expected severity %s, got %v", tt.expectedLevel, auditData["severity"])
}
} else {
if auditData != nil {
t.Errorf("Unexpected sensitive data detection for: %s", tt.message)
}
}
// Test the actual filtering
filtered := filter.FilterMessage(tt.message)
if tt.expectFiltered && filtered == tt.message {
t.Errorf("Expected message to be filtered, but it wasn't: %s", tt.message)
}
})
}
}
func TestSecureFilterWithEncryption(t *testing.T) {
// Generate a random encryption key
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
t.Fatalf("Failed to generate encryption key: %v", err)
}
config := &SecureFilterConfig{
Level: SecurityLevelProduction,
EncryptionKey: key,
AuditEnabled: true,
AuditEncryption: true,
}
filter := NewSecureFilterWithConfig(config)
testData := map[string]interface{}{
"test_field": "test_value",
"number": 123,
"nested": map[string]interface{}{
"inner": "value",
},
}
// Test encryption and decryption
encrypted, err := filter.encryptAuditData(testData)
if err != nil {
t.Fatalf("Failed to encrypt audit data: %v", err)
}
if encrypted == "" {
t.Fatal("Encrypted data should not be empty")
}
// Test decryption
decrypted, err := filter.DecryptAuditData(encrypted)
if err != nil {
t.Fatalf("Failed to decrypt audit data: %v", err)
}
// Verify data integrity
if decrypted["test_field"] != testData["test_field"] {
t.Errorf("Decrypted data doesn't match original")
}
}
func TestSecureFilterAddressFiltering(t *testing.T) {
filter := NewSecureFilter(SecurityLevelProduction)
address := common.HexToAddress("0x1234567890123456789012345678901234567890")
testMessage := "Processing transaction for address " + address.Hex()
filtered := filter.FilterMessage(testMessage)
// Should contain shortened address
if !strings.Contains(filtered, "0x1234...7890") {
t.Errorf("Expected shortened address in filtered message, got: %s", filtered)
}
}
func TestSecureFilterAmountFiltering(t *testing.T) {
filter := NewSecureFilter(SecurityLevelInfo)
testCases := []struct {
message string
contains string
}{
{"amount=1000000", "[AMOUNT_FILTERED]"},
{"Profit $123.45 detected", "[AMOUNT_FILTERED]"},
{"balance=999999999999999999", "[AMOUNT_FILTERED]"},
{"gasPrice=20000000000", "[AMOUNT_FILTERED]"},
}
for _, tc := range testCases {
filtered := filter.FilterMessage(tc.message)
if !strings.Contains(filtered, tc.contains) {
t.Errorf("Expected %s in filtered message for input '%s', got: %s", tc.contains, tc.message, filtered)
}
}
}
func TestSecureFilterConfiguration(t *testing.T) {
filter := NewSecureFilter(SecurityLevelDebug)
// Test initial level
if filter.GetSecurityLevel() != SecurityLevelDebug {
t.Errorf("Expected initial level to be Debug, got: %v", filter.GetSecurityLevel())
}
// Test level change
filter.SetSecurityLevel(SecurityLevelProduction)
if filter.GetSecurityLevel() != SecurityLevelProduction {
t.Errorf("Expected level to be Production, got: %v", filter.GetSecurityLevel())
}
// Test audit enabling
filter.EnableAuditLogging([]byte("test-key"))
if !filter.auditEnabled {
t.Error("Expected audit to be enabled")
}
if !filter.auditEncryption {
t.Error("Expected audit encryption to be enabled")
}
// Test audit disabling
filter.DisableAuditLogging()
if filter.auditEnabled {
t.Error("Expected audit to be disabled")
}
}
func BenchmarkSecureFilter(b *testing.B) {
filter := NewSecureFilter(SecurityLevelProduction)
testMessage := "Processing transaction 0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef for address 0x1111111111111111111111111111111111111111 with amount=1000000"
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.FilterMessage(testMessage)
}
}
func BenchmarkSecureFilterEnhanced(b *testing.B) {
config := &SecureFilterConfig{
Level: SecurityLevelProduction,
AuditEnabled: true,
AuditEncryption: false,
}
filter := NewSecureFilterWithConfig(config)
testMessage := "Processing transaction 0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef for address 0x1111111111111111111111111111111111111111 with amount=1000000"
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.FilterMessageEnhanced(testMessage, nil)
}
}

View File

@@ -0,0 +1,439 @@
package logger
import (
"math/big"
"regexp"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/assert"
)
func TestNewSecureFilter(t *testing.T) {
tests := []struct {
name string
level SecurityLevel
}{
{"Debug level", SecurityLevelDebug},
{"Info level", SecurityLevelInfo},
{"Production level", SecurityLevelProduction},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
filter := NewSecureFilter(tt.level)
assert.NotNil(t, filter)
assert.Equal(t, tt.level, filter.level)
assert.NotNil(t, filter.amountPatterns)
assert.NotNil(t, filter.addressPatterns)
assert.NotNil(t, filter.valuePatterns)
})
}
}
func TestFilterMessage_DebugLevel(t *testing.T) {
filter := NewSecureFilter(SecurityLevelDebug)
tests := []struct {
name string
input string
expected string
}{
{
name: "Debug shows everything",
input: "Amount: 1000.5 ETH, Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789, Value: $5000.00",
expected: "Amount: 1000.5 ETH, Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789, Value: $5000.00",
},
{
name: "Large amounts shown in debug",
input: "Profit: 999999.123456789 USDC",
expected: "Profit: 999999.123456789 USDC",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filter.FilterMessage(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestFilterMessage_InfoLevel(t *testing.T) {
filter := NewSecureFilter(SecurityLevelInfo)
tests := []struct {
name string
input string
expected string
}{
{
name: "Info filters amounts but shows full addresses",
input: "Amount: 1000.5 ETH, Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789",
expected: "Amount: 1000.5 ETH, Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789",
},
{
name: "USD values filtered",
input: "Profit: $5000.00 USD",
expected: "Profit: [AMOUNT_FILTERED] USD",
},
{
name: "Multiple amounts filtered",
input: "Swap 100.0 USDC for 0.05 ETH",
expected: "Swap [AMOUNT_FILTERED]C for 0.05 ETH",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filter.FilterMessage(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestFilterMessage_ProductionLevel(t *testing.T) {
filter := NewSecureFilter(SecurityLevelProduction)
tests := []struct {
name string
input string
expected string
}{
{
name: "Production filters everything sensitive",
input: "Amount: 1000.5 ETH, Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789, Value: $5000.00",
expected: "Amount: 1000.5 ETH, Address: 0x742d...6789, Value: [AMOUNT_FILTERED]",
},
{
name: "Complex transaction filtered",
input: "Swap 1500.789 USDC from 0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD to 0xaf88d065e77c8cC2239327C5EDb3A432268e5831 for $1500.00 profit",
expected: "Swap [AMOUNT_FILTERED]C from 0xA0b8...7ACD to 0xaf88...5831 for [AMOUNT_FILTERED] profit",
},
{
name: "Multiple addresses and amounts",
input: "Transfer 500 tokens from 0x1234567890123456789012345678901234567890 to 0x0987654321098765432109876543210987654321 worth $1000",
expected: "Transfer 500 tokens from 0x1234...7890 to 0x0987...4321 worth [AMOUNT_FILTERED]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filter.FilterMessage(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestShortenAddress(t *testing.T) {
filter := NewSecureFilter(SecurityLevelInfo)
tests := []struct {
name string
input common.Address
expected string
}{
{
name: "Standard address",
input: common.HexToAddress("0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789"),
expected: "0x742D...6789",
},
{
name: "Zero address",
input: common.HexToAddress("0x0000000000000000000000000000000000000000"),
expected: "0x0000...0000",
},
{
name: "Max address",
input: common.HexToAddress("0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"),
expected: "0xFFfF...FFfF",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filter.shortenAddress(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestCategorizeAmount(t *testing.T) {
_ = NewSecureFilter(SecurityLevelInfo) // Reference to avoid unused variable warning
tests := []struct {
name string
input *big.Int
expected string
}{
{
name: "Nil amount",
input: nil,
expected: "nil",
},
{
name: "Micro amount (< $1)",
input: big.NewInt(500000000000000000), // 0.5 ETH assuming 18 decimals
expected: "micro",
},
{
name: "Small amount ($1-$100)",
input: func() *big.Int { val, _ := new(big.Int).SetString("50000000000000000000", 10); return val }(), // 50 ETH
expected: "small",
},
{
name: "Medium amount ($100-$10k)",
input: func() *big.Int { val, _ := new(big.Int).SetString("5000000000000000000000", 10); return val }(), // 5000 ETH
expected: "medium",
},
{
name: "Large amount ($10k-$1M)",
input: func() *big.Int { val, _ := new(big.Int).SetString("500000000000000000000000", 10); return val }(), // 500k ETH
expected: "large",
},
{
name: "Whale amount (>$1M)",
input: func() *big.Int { val, _ := new(big.Int).SetString("2000000000000000000000000", 10); return val }(), // 2M ETH
expected: "whale",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Note: categorizeAmount is private, so we can't test it directly
// This test would need to be adapted to test the public API that uses it
_ = tt.input // Reference to avoid unused variable warning
_ = tt.expected // Reference to avoid unused variable warning
// Just pass the test since we can't test private methods directly
assert.True(t, true, "categorizeAmount is private - testing would need public wrapper")
})
}
}
func TestSanitizeForProduction(t *testing.T) {
filter := NewSecureFilter(SecurityLevelProduction)
tests := []struct {
name string
input map[string]interface{}
expected map[string]interface{}
}{
{
name: "Amounts filtered",
input: map[string]interface{}{
"amount": 1000.5,
"amountIn": 500,
"amountOut": 750,
"value": 999.99,
"other": "should remain",
},
expected: map[string]interface{}{
"amount": "[FILTERED]",
"amountIn": "[FILTERED]",
"amountOut": "[FILTERED]",
"value": "[FILTERED]",
"other": "should remain",
},
},
{
name: "Addresses filtered and shortened",
input: map[string]interface{}{
"address": common.HexToAddress("0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789"),
"token": "0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD",
"pool": "0xaf88d065e77c8cC2239327C5EDb3A432268e5831",
"other": "should remain",
},
expected: map[string]interface{}{
"address": "0x742D...6789",
"token": "0xA0b8...7ACD",
"pool": "0xaf88...5831",
"other": "should remain",
},
},
{
name: "Mixed data types",
input: map[string]interface{}{
"profit": 1000.0,
"tokenOut": common.HexToAddress("0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789"),
"fee": 30,
"protocol": "UniswapV3",
"timestamp": 1640995200,
},
expected: map[string]interface{}{
"profit": "[FILTERED]",
"tokenOut": "0x742D...6789",
"fee": 30,
"protocol": "UniswapV3",
"timestamp": 1640995200,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filter.SanitizeForProduction(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestFilterMessage_ComplexScenarios(t *testing.T) {
productionFilter := NewSecureFilter(SecurityLevelProduction)
infoFilter := NewSecureFilter(SecurityLevelInfo)
tests := []struct {
name string
input string
production string
info string
}{
{
name: "MEV opportunity log",
input: "🎯 ARBITRAGE OPPORTUNITY: Swap 1500.789 USDC via 0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD for profit $250.50",
production: "🎯 ARBITRAGE OPPORTUNITY: Swap [AMOUNT_FILTERED]C via 0xA0b8...7ACD for profit [AMOUNT_FILTERED]",
info: "🎯 ARBITRAGE OPPORTUNITY: Swap [AMOUNT_FILTERED]C via 0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD for profit [AMOUNT_FILTERED]",
},
{
name: "Transaction log with multiple sensitive data",
input: "TX: 0x123...abc Amount: 999.123456 ETH → 1500000.5 USDC, Gas: 150000, Value: $2500000.75",
production: "TX: 0x123...abc Amount: 999.123456 ETH → [AMOUNT_FILTERED]C, Gas: 150000, Value: [AMOUNT_FILTERED]",
info: "TX: 0x123...abc Amount: 999.123456 ETH → [AMOUNT_FILTERED]C, Gas: 150000, Value: [AMOUNT_FILTERED]",
},
{
name: "Pool creation event",
input: "Pool created: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789 with 1000000.0 liquidity worth $5000000",
production: "Pool created: 0x742d...6789 with 1000000.0 liquidity worth [AMOUNT_FILTERED]",
info: "Pool created: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789 with 1000000.0 liquidity worth [AMOUNT_FILTERED]",
},
}
for _, tt := range tests {
t.Run(tt.name+" - Production", func(t *testing.T) {
result := productionFilter.FilterMessage(tt.input)
assert.Equal(t, tt.production, result)
})
t.Run(tt.name+" - Info", func(t *testing.T) {
result := infoFilter.FilterMessage(tt.input)
assert.Equal(t, tt.info, result)
})
}
}
func TestFilterMessage_EdgeCases(t *testing.T) {
filter := NewSecureFilter(SecurityLevelProduction)
tests := []struct {
name string
input string
expected string
}{
{
name: "Empty string",
input: "",
expected: "",
},
{
name: "No sensitive data",
input: "Simple log message with no sensitive information",
expected: "Simple log message with no sensitive information",
},
{
name: "Only numbers (not amounts)",
input: "Block number: 12345, Gas limit: 8000000",
expected: "Block number: 12345, Gas limit: 8000000",
},
{
name: "Scientific notation",
input: "Amount: 1.5e18 wei",
expected: "Amount: 1.5e18 wei",
},
{
name: "Multiple decimal places",
input: "Price: 1234.567890123456 tokens",
expected: "Price: 1234.567890123456 tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filter.FilterMessage(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// Benchmark tests
func BenchmarkFilterMessage_Production(b *testing.B) {
filter := NewSecureFilter(SecurityLevelProduction)
input := "🎯 ARBITRAGE OPPORTUNITY: Swap 1500.789 USDC via 0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD for profit $250.50"
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.FilterMessage(input)
}
}
func BenchmarkFilterMessage_Info(b *testing.B) {
filter := NewSecureFilter(SecurityLevelInfo)
input := "Transaction: 1000.5 ETH from 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789 to 0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD"
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.FilterMessage(input)
}
}
func BenchmarkSanitizeForProduction(b *testing.B) {
filter := NewSecureFilter(SecurityLevelProduction)
data := map[string]interface{}{
"amount": 1000.5,
"address": common.HexToAddress("0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789"),
"profit": 250.75,
"protocol": "UniswapV3",
"token": "0xA0b86a33E6441f43E2e4A96439abFA2A69067ACD",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.SanitizeForProduction(data)
}
}
func TestSecurityLevelConstants(t *testing.T) {
// Verify security level constants are defined correctly
assert.Equal(t, SecurityLevel(0), SecurityLevelDebug)
assert.Equal(t, SecurityLevel(1), SecurityLevelInfo)
assert.Equal(t, SecurityLevel(2), SecurityLevelProduction)
}
func TestRegexPatterns(t *testing.T) {
filter := NewSecureFilter(SecurityLevelProduction)
// Test that patterns are properly compiled
assert.True(t, len(filter.amountPatterns) > 0, "Should have amount patterns")
assert.True(t, len(filter.addressPatterns) > 0, "Should have address patterns")
assert.True(t, len(filter.valuePatterns) > 0, "Should have value patterns")
// Test pattern matching
testCases := []struct {
patterns []*regexp.Regexp
input string
should string
}{
{filter.amountPatterns, "amount=123", "match amount patterns"},
{filter.addressPatterns, "Address: 0x742d35Cc6AaB8f5d6649c8C4F7C6b2d123456789", "match address patterns"},
{filter.valuePatterns, "profit=$1234.56", "match value patterns"},
}
for _, tc := range testCases {
found := false
for _, pattern := range tc.patterns {
if pattern.MatchString(tc.input) {
found = true
break
}
}
assert.True(t, found, tc.should)
}
}

View File

@@ -0,0 +1,400 @@
package monitoring
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// LogAlertHandler logs alerts to the application logger
type LogAlertHandler struct {
logger *logger.Logger
}
// NewLogAlertHandler creates a new log-based alert handler
func NewLogAlertHandler(logger *logger.Logger) *LogAlertHandler {
return &LogAlertHandler{
logger: logger,
}
}
// HandleAlert logs the alert using structured logging
func (lah *LogAlertHandler) HandleAlert(alert CorruptionAlert) error {
switch alert.Severity {
case AlertSeverityEmergency:
lah.logger.Error("🚨 EMERGENCY ALERT",
"message", alert.Message,
"severity", alert.Severity.String(),
"timestamp", alert.Timestamp,
"context", alert.Context)
case AlertSeverityCritical:
lah.logger.Error("🔴 CRITICAL ALERT",
"message", alert.Message,
"severity", alert.Severity.String(),
"timestamp", alert.Timestamp,
"context", alert.Context)
case AlertSeverityWarning:
lah.logger.Warn("🟡 WARNING ALERT",
"message", alert.Message,
"severity", alert.Severity.String(),
"timestamp", alert.Timestamp,
"context", alert.Context)
default:
lah.logger.Info(" INFO ALERT",
"message", alert.Message,
"severity", alert.Severity.String(),
"timestamp", alert.Timestamp,
"context", alert.Context)
}
return nil
}
// FileAlertHandler writes alerts to a file in JSON format
type FileAlertHandler struct {
mu sync.Mutex
filePath string
logger *logger.Logger
}
// NewFileAlertHandler creates a new file-based alert handler
func NewFileAlertHandler(filePath string, logger *logger.Logger) *FileAlertHandler {
return &FileAlertHandler{
filePath: filePath,
logger: logger,
}
}
// HandleAlert writes the alert to a file
func (fah *FileAlertHandler) HandleAlert(alert CorruptionAlert) error {
fah.mu.Lock()
defer fah.mu.Unlock()
// Create alert record for file
alertRecord := map[string]interface{}{
"timestamp": alert.Timestamp.Format(time.RFC3339),
"severity": alert.Severity.String(),
"message": alert.Message,
"address": alert.Address.Hex(),
"corruption_score": alert.CorruptionScore,
"source": alert.Source,
"context": alert.Context,
}
// Convert to JSON
alertJSON, err := json.Marshal(alertRecord)
if err != nil {
return fmt.Errorf("failed to marshal alert: %w", err)
}
// Open file for appending
file, err := os.OpenFile(fah.filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("failed to open alert file: %w", err)
}
defer file.Close()
// Write alert with newline
if _, err := file.Write(append(alertJSON, '\n')); err != nil {
return fmt.Errorf("failed to write alert to file: %w", err)
}
fah.logger.Debug("Alert written to file",
"file", fah.filePath,
"severity", alert.Severity.String())
return nil
}
// HTTPAlertHandler sends alerts to an HTTP endpoint (e.g., Slack, Discord, PagerDuty)
type HTTPAlertHandler struct {
mu sync.Mutex
webhookURL string
client *http.Client
logger *logger.Logger
retryCount int
}
// NewHTTPAlertHandler creates a new HTTP-based alert handler
func NewHTTPAlertHandler(webhookURL string, logger *logger.Logger) *HTTPAlertHandler {
return &HTTPAlertHandler{
webhookURL: webhookURL,
client: &http.Client{
Timeout: 10 * time.Second,
},
logger: logger,
retryCount: 3,
}
}
// HandleAlert sends the alert to the configured HTTP endpoint
func (hah *HTTPAlertHandler) HandleAlert(alert CorruptionAlert) error {
if hah.webhookURL == "" {
return fmt.Errorf("webhook URL not configured")
}
// Create payload based on webhook type
payload := hah.createPayload(alert)
// Convert payload to JSON
payloadJSON, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal webhook payload: %w", err)
}
// Send with retries
for attempt := 1; attempt <= hah.retryCount; attempt++ {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
req, err := http.NewRequestWithContext(ctx, "POST", hah.webhookURL, strings.NewReader(string(payloadJSON)))
cancel()
if err != nil {
return fmt.Errorf("failed to create HTTP request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "MEV-Bot-AlertHandler/1.0")
resp, err := hah.client.Do(req)
if err != nil {
hah.logger.Warn("Failed to send alert to webhook",
"attempt", attempt,
"error", err)
if attempt == hah.retryCount {
return fmt.Errorf("failed to send alert after %d attempts: %w", hah.retryCount, err)
}
time.Sleep(time.Duration(attempt) * time.Second)
continue
}
defer resp.Body.Close()
// Read response body for debugging
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
hah.logger.Debug("Alert sent successfully",
"webhook_url", hah.webhookURL,
"status_code", resp.StatusCode,
"response", string(body))
return nil
}
hah.logger.Warn("Webhook returned error status",
"attempt", attempt,
"status_code", resp.StatusCode,
"response", string(body))
if attempt == hah.retryCount {
return fmt.Errorf("webhook returned status %d after %d attempts", resp.StatusCode, hah.retryCount)
}
time.Sleep(time.Duration(attempt) * time.Second)
}
return nil
}
// createPayload creates the webhook payload based on the webhook type
func (hah *HTTPAlertHandler) createPayload(alert CorruptionAlert) map[string]interface{} {
// Detect webhook type based on URL
if strings.Contains(hah.webhookURL, "slack.com") {
return hah.createSlackPayload(alert)
} else if strings.Contains(hah.webhookURL, "discord.com") {
return hah.createDiscordPayload(alert)
}
// Generic webhook payload
return map[string]interface{}{
"timestamp": alert.Timestamp.Format(time.RFC3339),
"severity": alert.Severity.String(),
"message": alert.Message,
"address": alert.Address.Hex(),
"corruption_score": alert.CorruptionScore,
"source": alert.Source,
"context": alert.Context,
}
}
// createSlackPayload creates a Slack-compatible webhook payload
func (hah *HTTPAlertHandler) createSlackPayload(alert CorruptionAlert) map[string]interface{} {
color := "good"
switch alert.Severity {
case AlertSeverityWarning:
color = "warning"
case AlertSeverityCritical:
color = "danger"
case AlertSeverityEmergency:
color = "#FF0000" // Bright red for emergency
}
attachment := map[string]interface{}{
"color": color,
"title": fmt.Sprintf("%s Alert - MEV Bot", alert.Severity.String()),
"text": alert.Message,
"timestamp": alert.Timestamp.Unix(),
"fields": []map[string]interface{}{
{
"title": "Address",
"value": alert.Address.Hex(),
"short": true,
},
{
"title": "Corruption Score",
"value": fmt.Sprintf("%d", alert.CorruptionScore),
"short": true,
},
{
"title": "Source",
"value": alert.Source,
"short": true,
},
},
}
return map[string]interface{}{
"text": fmt.Sprintf("MEV Bot Alert: %s", alert.Severity.String()),
"attachments": []map[string]interface{}{attachment},
}
}
// createDiscordPayload creates a Discord-compatible webhook payload
func (hah *HTTPAlertHandler) createDiscordPayload(alert CorruptionAlert) map[string]interface{} {
color := 0x00FF00 // Green
switch alert.Severity {
case AlertSeverityWarning:
color = 0xFFFF00 // Yellow
case AlertSeverityCritical:
color = 0xFF8000 // Orange
case AlertSeverityEmergency:
color = 0xFF0000 // Red
}
embed := map[string]interface{}{
"title": fmt.Sprintf("%s Alert - MEV Bot", alert.Severity.String()),
"description": alert.Message,
"color": color,
"timestamp": alert.Timestamp.Format(time.RFC3339),
"fields": []map[string]interface{}{
{
"name": "Address",
"value": alert.Address.Hex(),
"inline": true,
},
{
"name": "Corruption Score",
"value": fmt.Sprintf("%d", alert.CorruptionScore),
"inline": true,
},
{
"name": "Source",
"value": alert.Source,
"inline": true,
},
},
"footer": map[string]interface{}{
"text": "MEV Bot Integrity Monitor",
},
}
return map[string]interface{}{
"embeds": []map[string]interface{}{embed},
}
}
// MetricsAlertHandler integrates with metrics systems (Prometheus, etc.)
type MetricsAlertHandler struct {
mu sync.Mutex
logger *logger.Logger
counters map[string]int64
}
// NewMetricsAlertHandler creates a new metrics-based alert handler
func NewMetricsAlertHandler(logger *logger.Logger) *MetricsAlertHandler {
return &MetricsAlertHandler{
logger: logger,
counters: make(map[string]int64),
}
}
// HandleAlert updates metrics counters based on alert
func (mah *MetricsAlertHandler) HandleAlert(alert CorruptionAlert) error {
mah.mu.Lock()
defer mah.mu.Unlock()
// Increment counters
mah.counters["total_alerts"]++
mah.counters[fmt.Sprintf("alerts_%s", strings.ToLower(alert.Severity.String()))]++
if alert.CorruptionScore > 0 {
mah.counters["corruption_alerts"]++
}
mah.logger.Debug("Metrics updated for alert",
"severity", alert.Severity.String(),
"total_alerts", mah.counters["total_alerts"])
return nil
}
// GetCounters returns the current alert counters
func (mah *MetricsAlertHandler) GetCounters() map[string]int64 {
mah.mu.Lock()
defer mah.mu.Unlock()
// Return a copy
counters := make(map[string]int64)
for k, v := range mah.counters {
counters[k] = v
}
return counters
}
// CompositeAlertHandler combines multiple alert handlers
type CompositeAlertHandler struct {
handlers []AlertSubscriber
logger *logger.Logger
}
// NewCompositeAlertHandler creates a composite alert handler
func NewCompositeAlertHandler(logger *logger.Logger, handlers ...AlertSubscriber) *CompositeAlertHandler {
return &CompositeAlertHandler{
handlers: handlers,
logger: logger,
}
}
// HandleAlert sends the alert to all configured handlers
func (cah *CompositeAlertHandler) HandleAlert(alert CorruptionAlert) error {
errors := make([]error, 0)
for i, handler := range cah.handlers {
if err := handler.HandleAlert(alert); err != nil {
cah.logger.Error("Alert handler failed",
"handler_index", i,
"handler_type", fmt.Sprintf("%T", handler),
"error", err)
errors = append(errors, fmt.Errorf("handler %d (%T): %w", i, handler, err))
}
}
if len(errors) > 0 {
return fmt.Errorf("alert handler errors: %v", errors)
}
return nil
}
// AddHandler adds a new handler to the composite
func (cah *CompositeAlertHandler) AddHandler(handler AlertSubscriber) {
cah.handlers = append(cah.handlers, handler)
}

View File

@@ -0,0 +1,549 @@
package monitoring
import (
"encoding/json"
"fmt"
"html/template"
"net/http"
"strconv"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// DashboardServer provides a web-based monitoring dashboard
type DashboardServer struct {
logger *logger.Logger
integrityMonitor *IntegrityMonitor
healthChecker *HealthCheckRunner
port int
server *http.Server
}
// NewDashboardServer creates a new dashboard server
func NewDashboardServer(logger *logger.Logger, integrityMonitor *IntegrityMonitor, healthChecker *HealthCheckRunner, port int) *DashboardServer {
return &DashboardServer{
logger: logger,
integrityMonitor: integrityMonitor,
healthChecker: healthChecker,
port: port,
}
}
// Start starts the dashboard HTTP server
func (ds *DashboardServer) Start() error {
mux := http.NewServeMux()
// Register endpoints
mux.HandleFunc("/", ds.handleDashboard)
mux.HandleFunc("/api/health", ds.handleAPIHealth)
mux.HandleFunc("/api/metrics", ds.handleAPIMetrics)
mux.HandleFunc("/api/history", ds.handleAPIHistory)
mux.HandleFunc("/api/alerts", ds.handleAPIAlerts)
mux.HandleFunc("/static/", ds.handleStatic)
ds.server = &http.Server{
Addr: fmt.Sprintf(":%d", ds.port),
Handler: mux,
}
ds.logger.Info("Starting monitoring dashboard",
"port", ds.port,
"url", fmt.Sprintf("http://localhost:%d", ds.port))
return ds.server.ListenAndServe()
}
// Stop stops the dashboard server
func (ds *DashboardServer) Stop() error {
if ds.server != nil {
return ds.server.Close()
}
return nil
}
// handleDashboard serves the main dashboard HTML page
func (ds *DashboardServer) handleDashboard(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
// Get current metrics and health data
metrics := ds.integrityMonitor.GetMetrics()
healthSummary := ds.integrityMonitor.GetHealthSummary()
healthHistory := ds.healthChecker.GetRecentSnapshots(20)
// Render dashboard template
tmpl := ds.getDashboardTemplate()
data := struct {
Metrics MetricsSnapshot
HealthSummary map[string]interface{}
HealthHistory []HealthSnapshot
Timestamp time.Time
}{
Metrics: metrics,
HealthSummary: healthSummary,
HealthHistory: healthHistory,
Timestamp: time.Now(),
}
if err := tmpl.Execute(w, data); err != nil {
ds.logger.Error("Failed to render dashboard", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}
// handleAPIHealth provides JSON health endpoint
func (ds *DashboardServer) handleAPIHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
healthSummary := ds.integrityMonitor.GetHealthSummary()
healthCheckerSummary := ds.healthChecker.GetHealthSummary()
// Combine summaries
response := map[string]interface{}{
"integrity_monitor": healthSummary,
"health_checker": healthCheckerSummary,
"timestamp": time.Now(),
}
if err := json.NewEncoder(w).Encode(response); err != nil {
ds.logger.Error("Failed to encode health response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
// handleAPIMetrics provides JSON metrics endpoint
func (ds *DashboardServer) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
metrics := ds.integrityMonitor.GetMetrics()
if err := json.NewEncoder(w).Encode(metrics); err != nil {
ds.logger.Error("Failed to encode metrics response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
// handleAPIHistory provides JSON health history endpoint
func (ds *DashboardServer) handleAPIHistory(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// Get count parameter (default 20)
countStr := r.URL.Query().Get("count")
count := 20
if countStr != "" {
if c, err := strconv.Atoi(countStr); err == nil && c > 0 && c <= 100 {
count = c
}
}
history := ds.healthChecker.GetRecentSnapshots(count)
if err := json.NewEncoder(w).Encode(history); err != nil {
ds.logger.Error("Failed to encode history response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
// handleAPIAlerts provides recent alerts for integrity and health monitoring.
func (ds *DashboardServer) handleAPIAlerts(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
limit := 20
if q := r.URL.Query().Get("limit"); q != "" {
if parsed, err := strconv.Atoi(q); err == nil && parsed > 0 && parsed <= 200 {
limit = parsed
}
}
alerts := ds.integrityMonitor.GetRecentAlerts(limit)
payload := map[string]interface{}{
"alerts": alerts,
"count": len(alerts),
"timestamp": time.Now(),
}
if err := json.NewEncoder(w).Encode(payload); err != nil {
ds.logger.Error("Failed to encode alerts response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
// handleStatic serves static assets (CSS, JS)
func (ds *DashboardServer) handleStatic(w http.ResponseWriter, r *http.Request) {
// For simplicity, we'll inline CSS and JS in the HTML template
// In a production system, you'd serve actual static files
http.NotFound(w, r)
}
// getDashboardTemplate returns the HTML template for the dashboard
func (ds *DashboardServer) getDashboardTemplate() *template.Template {
htmlTemplate := `
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MEV Bot - Data Integrity Monitor</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background-color: #f5f5f5;
color: #333;
line-height: 1.6;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 1rem 0;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 0 1rem;
}
.header h1 {
font-size: 2rem;
font-weight: 300;
}
.header .subtitle {
opacity: 0.9;
margin-top: 0.5rem;
}
.dashboard {
padding: 2rem 0;
}
.grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
gap: 1.5rem;
margin-bottom: 2rem;
}
.card {
background: white;
border-radius: 8px;
padding: 1.5rem;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
border-left: 4px solid #667eea;
}
.card h3 {
color: #333;
margin-bottom: 1rem;
font-size: 1.25rem;
}
.metric {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.5rem 0;
border-bottom: 1px solid #eee;
}
.metric:last-child {
border-bottom: none;
}
.metric-label {
font-weight: 500;
color: #666;
}
.metric-value {
font-weight: 600;
color: #333;
}
.health-score {
font-size: 2rem;
font-weight: bold;
text-align: center;
padding: 1rem;
border-radius: 50%;
width: 100px;
height: 100px;
display: flex;
align-items: center;
justify-content: center;
margin: 0 auto 1rem;
}
.health-excellent { background: #4CAF50; color: white; }
.health-good { background: #8BC34A; color: white; }
.health-fair { background: #FF9800; color: white; }
.health-poor { background: #F44336; color: white; }
.status-indicator {
display: inline-block;
width: 12px;
height: 12px;
border-radius: 50%;
margin-right: 8px;
}
.status-healthy { background: #4CAF50; }
.status-warning { background: #FF9800; }
.status-critical { background: #F44336; }
.chart-container {
background: white;
border-radius: 8px;
padding: 1.5rem;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
margin-top: 1.5rem;
}
.refresh-indicator {
position: fixed;
top: 20px;
right: 20px;
background: #667eea;
color: white;
padding: 0.5rem 1rem;
border-radius: 4px;
font-size: 0.875rem;
}
.timestamp {
text-align: center;
color: #666;
font-size: 0.875rem;
margin-top: 2rem;
}
.recovery-actions {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 1rem;
margin-top: 1rem;
}
.recovery-action {
background: #f8f9fa;
padding: 0.75rem;
border-radius: 4px;
text-align: center;
}
.recovery-action-count {
font-size: 1.5rem;
font-weight: bold;
color: #667eea;
}
.recovery-action-label {
font-size: 0.875rem;
color: #666;
text-transform: uppercase;
}
</style>
</head>
<body>
<div class="header">
<div class="container">
<h1>🤖 MEV Bot - Data Integrity Monitor</h1>
<p class="subtitle">Real-time monitoring of corruption detection and recovery systems</p>
</div>
</div>
<div class="dashboard">
<div class="container">
<div class="grid">
<!-- Health Score Card -->
<div class="card">
<h3>System Health</h3>
<div class="health-score {{.HealthSummary.health_score | healthClass}}">
{{.HealthSummary.health_score | printf "%.1f"}}
</div>
<div class="metric">
<span class="metric-label">Status</span>
<span class="metric-value">
<span class="status-indicator {{.HealthSummary.health_score | statusClass}}"></span>
{{.HealthSummary.health_score | healthStatus}}
</span>
</div>
<div class="metric">
<span class="metric-label">Monitor Enabled</span>
<span class="metric-value">{{if .HealthSummary.enabled}}✅ Yes{{else}}❌ No{{end}}</span>
</div>
</div>
<!-- Processing Statistics -->
<div class="card">
<h3>Processing Statistics</h3>
<div class="metric">
<span class="metric-label">Total Addresses</span>
<span class="metric-value">{{.Metrics.TotalAddressesProcessed | printf "%,d"}}</span>
</div>
<div class="metric">
<span class="metric-label">Corruption Detected</span>
<span class="metric-value">{{.Metrics.CorruptAddressesDetected | printf "%,d"}}</span>
</div>
<div class="metric">
<span class="metric-label">Corruption Rate</span>
<span class="metric-value">{{.HealthSummary.corruption_rate | printf "%.4f%%"}}</span>
</div>
<div class="metric">
<span class="metric-label">Avg Corruption Score</span>
<span class="metric-value">{{.Metrics.AverageCorruptionScore | printf "%.1f"}}</span>
</div>
<div class="metric">
<span class="metric-label">Max Corruption Score</span>
<span class="metric-value">{{.Metrics.MaxCorruptionScore}}</span>
</div>
</div>
<!-- Validation Results -->
<div class="card">
<h3>Validation Results</h3>
<div class="metric">
<span class="metric-label">Validation Passed</span>
<span class="metric-value">{{.Metrics.AddressValidationPassed | printf "%,d"}}</span>
</div>
<div class="metric">
<span class="metric-label">Validation Failed</span>
<span class="metric-value">{{.Metrics.AddressValidationFailed | printf "%,d"}}</span>
</div>
<div class="metric">
<span class="metric-label">Success Rate</span>
<span class="metric-value">{{.HealthSummary.validation_success_rate | printf "%.2f%%"}}</span>
</div>
</div>
<!-- Contract Calls -->
<div class="card">
<h3>Contract Calls</h3>
<div class="metric">
<span class="metric-label">Successful Calls</span>
<span class="metric-value">{{.Metrics.ContractCallsSucceeded | printf "%,d"}}</span>
</div>
<div class="metric">
<span class="metric-label">Failed Calls</span>
<span class="metric-value">{{.Metrics.ContractCallsFailed | printf "%,d"}}</span>
</div>
<div class="metric">
<span class="metric-label">Success Rate</span>
<span class="metric-value">{{.HealthSummary.contract_call_success_rate | printf "%.2f%%"}}</span>
</div>
</div>
</div>
<!-- Recovery Actions -->
<div class="chart-container">
<h3>Recovery System Activity</h3>
<div class="recovery-actions">
<div class="recovery-action">
<div class="recovery-action-count">{{.Metrics.RetryOperationsTriggered}}</div>
<div class="recovery-action-label">Retry Operations</div>
</div>
<div class="recovery-action">
<div class="recovery-action-count">{{.Metrics.FallbackOperationsUsed}}</div>
<div class="recovery-action-label">Fallback Used</div>
</div>
<div class="recovery-action">
<div class="recovery-action-count">{{.Metrics.CircuitBreakersTripped}}</div>
<div class="recovery-action-label">Circuit Breakers</div>
</div>
</div>
</div>
<div class="timestamp">
Last updated: {{.Timestamp.Format "2006-01-02 15:04:05 UTC"}}
<br>
Auto-refresh every 30 seconds
</div>
</div>
</div>
<div class="refresh-indicator">🔄 Live</div>
<script>
// Auto-refresh every 30 seconds
setInterval(function() {
window.location.reload();
}, 30000);
// Add smooth transitions
document.addEventListener('DOMContentLoaded', function() {
const cards = document.querySelectorAll('.card');
cards.forEach((card, index) => {
card.style.animationDelay = (index * 0.1) + 's';
card.style.animation = 'fadeInUp 0.6s ease forwards';
});
});
</script>
<style>
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
</style>
</body>
</html>
`
// Create template with custom functions
funcMap := template.FuncMap{
"healthClass": func(score interface{}) string {
s := score.(float64)
if s >= 0.9 {
return "health-excellent"
} else if s >= 0.7 {
return "health-good"
} else if s >= 0.5 {
return "health-fair"
}
return "health-poor"
},
"statusClass": func(score interface{}) string {
s := score.(float64)
if s >= 0.7 {
return "status-healthy"
} else if s >= 0.5 {
return "status-warning"
}
return "status-critical"
},
"healthStatus": func(score interface{}) string {
s := score.(float64)
if s >= 0.9 {
return "Excellent"
} else if s >= 0.7 {
return "Good"
} else if s >= 0.5 {
return "Fair"
}
return "Poor"
},
}
return template.Must(template.New("dashboard").Funcs(funcMap).Parse(htmlTemplate))
}
// GetDashboardURL returns the dashboard URL
func (ds *DashboardServer) GetDashboardURL() string {
return fmt.Sprintf("http://localhost:%d", ds.port)
}

View File

@@ -0,0 +1,447 @@
package monitoring
import (
"context"
"fmt"
"sync"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// HealthCheckRunner performs periodic health checks and monitoring
type HealthCheckRunner struct {
mu sync.RWMutex
logger *logger.Logger
integrityMonitor *IntegrityMonitor
checkInterval time.Duration
running bool
stopChan chan struct{}
lastHealthCheck time.Time
healthHistory []HealthSnapshot
maxHistorySize int
warmupSamples int
minAddressesForAlerts int64
}
// HealthSnapshot represents a point-in-time health snapshot
type HealthSnapshot struct {
Timestamp time.Time
HealthScore float64
CorruptionRate float64
ValidationSuccess float64
ContractCallSuccess float64
ActiveAlerts int
Trend HealthTrend
}
// HealthTrend indicates the direction of health metrics
type HealthTrend int
const (
HealthTrendUnknown HealthTrend = iota
HealthTrendImproving
HealthTrendStable
HealthTrendDeclining
HealthTrendCritical
)
func (t HealthTrend) String() string {
switch t {
case HealthTrendImproving:
return "IMPROVING"
case HealthTrendStable:
return "STABLE"
case HealthTrendDeclining:
return "DECLINING"
case HealthTrendCritical:
return "CRITICAL"
default:
return "UNKNOWN"
}
}
// NewHealthCheckRunner creates a new health check runner
func NewHealthCheckRunner(logger *logger.Logger, integrityMonitor *IntegrityMonitor) *HealthCheckRunner {
return &HealthCheckRunner{
logger: logger,
integrityMonitor: integrityMonitor,
checkInterval: 30 * time.Second, // Default 30 second intervals
stopChan: make(chan struct{}),
healthHistory: make([]HealthSnapshot, 0),
maxHistorySize: 100, // Keep last 100 snapshots (50 minutes at 30s intervals)
warmupSamples: 3,
minAddressesForAlerts: 25,
}
}
// Start begins the periodic health checking routine
func (hcr *HealthCheckRunner) Start(ctx context.Context) {
hcr.mu.Lock()
if hcr.running {
hcr.mu.Unlock()
return
}
hcr.running = true
hcr.mu.Unlock()
hcr.logger.Info("Starting health check runner",
"interval", hcr.checkInterval)
go hcr.healthCheckLoop(ctx)
}
// Stop stops the health checking routine
func (hcr *HealthCheckRunner) Stop() {
hcr.mu.Lock()
defer hcr.mu.Unlock()
if !hcr.running {
return
}
hcr.running = false
close(hcr.stopChan)
hcr.logger.Info("Health check runner stopped")
}
// healthCheckLoop runs the periodic health checking
func (hcr *HealthCheckRunner) healthCheckLoop(ctx context.Context) {
ticker := time.NewTicker(hcr.checkInterval)
defer ticker.Stop()
// Perform initial health check
hcr.performHealthCheck()
for {
select {
case <-ctx.Done():
hcr.logger.Info("Health check runner stopped due to context cancellation")
return
case <-hcr.stopChan:
hcr.logger.Info("Health check runner stopped")
return
case <-ticker.C:
hcr.performHealthCheck()
}
}
}
// performHealthCheck executes a comprehensive health check
func (hcr *HealthCheckRunner) performHealthCheck() {
start := time.Now()
hcr.lastHealthCheck = start
if !hcr.integrityMonitor.IsEnabled() {
hcr.logger.Debug("Skipping health check - integrity monitor disabled")
return
}
// Get current metrics
metrics := hcr.integrityMonitor.GetMetrics()
healthSummary := hcr.integrityMonitor.GetHealthSummary()
// Calculate rates
corruptionRate := 0.0
if metrics.TotalAddressesProcessed > 0 {
corruptionRate = float64(metrics.CorruptAddressesDetected) / float64(metrics.TotalAddressesProcessed)
}
validationSuccessRate := 0.0
totalValidations := metrics.AddressValidationPassed + metrics.AddressValidationFailed
if totalValidations > 0 {
validationSuccessRate = float64(metrics.AddressValidationPassed) / float64(totalValidations)
}
contractCallSuccessRate := 0.0
totalCalls := metrics.ContractCallsSucceeded + metrics.ContractCallsFailed
if totalCalls > 0 {
contractCallSuccessRate = float64(metrics.ContractCallsSucceeded) / float64(totalCalls)
}
// Create health snapshot
snapshot := HealthSnapshot{
Timestamp: start,
HealthScore: metrics.HealthScore,
CorruptionRate: corruptionRate,
ValidationSuccess: validationSuccessRate,
ContractCallSuccess: contractCallSuccessRate,
ActiveAlerts: 0, // Will be calculated based on current conditions
Trend: hcr.calculateHealthTrend(metrics.HealthScore),
}
// Add to history
hcr.addHealthSnapshot(snapshot)
// Check for threshold violations and generate alerts
hcr.checkThresholds(healthSummary, snapshot)
// Log health status periodically
hcr.logHealthStatus(snapshot, time.Since(start))
}
// addHealthSnapshot adds a snapshot to the health history
func (hcr *HealthCheckRunner) addHealthSnapshot(snapshot HealthSnapshot) {
hcr.mu.Lock()
defer hcr.mu.Unlock()
hcr.healthHistory = append(hcr.healthHistory, snapshot)
// Trim history if it exceeds max size
if len(hcr.healthHistory) > hcr.maxHistorySize {
hcr.healthHistory = hcr.healthHistory[len(hcr.healthHistory)-hcr.maxHistorySize:]
}
}
// calculateHealthTrend analyzes recent health scores to determine trend
func (hcr *HealthCheckRunner) calculateHealthTrend(currentScore float64) HealthTrend {
hcr.mu.RLock()
defer hcr.mu.RUnlock()
if len(hcr.healthHistory) < 3 {
return HealthTrendUnknown
}
// Get last few scores for trend analysis
recentScores := make([]float64, 0, 5)
start := len(hcr.healthHistory) - 5
if start < 0 {
start = 0
}
for i := start; i < len(hcr.healthHistory); i++ {
recentScores = append(recentScores, hcr.healthHistory[i].HealthScore)
}
recentScores = append(recentScores, currentScore)
// Calculate trend
if currentScore < 0.5 {
return HealthTrendCritical
}
// Simple linear trend calculation
if len(recentScores) >= 3 {
first := recentScores[0]
last := recentScores[len(recentScores)-1]
diff := last - first
if diff > 0.05 {
return HealthTrendImproving
} else if diff < -0.05 {
return HealthTrendDeclining
} else {
return HealthTrendStable
}
}
return HealthTrendUnknown
}
// checkThresholds checks for threshold violations and generates alerts
func (hcr *HealthCheckRunner) checkThresholds(healthSummary map[string]interface{}, snapshot HealthSnapshot) {
if !hcr.readyForAlerts(healthSummary, snapshot) {
hcr.logger.Debug("Health alerts suppressed during warm-up",
"health_score", snapshot.HealthScore,
"total_addresses_processed", safeNumericLookup(healthSummary, "total_addresses_processed"),
"history_size", hcr.historySize())
return
}
// Critical health score alert
if snapshot.HealthScore < 0.5 {
alert := CorruptionAlert{
Timestamp: time.Now(),
Severity: AlertSeverityEmergency,
Message: fmt.Sprintf("CRITICAL: System health score is %.2f (below 0.5)", snapshot.HealthScore),
Context: map[string]interface{}{
"health_score": snapshot.HealthScore,
"corruption_rate": snapshot.CorruptionRate,
"validation_success": snapshot.ValidationSuccess,
"contract_call_success": snapshot.ContractCallSuccess,
"trend": snapshot.Trend.String(),
},
}
hcr.integrityMonitor.sendAlert(alert)
}
// High corruption rate alert
if snapshot.CorruptionRate > 0.10 { // 10% corruption rate
alert := CorruptionAlert{
Timestamp: time.Now(),
Severity: AlertSeverityCritical,
Message: fmt.Sprintf("High corruption rate detected: %.2f%%", snapshot.CorruptionRate*100),
Context: map[string]interface{}{
"corruption_rate": snapshot.CorruptionRate,
"threshold": 0.10,
"addresses_affected": snapshot.CorruptionRate,
},
}
hcr.integrityMonitor.sendAlert(alert)
}
// Declining trend alert
if snapshot.Trend == HealthTrendDeclining || snapshot.Trend == HealthTrendCritical {
alert := CorruptionAlert{
Timestamp: time.Now(),
Severity: AlertSeverityWarning,
Message: fmt.Sprintf("System health trend is %s (current score: %.2f)", snapshot.Trend.String(), snapshot.HealthScore),
Context: map[string]interface{}{
"trend": snapshot.Trend.String(),
"health_score": snapshot.HealthScore,
"recent_snapshots": hcr.getRecentSnapshots(5),
},
}
hcr.integrityMonitor.sendAlert(alert)
}
}
func (hcr *HealthCheckRunner) readyForAlerts(healthSummary map[string]interface{}, snapshot HealthSnapshot) bool {
hcr.mu.RLock()
historyLen := len(hcr.healthHistory)
hcr.mu.RUnlock()
if historyLen < hcr.warmupSamples {
return false
}
totalProcessed := safeNumericLookup(healthSummary, "total_addresses_processed")
if totalProcessed >= 0 && totalProcessed < float64(hcr.minAddressesForAlerts) {
return false
}
// Require at least one validation or contract call attempt before alarming.
if snapshot.ValidationSuccess == 0 && snapshot.ContractCallSuccess == 0 && totalProcessed == 0 {
return false
}
return true
}
func safeNumericLookup(summary map[string]interface{}, key string) float64 {
if summary == nil {
return -1
}
value, ok := summary[key]
if !ok {
return -1
}
switch v := value.(type) {
case int:
return float64(v)
case int32:
return float64(v)
case int64:
return float64(v)
case uint:
return float64(v)
case uint32:
return float64(v)
case uint64:
return float64(v)
case float32:
return float64(v)
case float64:
return v
default:
return -1
}
}
func (hcr *HealthCheckRunner) historySize() int {
hcr.mu.RLock()
defer hcr.mu.RUnlock()
return len(hcr.healthHistory)
}
// logHealthStatus logs periodic health status information
func (hcr *HealthCheckRunner) logHealthStatus(snapshot HealthSnapshot, duration time.Duration) {
// Log detailed status every 5 minutes (10 checks at 30s intervals)
if len(hcr.healthHistory)%10 == 0 {
hcr.logger.Info("System health status",
"health_score", snapshot.HealthScore,
"corruption_rate", fmt.Sprintf("%.4f", snapshot.CorruptionRate),
"validation_success", fmt.Sprintf("%.4f", snapshot.ValidationSuccess),
"contract_call_success", fmt.Sprintf("%.4f", snapshot.ContractCallSuccess),
"trend", snapshot.Trend.String(),
"check_duration", duration)
} else {
// Brief status for regular checks
hcr.logger.Debug("Health check completed",
"health_score", snapshot.HealthScore,
"trend", snapshot.Trend.String(),
"duration", duration)
}
}
// GetRecentSnapshots returns the most recent health snapshots
func (hcr *HealthCheckRunner) GetRecentSnapshots(count int) []HealthSnapshot {
return hcr.getRecentSnapshots(count)
}
// getRecentSnapshots internal implementation
func (hcr *HealthCheckRunner) getRecentSnapshots(count int) []HealthSnapshot {
hcr.mu.RLock()
defer hcr.mu.RUnlock()
if len(hcr.healthHistory) == 0 {
return []HealthSnapshot{}
}
start := len(hcr.healthHistory) - count
if start < 0 {
start = 0
}
// Create a copy to avoid external modification
snapshots := make([]HealthSnapshot, len(hcr.healthHistory[start:]))
copy(snapshots, hcr.healthHistory[start:])
return snapshots
}
// GetHealthSummary returns a comprehensive health summary
func (hcr *HealthCheckRunner) GetHealthSummary() map[string]interface{} {
hcr.mu.RLock()
defer hcr.mu.RUnlock()
if len(hcr.healthHistory) == 0 {
return map[string]interface{}{
"running": hcr.running,
"check_interval": hcr.checkInterval.String(),
"history_size": 0,
"last_check": nil,
}
}
lastSnapshot := hcr.healthHistory[len(hcr.healthHistory)-1]
return map[string]interface{}{
"running": hcr.running,
"check_interval": hcr.checkInterval.String(),
"history_size": len(hcr.healthHistory),
"last_check": hcr.lastHealthCheck,
"current_health_score": lastSnapshot.HealthScore,
"current_trend": lastSnapshot.Trend.String(),
"corruption_rate": lastSnapshot.CorruptionRate,
"validation_success": lastSnapshot.ValidationSuccess,
"contract_call_success": lastSnapshot.ContractCallSuccess,
"recent_snapshots": hcr.getRecentSnapshots(10),
}
}
// SetCheckInterval sets the health check interval
func (hcr *HealthCheckRunner) SetCheckInterval(interval time.Duration) {
hcr.mu.Lock()
defer hcr.mu.Unlock()
hcr.checkInterval = interval
hcr.logger.Info("Health check interval updated", "interval", interval)
}
// IsRunning returns whether the health checker is running
func (hcr *HealthCheckRunner) IsRunning() bool {
hcr.mu.RLock()
defer hcr.mu.RUnlock()
return hcr.running
}

View File

@@ -0,0 +1,533 @@
package monitoring
import (
"context"
"fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/fraktal/mev-beta/internal/logger"
"github.com/fraktal/mev-beta/internal/recovery"
)
// IntegrityMetrics tracks data integrity statistics
type IntegrityMetrics struct {
mu sync.RWMutex
TotalAddressesProcessed int64
CorruptAddressesDetected int64
AddressValidationPassed int64
AddressValidationFailed int64
ContractCallsSucceeded int64
ContractCallsFailed int64
RetryOperationsTriggered int64
FallbackOperationsUsed int64
CircuitBreakersTripped int64
LastCorruptionDetection time.Time
AverageCorruptionScore float64
MaxCorruptionScore int
HealthScore float64
HighScore float64
RecoveryActions map[recovery.RecoveryAction]int64
ErrorsByType map[recovery.ErrorType]int64
}
// MetricsSnapshot represents a copy of metrics without mutex for safe external access
type MetricsSnapshot struct {
TotalAddressesProcessed int64 `json:"total_addresses_processed"`
CorruptAddressesDetected int64 `json:"corrupt_addresses_detected"`
AddressValidationPassed int64 `json:"address_validation_passed"`
AddressValidationFailed int64 `json:"address_validation_failed"`
ContractCallsSucceeded int64 `json:"contract_calls_succeeded"`
ContractCallsFailed int64 `json:"contract_calls_failed"`
RetryOperationsTriggered int64 `json:"retry_operations_triggered"`
FallbackOperationsUsed int64 `json:"fallback_operations_used"`
CircuitBreakersTripped int64 `json:"circuit_breakers_tripped"`
LastCorruptionDetection time.Time `json:"last_corruption_detection"`
AverageCorruptionScore float64 `json:"average_corruption_score"`
MaxCorruptionScore int `json:"max_corruption_score"`
HealthScore float64 `json:"health_score"`
HighScore float64 `json:"high_score"`
RecoveryActions map[recovery.RecoveryAction]int64 `json:"recovery_actions"`
ErrorsByType map[recovery.ErrorType]int64 `json:"errors_by_type"`
}
// CorruptionAlert represents a corruption detection alert
type CorruptionAlert struct {
Timestamp time.Time
Address common.Address
CorruptionScore int
Source string
Severity AlertSeverity
Message string
Context map[string]interface{}
}
// AlertSeverity defines alert severity levels
type AlertSeverity int
const (
AlertSeverityInfo AlertSeverity = iota
AlertSeverityWarning
AlertSeverityCritical
AlertSeverityEmergency
)
func (s AlertSeverity) String() string {
switch s {
case AlertSeverityInfo:
return "INFO"
case AlertSeverityWarning:
return "WARNING"
case AlertSeverityCritical:
return "CRITICAL"
case AlertSeverityEmergency:
return "EMERGENCY"
default:
return "UNKNOWN"
}
}
// IntegrityMonitor monitors and tracks data integrity metrics
type IntegrityMonitor struct {
mu sync.RWMutex
logger *logger.Logger
metrics *IntegrityMetrics
alertThresholds map[string]float64
alertSubscribers []AlertSubscriber
healthCheckRunner *HealthCheckRunner
enabled bool
alerts []CorruptionAlert
alertsMutex sync.RWMutex
}
// AlertSubscriber defines the interface for alert handlers
type AlertSubscriber interface {
HandleAlert(alert CorruptionAlert) error
}
// NewIntegrityMonitor creates a new integrity monitoring system
func NewIntegrityMonitor(logger *logger.Logger) *IntegrityMonitor {
monitor := &IntegrityMonitor{
logger: logger,
metrics: &IntegrityMetrics{
RecoveryActions: make(map[recovery.RecoveryAction]int64),
ErrorsByType: make(map[recovery.ErrorType]int64),
HealthScore: 1.0,
HighScore: 1.0,
},
alertThresholds: make(map[string]float64),
enabled: true,
alerts: make([]CorruptionAlert, 0, 256),
}
// Set default thresholds
monitor.setDefaultThresholds()
// Initialize health check runner
monitor.healthCheckRunner = NewHealthCheckRunner(logger, monitor)
return monitor
}
// setDefaultThresholds configures default alert thresholds
func (im *IntegrityMonitor) setDefaultThresholds() {
im.alertThresholds["corruption_rate"] = 0.05 // 5% corruption rate
im.alertThresholds["failure_rate"] = 0.10 // 10% failure rate
im.alertThresholds["health_score_min"] = 0.80 // 80% minimum health
im.alertThresholds["max_corruption_score"] = 70.0 // Maximum individual corruption score
im.alertThresholds["circuit_breaker_rate"] = 0.02 // 2% circuit breaker rate
}
// RecordAddressProcessed increments the counter for processed addresses
func (im *IntegrityMonitor) RecordAddressProcessed() {
if !im.enabled {
return
}
im.metrics.mu.Lock()
im.metrics.TotalAddressesProcessed++
im.metrics.mu.Unlock()
im.updateHealthScore()
}
// RecordCorruptionDetected records a corruption detection event
func (im *IntegrityMonitor) RecordCorruptionDetected(address common.Address, corruptionScore int, source string) {
if !im.enabled {
return
}
im.metrics.mu.Lock()
im.metrics.CorruptAddressesDetected++
im.metrics.LastCorruptionDetection = time.Now()
// Update corruption statistics
if corruptionScore > im.metrics.MaxCorruptionScore {
im.metrics.MaxCorruptionScore = corruptionScore
}
// Calculate rolling average corruption score
total := float64(im.metrics.CorruptAddressesDetected)
im.metrics.AverageCorruptionScore = ((im.metrics.AverageCorruptionScore * (total - 1)) + float64(corruptionScore)) / total
im.metrics.mu.Unlock()
// Generate alert based on corruption score
severity := im.getCorruptionSeverity(corruptionScore)
alert := CorruptionAlert{
Timestamp: time.Now(),
Address: address,
CorruptionScore: corruptionScore,
Source: source,
Severity: severity,
Message: fmt.Sprintf("Corruption detected: address %s, score %d, source %s", address.Hex(), corruptionScore, source),
Context: map[string]interface{}{
"address": address.Hex(),
"corruption_score": corruptionScore,
"source": source,
"timestamp": time.Now().Unix(),
},
}
im.sendAlert(alert)
im.updateHealthScore()
im.logger.Warn("Corruption detected",
"address", address.Hex(),
"corruption_score", corruptionScore,
"source", source,
"severity", severity.String())
}
// RecordValidationResult records address validation results
func (im *IntegrityMonitor) RecordValidationResult(passed bool) {
if !im.enabled {
return
}
im.metrics.mu.Lock()
if passed {
im.metrics.AddressValidationPassed++
} else {
im.metrics.AddressValidationFailed++
}
im.metrics.mu.Unlock()
im.updateHealthScore()
}
// RecordContractCallResult records contract call success/failure
func (im *IntegrityMonitor) RecordContractCallResult(succeeded bool) {
if !im.enabled {
return
}
im.metrics.mu.Lock()
if succeeded {
im.metrics.ContractCallsSucceeded++
} else {
im.metrics.ContractCallsFailed++
}
im.metrics.mu.Unlock()
im.updateHealthScore()
}
// RecordRecoveryAction records recovery action usage
func (im *IntegrityMonitor) RecordRecoveryAction(action recovery.RecoveryAction) {
if !im.enabled {
return
}
im.metrics.mu.Lock()
im.metrics.RecoveryActions[action]++
// Track specific metrics
switch action {
case recovery.ActionRetryWithBackoff:
im.metrics.RetryOperationsTriggered++
case recovery.ActionUseFallbackData:
im.metrics.FallbackOperationsUsed++
case recovery.ActionCircuitBreaker:
im.metrics.CircuitBreakersTripped++
}
im.metrics.mu.Unlock()
im.updateHealthScore()
}
// RecordErrorType records error by type
func (im *IntegrityMonitor) RecordErrorType(errorType recovery.ErrorType) {
if !im.enabled {
return
}
im.metrics.mu.Lock()
im.metrics.ErrorsByType[errorType]++
im.metrics.mu.Unlock()
}
// getCorruptionSeverity determines alert severity based on corruption score
func (im *IntegrityMonitor) getCorruptionSeverity(corruptionScore int) AlertSeverity {
if corruptionScore >= 90 {
return AlertSeverityEmergency
} else if corruptionScore >= 70 {
return AlertSeverityCritical
} else if corruptionScore >= 40 {
return AlertSeverityWarning
}
return AlertSeverityInfo
}
// updateHealthScore calculates overall system health score
func (im *IntegrityMonitor) updateHealthScore() {
im.metrics.mu.Lock()
defer im.metrics.mu.Unlock()
if im.metrics.TotalAddressesProcessed == 0 {
im.metrics.HealthScore = 1.0
return
}
// Calculate component scores
corruptionRate := float64(im.metrics.CorruptAddressesDetected) / float64(im.metrics.TotalAddressesProcessed)
var validationSuccessRate float64 = 1.0
validationTotal := im.metrics.AddressValidationPassed + im.metrics.AddressValidationFailed
if validationTotal > 0 {
validationSuccessRate = float64(im.metrics.AddressValidationPassed) / float64(validationTotal)
}
var contractCallSuccessRate float64 = 1.0
contractTotal := im.metrics.ContractCallsSucceeded + im.metrics.ContractCallsFailed
if contractTotal > 0 {
contractCallSuccessRate = float64(im.metrics.ContractCallsSucceeded) / float64(contractTotal)
}
// Weighted health score calculation
healthScore := 0.0
healthScore += (1.0 - corruptionRate) * 0.4 // 40% weight on corruption prevention
healthScore += validationSuccessRate * 0.3 // 30% weight on validation success
healthScore += contractCallSuccessRate * 0.3 // 30% weight on contract call success
// Cap at 1.0 and handle edge cases
if healthScore > 1.0 {
healthScore = 1.0
} else if healthScore < 0.0 {
healthScore = 0.0
}
im.metrics.HealthScore = healthScore
if healthScore > im.metrics.HighScore {
im.metrics.HighScore = healthScore
}
// Check for health score threshold alerts
if healthScore < im.alertThresholds["health_score_min"] {
alert := CorruptionAlert{
Timestamp: time.Now(),
Severity: AlertSeverityCritical,
Message: fmt.Sprintf("System health score dropped to %.2f (threshold: %.2f)", healthScore, im.alertThresholds["health_score_min"]),
Context: map[string]interface{}{
"health_score": healthScore,
"threshold": im.alertThresholds["health_score_min"],
"corruption_rate": corruptionRate,
"validation_success": validationSuccessRate,
"contract_call_success": contractCallSuccessRate,
},
}
im.sendAlert(alert)
}
}
// sendAlert sends alerts to all subscribers
func (im *IntegrityMonitor) sendAlert(alert CorruptionAlert) {
im.alertsMutex.Lock()
im.alerts = append(im.alerts, alert)
if len(im.alerts) > 1000 {
trimmed := make([]CorruptionAlert, 1000)
copy(trimmed, im.alerts[len(im.alerts)-1000:])
im.alerts = trimmed
}
im.alertsMutex.Unlock()
for _, subscriber := range im.alertSubscribers {
if err := subscriber.HandleAlert(alert); err != nil {
im.logger.Error("Failed to send alert",
"subscriber", fmt.Sprintf("%T", subscriber),
"error", err)
}
}
}
// AddAlertSubscriber adds an alert subscriber
func (im *IntegrityMonitor) AddAlertSubscriber(subscriber AlertSubscriber) {
im.mu.Lock()
defer im.mu.Unlock()
im.alertSubscribers = append(im.alertSubscribers, subscriber)
}
// GetMetrics returns a copy of current metrics
func (im *IntegrityMonitor) GetMetrics() MetricsSnapshot {
im.metrics.mu.RLock()
defer im.metrics.mu.RUnlock()
// Create a deep copy
metrics := IntegrityMetrics{
TotalAddressesProcessed: im.metrics.TotalAddressesProcessed,
CorruptAddressesDetected: im.metrics.CorruptAddressesDetected,
AddressValidationPassed: im.metrics.AddressValidationPassed,
AddressValidationFailed: im.metrics.AddressValidationFailed,
ContractCallsSucceeded: im.metrics.ContractCallsSucceeded,
ContractCallsFailed: im.metrics.ContractCallsFailed,
RetryOperationsTriggered: im.metrics.RetryOperationsTriggered,
FallbackOperationsUsed: im.metrics.FallbackOperationsUsed,
CircuitBreakersTripped: im.metrics.CircuitBreakersTripped,
LastCorruptionDetection: im.metrics.LastCorruptionDetection,
AverageCorruptionScore: im.metrics.AverageCorruptionScore,
MaxCorruptionScore: im.metrics.MaxCorruptionScore,
HealthScore: im.metrics.HealthScore,
HighScore: im.metrics.HighScore,
RecoveryActions: make(map[recovery.RecoveryAction]int64),
ErrorsByType: make(map[recovery.ErrorType]int64),
}
// Copy maps
for k, v := range im.metrics.RecoveryActions {
metrics.RecoveryActions[k] = v
}
for k, v := range im.metrics.ErrorsByType {
metrics.ErrorsByType[k] = v
}
// Return a safe copy without mutex
return MetricsSnapshot{
TotalAddressesProcessed: metrics.TotalAddressesProcessed,
CorruptAddressesDetected: metrics.CorruptAddressesDetected,
AddressValidationPassed: metrics.AddressValidationPassed,
AddressValidationFailed: metrics.AddressValidationFailed,
ContractCallsSucceeded: metrics.ContractCallsSucceeded,
ContractCallsFailed: metrics.ContractCallsFailed,
RetryOperationsTriggered: metrics.RetryOperationsTriggered,
FallbackOperationsUsed: metrics.FallbackOperationsUsed,
CircuitBreakersTripped: metrics.CircuitBreakersTripped,
LastCorruptionDetection: metrics.LastCorruptionDetection,
AverageCorruptionScore: metrics.AverageCorruptionScore,
MaxCorruptionScore: metrics.MaxCorruptionScore,
HealthScore: metrics.HealthScore,
HighScore: metrics.HighScore,
RecoveryActions: metrics.RecoveryActions,
ErrorsByType: metrics.ErrorsByType,
}
}
// GetHealthSummary returns a comprehensive health summary
func (im *IntegrityMonitor) GetHealthSummary() map[string]interface{} {
metrics := im.GetMetrics()
corruptionRate := 0.0
if metrics.TotalAddressesProcessed > 0 {
corruptionRate = float64(metrics.CorruptAddressesDetected) / float64(metrics.TotalAddressesProcessed)
}
validationSuccessRate := 0.0
totalValidations := metrics.AddressValidationPassed + metrics.AddressValidationFailed
if totalValidations > 0 {
validationSuccessRate = float64(metrics.AddressValidationPassed) / float64(totalValidations)
}
contractCallSuccessRate := 0.0
totalCalls := metrics.ContractCallsSucceeded + metrics.ContractCallsFailed
if totalCalls > 0 {
contractCallSuccessRate = float64(metrics.ContractCallsSucceeded) / float64(totalCalls)
}
return map[string]interface{}{
"enabled": im.enabled,
"health_score": metrics.HealthScore,
"total_addresses_processed": metrics.TotalAddressesProcessed,
"corruption_detections": metrics.CorruptAddressesDetected,
"corruption_rate": corruptionRate,
"validation_success_rate": validationSuccessRate,
"contract_call_success_rate": contractCallSuccessRate,
"average_corruption_score": metrics.AverageCorruptionScore,
"max_corruption_score": metrics.MaxCorruptionScore,
"retry_operations": metrics.RetryOperationsTriggered,
"fallback_operations": metrics.FallbackOperationsUsed,
"circuit_breakers_tripped": metrics.CircuitBreakersTripped,
"last_corruption": metrics.LastCorruptionDetection,
"recovery_actions": metrics.RecoveryActions,
"errors_by_type": metrics.ErrorsByType,
"alert_thresholds": im.alertThresholds,
"alert_subscribers": len(im.alertSubscribers),
}
}
// GetRecentAlerts returns the most recent corruption alerts up to the specified limit.
func (im *IntegrityMonitor) GetRecentAlerts(limit int) []CorruptionAlert {
im.alertsMutex.RLock()
defer im.alertsMutex.RUnlock()
if limit <= 0 || limit > len(im.alerts) {
limit = len(im.alerts)
}
if limit == 0 {
return []CorruptionAlert{}
}
start := len(im.alerts) - limit
alertsCopy := make([]CorruptionAlert, limit)
copy(alertsCopy, im.alerts[start:])
return alertsCopy
}
// SetThreshold sets an alert threshold
func (im *IntegrityMonitor) SetThreshold(name string, value float64) {
im.mu.Lock()
defer im.mu.Unlock()
im.alertThresholds[name] = value
}
// Enable enables the integrity monitor
func (im *IntegrityMonitor) Enable() {
im.mu.Lock()
defer im.mu.Unlock()
im.enabled = true
im.logger.Info("Integrity monitor enabled")
}
// Disable disables the integrity monitor
func (im *IntegrityMonitor) Disable() {
im.mu.Lock()
defer im.mu.Unlock()
im.enabled = false
im.logger.Info("Integrity monitor disabled")
}
// IsEnabled returns whether the monitor is enabled
func (im *IntegrityMonitor) IsEnabled() bool {
im.mu.RLock()
defer im.mu.RUnlock()
return im.enabled
}
// StartHealthCheckRunner starts the periodic health check routine
func (im *IntegrityMonitor) StartHealthCheckRunner(ctx context.Context) {
if im.healthCheckRunner != nil {
im.healthCheckRunner.Start(ctx)
}
}
// StopHealthCheckRunner stops the periodic health check routine
func (im *IntegrityMonitor) StopHealthCheckRunner() {
if im.healthCheckRunner != nil {
im.healthCheckRunner.Stop()
}
}
// GetHealthCheckRunner returns the health check runner
func (im *IntegrityMonitor) GetHealthCheckRunner() *HealthCheckRunner {
return im.healthCheckRunner
}

View File

@@ -0,0 +1,391 @@
package monitoring
import (
"fmt"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/fraktal/mev-beta/internal/logger"
"github.com/fraktal/mev-beta/internal/recovery"
)
// MockAlertSubscriber for testing
type MockAlertSubscriber struct {
alerts []CorruptionAlert
}
func (m *MockAlertSubscriber) HandleAlert(alert CorruptionAlert) error {
m.alerts = append(m.alerts, alert)
return nil
}
func (m *MockAlertSubscriber) GetAlerts() []CorruptionAlert {
return m.alerts
}
func (m *MockAlertSubscriber) Reset() {
m.alerts = nil
}
func TestIntegrityMonitor_RecordCorruptionDetected(t *testing.T) {
log := logger.New("error", "text", "")
monitor := NewIntegrityMonitor(log)
mockSubscriber := &MockAlertSubscriber{}
monitor.AddAlertSubscriber(mockSubscriber)
// Test various corruption scenarios
testCases := []struct {
name string
address string
corruptionScore int
source string
expectedSeverity AlertSeverity
}{
{
name: "Low corruption",
address: "0x1234567890123456789012345678901234567890",
corruptionScore: 30,
source: "test_source",
expectedSeverity: AlertSeverityInfo,
},
{
name: "Medium corruption",
address: "0x1234000000000000000000000000000000000000",
corruptionScore: 50,
source: "token_extraction",
expectedSeverity: AlertSeverityWarning,
},
{
name: "High corruption",
address: "0x0000001000000000000000000000000000000000",
corruptionScore: 80,
source: "abi_decoder",
expectedSeverity: AlertSeverityCritical,
},
{
name: "Critical corruption - TOKEN_0x000000",
address: "0x0000000300000000000000000000000000000000",
corruptionScore: 100,
source: "generic_extraction",
expectedSeverity: AlertSeverityEmergency,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockSubscriber.Reset()
addr := common.HexToAddress(tc.address)
monitor.RecordCorruptionDetected(addr, tc.corruptionScore, tc.source)
// Verify metrics were updated
metrics := monitor.GetMetrics()
assert.Greater(t, metrics.CorruptAddressesDetected, int64(0))
assert.GreaterOrEqual(t, metrics.MaxCorruptionScore, tc.corruptionScore)
// Verify alert was generated
alerts := mockSubscriber.GetAlerts()
require.Len(t, alerts, 1)
alert := alerts[0]
assert.Equal(t, tc.expectedSeverity, alert.Severity)
assert.Equal(t, addr, alert.Address)
assert.Equal(t, tc.corruptionScore, alert.CorruptionScore)
assert.Equal(t, tc.source, alert.Source)
assert.Contains(t, alert.Message, "Corruption detected")
})
}
}
func TestIntegrityMonitor_HealthScoreCalculation(t *testing.T) {
log := logger.New("error", "text", "")
monitor := NewIntegrityMonitor(log)
// Test initial health score
metrics := monitor.GetMetrics()
assert.Equal(t, 1.0, metrics.HealthScore) // Perfect health initially
// Record some activity
monitor.RecordAddressProcessed()
monitor.RecordAddressProcessed()
monitor.RecordValidationResult(true)
monitor.RecordValidationResult(true)
monitor.RecordContractCallResult(true)
monitor.RecordContractCallResult(true)
// Health should still be perfect
metrics = monitor.GetMetrics()
assert.Equal(t, 1.0, metrics.HealthScore)
// Introduce some corruption
addr := common.HexToAddress("0x0000000300000000000000000000000000000000")
monitor.RecordCorruptionDetected(addr, 80, "test")
// Health score should decrease
metrics = monitor.GetMetrics()
assert.Less(t, metrics.HealthScore, 1.0)
assert.Greater(t, metrics.HealthScore, 0.0)
// Add validation failures
monitor.RecordValidationResult(false)
monitor.RecordValidationResult(false)
// Health should decrease further
newMetrics := monitor.GetMetrics()
assert.Less(t, newMetrics.HealthScore, metrics.HealthScore)
}
func TestIntegrityMonitor_RecoveryActionTracking(t *testing.T) {
log := logger.New("error", "text", "")
monitor := NewIntegrityMonitor(log)
// Record various recovery actions
monitor.RecordRecoveryAction(recovery.ActionRetryWithBackoff)
monitor.RecordRecoveryAction(recovery.ActionRetryWithBackoff)
monitor.RecordRecoveryAction(recovery.ActionUseFallbackData)
monitor.RecordRecoveryAction(recovery.ActionCircuitBreaker)
metrics := monitor.GetMetrics()
// Verify action counts
assert.Equal(t, int64(2), metrics.RecoveryActions[recovery.ActionRetryWithBackoff])
assert.Equal(t, int64(1), metrics.RecoveryActions[recovery.ActionUseFallbackData])
assert.Equal(t, int64(1), metrics.RecoveryActions[recovery.ActionCircuitBreaker])
// Verify specific counters
assert.Equal(t, int64(2), metrics.RetryOperationsTriggered)
assert.Equal(t, int64(1), metrics.FallbackOperationsUsed)
assert.Equal(t, int64(1), metrics.CircuitBreakersTripped)
}
func TestIntegrityMonitor_ErrorTypeTracking(t *testing.T) {
log := logger.New("error", "text", "")
monitor := NewIntegrityMonitor(log)
// Record various error types
errorTypes := []recovery.ErrorType{
recovery.ErrorTypeAddressCorruption,
recovery.ErrorTypeContractCallFailed,
recovery.ErrorTypeRPCConnectionFailed,
recovery.ErrorTypeDataParsingFailed,
recovery.ErrorTypeValidationFailed,
recovery.ErrorTypeAddressCorruption, // Duplicate
}
for _, errorType := range errorTypes {
monitor.RecordErrorType(errorType)
}
metrics := monitor.GetMetrics()
// Verify error type counts
assert.Equal(t, int64(2), metrics.ErrorsByType[recovery.ErrorTypeAddressCorruption])
assert.Equal(t, int64(1), metrics.ErrorsByType[recovery.ErrorTypeContractCallFailed])
assert.Equal(t, int64(1), metrics.ErrorsByType[recovery.ErrorTypeRPCConnectionFailed])
assert.Equal(t, int64(1), metrics.ErrorsByType[recovery.ErrorTypeDataParsingFailed])
assert.Equal(t, int64(1), metrics.ErrorsByType[recovery.ErrorTypeValidationFailed])
}
func TestIntegrityMonitor_GetHealthSummary(t *testing.T) {
log := logger.New("error", "text", "")
monitor := NewIntegrityMonitor(log)
// Generate some activity
for i := 0; i < 100; i++ {
monitor.RecordAddressProcessed()
if i%10 == 0 { // 10% corruption rate
addr := common.HexToAddress(fmt.Sprintf("0x%040d", i))
monitor.RecordCorruptionDetected(addr, 50, "test")
}
monitor.RecordValidationResult(i%20 != 0) // 95% success rate
monitor.RecordContractCallResult(i%10 != 0) // 90% success rate
}
summary := monitor.GetHealthSummary()
// Verify summary structure
assert.True(t, summary["enabled"].(bool))
assert.Equal(t, int64(100), summary["total_addresses_processed"].(int64))
assert.Equal(t, int64(10), summary["corruption_detections"].(int64))
assert.InDelta(t, 0.1, summary["corruption_rate"].(float64), 0.01)
assert.InDelta(t, 0.95, summary["validation_success_rate"].(float64), 0.01)
assert.InDelta(t, 0.9, summary["contract_call_success_rate"].(float64), 0.01)
// Health score should be reasonable
healthScore := summary["health_score"].(float64)
assert.Greater(t, healthScore, 0.7) // Should be decent despite some issues
assert.Less(t, healthScore, 1.0) // Not perfect due to corruption
}
func TestIntegrityMonitor_AlertThresholds(t *testing.T) {
log := logger.New("error", "text", "")
monitor := NewIntegrityMonitor(log)
mockSubscriber := &MockAlertSubscriber{}
monitor.AddAlertSubscriber(mockSubscriber)
// Test health score threshold
monitor.SetThreshold("health_score_min", 0.8)
// Generate activity that drops health below threshold
for i := 0; i < 50; i++ {
monitor.RecordAddressProcessed()
// High corruption rate to drop health score
addr := common.HexToAddress(fmt.Sprintf("0x%040d", i))
monitor.RecordCorruptionDetected(addr, 80, "test")
}
// Should trigger health score alert
alerts := mockSubscriber.GetAlerts()
healthAlerts := 0
for _, alert := range alerts {
if alert.Severity == AlertSeverityCritical &&
alert.Context != nil &&
alert.Context["health_score"] != nil {
healthAlerts++
}
}
assert.Greater(t, healthAlerts, 0, "Should have triggered health score alerts")
}
func TestIntegrityMonitor_ConcurrentAccess(t *testing.T) {
log := logger.New("error", "text", "")
monitor := NewIntegrityMonitor(log)
const numGoroutines = 50
const operationsPerGoroutine = 100
done := make(chan bool, numGoroutines)
// Launch concurrent operations
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer func() { done <- true }()
for j := 0; j < operationsPerGoroutine; j++ {
// Perform various operations
monitor.RecordAddressProcessed()
monitor.RecordValidationResult(j%10 != 0)
monitor.RecordContractCallResult(j%5 != 0)
if j%20 == 0 { // Occasional corruption
addr := common.HexToAddress(fmt.Sprintf("0x%020d%020d", id, j))
monitor.RecordCorruptionDetected(addr, 60, fmt.Sprintf("goroutine_%d", id))
}
// Recovery actions
if j%15 == 0 {
monitor.RecordRecoveryAction(recovery.ActionRetryWithBackoff)
}
if j%25 == 0 {
monitor.RecordErrorType(recovery.ErrorTypeAddressCorruption)
}
}
}(i)
}
// Wait for completion
for i := 0; i < numGoroutines; i++ {
select {
case <-done:
// Success
case <-time.After(10 * time.Second):
t.Fatal("Concurrent test timed out")
}
}
// Verify final metrics are consistent
metrics := monitor.GetMetrics()
expectedAddresses := int64(numGoroutines * operationsPerGoroutine)
assert.Equal(t, expectedAddresses, metrics.TotalAddressesProcessed)
// Should have some corruption detections
assert.Greater(t, metrics.CorruptAddressesDetected, int64(0))
// Should have recorded recovery actions
assert.Greater(t, metrics.RetryOperationsTriggered, int64(0))
// Health score should be calculated
assert.GreaterOrEqual(t, metrics.HealthScore, 0.0)
assert.LessOrEqual(t, metrics.HealthScore, 1.0)
t.Logf("Final metrics: Processed=%d, Corrupted=%d, Health=%.3f",
metrics.TotalAddressesProcessed,
metrics.CorruptAddressesDetected,
metrics.HealthScore)
}
func TestIntegrityMonitor_DisableEnable(t *testing.T) {
log := logger.New("error", "text", "")
monitor := NewIntegrityMonitor(log)
// Should be enabled by default
assert.True(t, monitor.IsEnabled())
// Record some activity
monitor.RecordAddressProcessed()
monitor.RecordValidationResult(true)
initialMetrics := monitor.GetMetrics()
assert.Greater(t, initialMetrics.TotalAddressesProcessed, int64(0))
// Disable monitor
monitor.Disable()
assert.False(t, monitor.IsEnabled())
// Activity should not be recorded when disabled
monitor.RecordAddressProcessed()
monitor.RecordValidationResult(true)
disabledMetrics := monitor.GetMetrics()
assert.Equal(t, initialMetrics.TotalAddressesProcessed, disabledMetrics.TotalAddressesProcessed)
// Re-enable
monitor.Enable()
assert.True(t, monitor.IsEnabled())
// Activity should be recorded again
monitor.RecordAddressProcessed()
enabledMetrics := monitor.GetMetrics()
assert.Greater(t, enabledMetrics.TotalAddressesProcessed, disabledMetrics.TotalAddressesProcessed)
}
func TestIntegrityMonitor_Performance(t *testing.T) {
log := logger.New("error", "text", "")
monitor := NewIntegrityMonitor(log)
const iterations = 10000
// Benchmark recording operations
start := time.Now()
for i := 0; i < iterations; i++ {
monitor.RecordAddressProcessed()
monitor.RecordValidationResult(i%10 != 0)
monitor.RecordContractCallResult(i%5 != 0)
if i%100 == 0 {
addr := common.HexToAddress(fmt.Sprintf("0x%040d", i))
monitor.RecordCorruptionDetected(addr, 50, "benchmark")
}
}
duration := time.Since(start)
avgTime := duration / iterations
t.Logf("Performance: %d operations in %v (avg: %v per operation)",
iterations, duration, avgTime)
// Should be reasonably fast (under 500 microseconds per operation is acceptable)
maxTime := 500 * time.Microsecond
assert.Less(t, avgTime.Nanoseconds(), maxTime.Nanoseconds(),
"Recording should be faster than %v per operation (got %v)", maxTime, avgTime)
// Verify metrics are accurate
metrics := monitor.GetMetrics()
assert.Equal(t, int64(iterations), metrics.TotalAddressesProcessed)
assert.Equal(t, int64(100), metrics.CorruptAddressesDetected) // Every 100th iteration
}

View File

@@ -0,0 +1,494 @@
package ratelimit
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"golang.org/x/time/rate"
"github.com/fraktal/mev-beta/internal/config"
"github.com/fraktal/mev-beta/internal/logger"
)
// AdaptiveRateLimiter implements adaptive rate limiting that adjusts to endpoint capacity
type AdaptiveRateLimiter struct {
endpoints map[string]*AdaptiveEndpoint
mu sync.RWMutex
logger *logger.Logger
defaultConfig config.RateLimitConfig
adjustInterval time.Duration
stopChan chan struct{}
}
// AdaptiveEndpoint represents an endpoint with adaptive rate limiting
type AdaptiveEndpoint struct {
URL string
limiter *rate.Limiter
config config.RateLimitConfig
circuitBreaker *CircuitBreaker
metrics *EndpointMetrics
healthChecker *HealthChecker
lastAdjustment time.Time
consecutiveErrors int64
consecutiveSuccess int64
}
// EndpointMetrics tracks performance metrics for an endpoint
// All fields must be 64-bit aligned for atomic access
type EndpointMetrics struct {
TotalRequests int64
SuccessfulRequests int64
FailedRequests int64
TotalLatency int64 // nanoseconds
LastRequestTime int64 // unix timestamp
// Non-atomic fields - must be protected by mutex when accessed
mu sync.RWMutex
SuccessRate float64
AverageLatency float64 // milliseconds
}
// CircuitBreaker implements circuit breaker pattern for failed endpoints
type CircuitBreaker struct {
state int32 // 0: Closed, 1: Open, 2: HalfOpen
failureCount int64
lastFailTime int64
threshold int64
timeout time.Duration // How long to wait before trying again
testRequests int64 // Number of test requests in half-open state
}
// Circuit breaker states
const (
CircuitClosed = 0
CircuitOpen = 1
CircuitHalfOpen = 2
)
// HealthChecker monitors endpoint health
type HealthChecker struct {
endpoint string
interval time.Duration
timeout time.Duration
isHealthy int64 // atomic bool
lastCheck int64 // unix timestamp
stopChan chan struct{}
}
// NewAdaptiveRateLimiter creates a new adaptive rate limiter
func NewAdaptiveRateLimiter(cfg *config.ArbitrumConfig, logger *logger.Logger) *AdaptiveRateLimiter {
arl := &AdaptiveRateLimiter{
endpoints: make(map[string]*AdaptiveEndpoint),
logger: logger,
defaultConfig: cfg.RateLimit,
adjustInterval: 30 * time.Second,
stopChan: make(chan struct{}),
}
// Create adaptive endpoint for primary endpoint
arl.addEndpoint(cfg.RPCEndpoint, cfg.RateLimit)
// Create adaptive endpoints for reading endpoints
for _, endpoint := range cfg.ReadingEndpoints {
arl.addEndpoint(endpoint.URL, endpoint.RateLimit)
}
// Create adaptive endpoints for execution endpoints
for _, endpoint := range cfg.ExecutionEndpoints {
arl.addEndpoint(endpoint.URL, endpoint.RateLimit)
}
// Start background adjustment routine
go arl.adjustmentLoop()
return arl
}
// addEndpoint adds a new adaptive endpoint
func (arl *AdaptiveRateLimiter) addEndpoint(url string, config config.RateLimitConfig) {
endpoint := &AdaptiveEndpoint{
URL: url,
limiter: rate.NewLimiter(rate.Limit(config.RequestsPerSecond), config.Burst),
config: config,
circuitBreaker: &CircuitBreaker{
threshold: 10, // Break after 10 consecutive failures
timeout: 60 * time.Second,
},
metrics: &EndpointMetrics{},
healthChecker: &HealthChecker{
endpoint: url,
interval: 30 * time.Second,
timeout: 5 * time.Second,
isHealthy: 1, // Start assuming healthy
stopChan: make(chan struct{}),
},
}
arl.mu.Lock()
arl.endpoints[url] = endpoint
arl.mu.Unlock()
// Start health checker for this endpoint
go endpoint.healthChecker.start()
arl.logger.Info(fmt.Sprintf("Added adaptive rate limiter for endpoint: %s", url))
}
// WaitForBestEndpoint waits for the best available endpoint
func (arl *AdaptiveRateLimiter) WaitForBestEndpoint(ctx context.Context) (string, error) {
// Find the best available endpoint
bestEndpoint := arl.getBestEndpoint()
if bestEndpoint == "" {
return "", fmt.Errorf("no healthy endpoints available")
}
// Wait for rate limiter
arl.mu.RLock()
endpoint := arl.endpoints[bestEndpoint]
arl.mu.RUnlock()
if endpoint == nil {
return "", fmt.Errorf("endpoint not found: %s", bestEndpoint)
}
// Check circuit breaker
if !endpoint.circuitBreaker.canExecute() {
return "", fmt.Errorf("circuit breaker open for endpoint: %s", bestEndpoint)
}
// Wait for rate limiter
err := endpoint.limiter.Wait(ctx)
if err != nil {
return "", err
}
return bestEndpoint, nil
}
// RecordResult records the result of a request for adaptive adjustment
func (arl *AdaptiveRateLimiter) RecordResult(endpointURL string, success bool, latency time.Duration) {
arl.mu.RLock()
endpoint, exists := arl.endpoints[endpointURL]
arl.mu.RUnlock()
if !exists {
return
}
// Update metrics atomically
atomic.AddInt64(&endpoint.metrics.TotalRequests, 1)
atomic.AddInt64(&endpoint.metrics.TotalLatency, latency.Nanoseconds())
atomic.StoreInt64(&endpoint.metrics.LastRequestTime, time.Now().Unix())
if success {
atomic.AddInt64(&endpoint.metrics.SuccessfulRequests, 1)
atomic.AddInt64(&endpoint.consecutiveSuccess, 1)
atomic.StoreInt64(&endpoint.consecutiveErrors, 0)
endpoint.circuitBreaker.recordSuccess()
} else {
atomic.AddInt64(&endpoint.metrics.FailedRequests, 1)
atomic.AddInt64(&endpoint.consecutiveErrors, 1)
atomic.StoreInt64(&endpoint.consecutiveSuccess, 0)
endpoint.circuitBreaker.recordFailure()
}
// Update calculated metrics
arl.updateCalculatedMetrics(endpoint)
}
// updateCalculatedMetrics updates derived metrics
func (arl *AdaptiveRateLimiter) updateCalculatedMetrics(endpoint *AdaptiveEndpoint) {
totalReq := atomic.LoadInt64(&endpoint.metrics.TotalRequests)
successReq := atomic.LoadInt64(&endpoint.metrics.SuccessfulRequests)
totalLatency := atomic.LoadInt64(&endpoint.metrics.TotalLatency)
if totalReq > 0 {
endpoint.metrics.SuccessRate = float64(successReq) / float64(totalReq)
endpoint.metrics.AverageLatency = float64(totalLatency) / float64(totalReq) / 1000000 // Convert to milliseconds
}
}
// getBestEndpoint selects the best available endpoint based on metrics
func (arl *AdaptiveRateLimiter) getBestEndpoint() string {
arl.mu.RLock()
defer arl.mu.RUnlock()
bestEndpoint := ""
bestScore := float64(-1)
for url, endpoint := range arl.endpoints {
// Skip unhealthy endpoints
if atomic.LoadInt64(&endpoint.healthChecker.isHealthy) == 0 {
continue
}
// Skip if circuit breaker is open
if !endpoint.circuitBreaker.canExecute() {
continue
}
// Calculate score based on success rate, latency, and current load
score := arl.calculateEndpointScore(endpoint)
if score > bestScore {
bestScore = score
bestEndpoint = url
}
}
return bestEndpoint
}
// updateDerivedMetrics safely updates calculated metrics with proper synchronization
func (em *EndpointMetrics) updateDerivedMetrics() {
totalRequests := atomic.LoadInt64(&em.TotalRequests)
successfulRequests := atomic.LoadInt64(&em.SuccessfulRequests)
totalLatency := atomic.LoadInt64(&em.TotalLatency)
em.mu.Lock()
defer em.mu.Unlock()
// Calculate success rate
if totalRequests > 0 {
em.SuccessRate = float64(successfulRequests) / float64(totalRequests)
} else {
em.SuccessRate = 0.0
}
// Calculate average latency in milliseconds
if totalRequests > 0 {
em.AverageLatency = float64(totalLatency) / float64(totalRequests) / 1e6 // ns to ms
} else {
em.AverageLatency = 0.0
}
}
// getCalculatedMetrics safely returns derived metrics
func (em *EndpointMetrics) getCalculatedMetrics() (float64, float64) {
em.mu.RLock()
defer em.mu.RUnlock()
return em.SuccessRate, em.AverageLatency
}
// calculateEndpointScore calculates a score for endpoint selection
func (arl *AdaptiveRateLimiter) calculateEndpointScore(endpoint *AdaptiveEndpoint) float64 {
// Base score on success rate (0-1)
successWeight := 0.6
latencyWeight := 0.3
loadWeight := 0.1
// Update derived metrics first
endpoint.metrics.updateDerivedMetrics()
// Get calculated metrics safely
successScore, avgLatency := endpoint.metrics.getCalculatedMetrics()
// Invert latency score (lower latency = higher score)
latencyScore := 1.0
if avgLatency > 0 {
// Normalize latency score (assuming 1000ms is poor, 100ms is good)
latencyScore = 1.0 - (avgLatency / 1000.0)
if latencyScore < 0 {
latencyScore = 0
}
}
// Load score based on current rate limiter state
loadScore := 1.0 // Simplified - could check current tokens in limiter
return successScore*successWeight + latencyScore*latencyWeight + loadScore*loadWeight
}
// adjustmentLoop runs periodic adjustments to rate limits
func (arl *AdaptiveRateLimiter) adjustmentLoop() {
ticker := time.NewTicker(arl.adjustInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
arl.adjustRateLimits()
case <-arl.stopChan:
return
}
}
}
// adjustRateLimits adjusts rate limits based on observed performance
func (arl *AdaptiveRateLimiter) adjustRateLimits() {
arl.mu.Lock()
defer arl.mu.Unlock()
for url, endpoint := range arl.endpoints {
arl.adjustEndpointRateLimit(url, endpoint)
}
}
// adjustEndpointRateLimit adjusts rate limit for a specific endpoint
func (arl *AdaptiveRateLimiter) adjustEndpointRateLimit(url string, endpoint *AdaptiveEndpoint) {
// Don't adjust too frequently
if time.Since(endpoint.lastAdjustment) < arl.adjustInterval {
return
}
successRate := endpoint.metrics.SuccessRate
avgLatency := endpoint.metrics.AverageLatency
currentLimit := float64(endpoint.limiter.Limit())
var newLimit float64 = currentLimit
adjustmentFactor := 0.1 // 10% adjustment
// Increase rate if performing well
if successRate > 0.95 && avgLatency < 500 { // 95% success, < 500ms latency
newLimit = currentLimit * (1.0 + adjustmentFactor)
} else if successRate < 0.8 || avgLatency > 2000 { // < 80% success or > 2s latency
newLimit = currentLimit * (1.0 - adjustmentFactor)
}
// Apply bounds
minLimit := float64(arl.defaultConfig.RequestsPerSecond) * 0.1 // 10% of default minimum
maxLimit := float64(arl.defaultConfig.RequestsPerSecond) * 3.0 // 300% of default maximum
if newLimit < minLimit {
newLimit = minLimit
}
if newLimit > maxLimit {
newLimit = maxLimit
}
// Update if changed significantly
if abs(newLimit-currentLimit)/currentLimit > 0.05 { // 5% change threshold
endpoint.limiter.SetLimit(rate.Limit(newLimit))
endpoint.lastAdjustment = time.Now()
arl.logger.Info(fmt.Sprintf("Adjusted rate limit for %s: %.2f -> %.2f (success: %.2f%%, latency: %.2fms)",
url, currentLimit, newLimit, successRate*100, avgLatency))
}
}
// abs returns absolute value of float64
func abs(x float64) float64 {
if x < 0 {
return -x
}
return x
}
// canExecute checks if circuit breaker allows execution
func (cb *CircuitBreaker) canExecute() bool {
state := atomic.LoadInt32(&cb.state)
now := time.Now().Unix()
switch state {
case CircuitClosed:
return true
case CircuitOpen:
// Check if timeout has passed
lastFail := atomic.LoadInt64(&cb.lastFailTime)
if now-lastFail > int64(cb.timeout.Seconds()) {
// Try to move to half-open
if atomic.CompareAndSwapInt32(&cb.state, CircuitOpen, CircuitHalfOpen) {
atomic.StoreInt64(&cb.testRequests, 0)
return true
}
}
return false
case CircuitHalfOpen:
// Allow limited test requests
testReq := atomic.LoadInt64(&cb.testRequests)
if testReq < 3 { // Allow up to 3 test requests
atomic.AddInt64(&cb.testRequests, 1)
return true
}
return false
}
return false
}
// recordSuccess records a successful request
func (cb *CircuitBreaker) recordSuccess() {
state := atomic.LoadInt32(&cb.state)
if state == CircuitHalfOpen {
// Move back to closed after successful test
atomic.StoreInt32(&cb.state, CircuitClosed)
atomic.StoreInt64(&cb.failureCount, 0)
}
}
// recordFailure records a failed request
func (cb *CircuitBreaker) recordFailure() {
failures := atomic.AddInt64(&cb.failureCount, 1)
atomic.StoreInt64(&cb.lastFailTime, time.Now().Unix())
if failures >= cb.threshold {
atomic.StoreInt32(&cb.state, CircuitOpen)
}
}
// start starts the health checker
func (hc *HealthChecker) start() {
ticker := time.NewTicker(hc.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
hc.checkHealth()
case <-hc.stopChan:
return
}
}
}
// checkHealth performs a health check on the endpoint
func (hc *HealthChecker) checkHealth() {
ctx, cancel := context.WithTimeout(context.Background(), hc.timeout)
defer cancel()
// Simple health check - try to connect
// In production, this might make a simple RPC call
healthy := hc.performHealthCheck(ctx)
if healthy {
atomic.StoreInt64(&hc.isHealthy, 1)
} else {
atomic.StoreInt64(&hc.isHealthy, 0)
}
atomic.StoreInt64(&hc.lastCheck, time.Now().Unix())
}
// performHealthCheck performs the actual health check
func (hc *HealthChecker) performHealthCheck(ctx context.Context) bool {
// Simplified health check - in production would make actual RPC call
// For now, just simulate based on endpoint availability
return true // Assume healthy for demo
}
// Stop stops the adaptive rate limiter
func (arl *AdaptiveRateLimiter) Stop() {
close(arl.stopChan)
// Stop all health checkers
arl.mu.RLock()
for _, endpoint := range arl.endpoints {
close(endpoint.healthChecker.stopChan)
}
arl.mu.RUnlock()
}
// GetMetrics returns current metrics for all endpoints
func (arl *AdaptiveRateLimiter) GetMetrics() map[string]*EndpointMetrics {
arl.mu.RLock()
defer arl.mu.RUnlock()
metrics := make(map[string]*EndpointMetrics)
for url, endpoint := range arl.endpoints {
// Update calculated metrics before returning
arl.updateCalculatedMetrics(endpoint)
metrics[url] = endpoint.metrics
}
return metrics
}

View File

@@ -0,0 +1,139 @@
package ratelimit
import (
"context"
"fmt"
"sync"
"golang.org/x/time/rate"
"github.com/fraktal/mev-beta/internal/config"
)
// LimiterManager manages rate limiters for multiple endpoints
type LimiterManager struct {
limiters map[string]*EndpointLimiter
mu sync.RWMutex
}
// EndpointLimiter represents a rate limiter for a specific endpoint
type EndpointLimiter struct {
URL string
Limiter *rate.Limiter
Config config.RateLimitConfig
}
// NewLimiterManager creates a new LimiterManager
func NewLimiterManager(cfg *config.ArbitrumConfig) *LimiterManager {
lm := &LimiterManager{
limiters: make(map[string]*EndpointLimiter),
}
// Create limiter for primary endpoint
limiter := createLimiter(cfg.RateLimit)
lm.limiters[cfg.RPCEndpoint] = &EndpointLimiter{
URL: cfg.RPCEndpoint,
Limiter: limiter,
Config: cfg.RateLimit,
}
// Create limiters for reading endpoints
for _, endpoint := range cfg.ReadingEndpoints {
limiter := createLimiter(endpoint.RateLimit)
lm.limiters[endpoint.URL] = &EndpointLimiter{
URL: endpoint.URL,
Limiter: limiter,
Config: endpoint.RateLimit,
}
}
// Create limiters for execution endpoints
for _, endpoint := range cfg.ExecutionEndpoints {
limiter := createLimiter(endpoint.RateLimit)
lm.limiters[endpoint.URL] = &EndpointLimiter{
URL: endpoint.URL,
Limiter: limiter,
Config: endpoint.RateLimit,
}
}
return lm
}
// createLimiter creates a rate limiter based on the configuration
func createLimiter(cfg config.RateLimitConfig) *rate.Limiter {
// Create a rate limiter with the specified rate and burst
r := rate.Limit(cfg.RequestsPerSecond)
return rate.NewLimiter(r, cfg.Burst)
}
// WaitForLimit waits for the rate limiter to allow a request
func (lm *LimiterManager) WaitForLimit(ctx context.Context, endpointURL string) error {
lm.mu.RLock()
limiter, exists := lm.limiters[endpointURL]
lm.mu.RUnlock()
if !exists {
return fmt.Errorf("no rate limiter found for endpoint: %s", endpointURL)
}
// Wait for permission to make a request
return limiter.Limiter.Wait(ctx)
}
// TryWaitForLimit tries to wait for the rate limiter to allow a request without blocking
func (lm *LimiterManager) TryWaitForLimit(ctx context.Context, endpointURL string) error {
lm.mu.RLock()
limiter, exists := lm.limiters[endpointURL]
lm.mu.RUnlock()
if !exists {
return fmt.Errorf("no rate limiter found for endpoint: %s", endpointURL)
}
// Try to wait for permission to make a request without blocking
if !limiter.Limiter.Allow() {
return fmt.Errorf("rate limit exceeded for endpoint: %s", endpointURL)
}
return nil
}
// GetLimiter returns the rate limiter for a specific endpoint
func (lm *LimiterManager) GetLimiter(endpointURL string) (*rate.Limiter, error) {
lm.mu.RLock()
limiter, exists := lm.limiters[endpointURL]
lm.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("no rate limiter found for endpoint: %s", endpointURL)
}
return limiter.Limiter, nil
}
// UpdateLimiter updates the rate limiter for an endpoint
func (lm *LimiterManager) UpdateLimiter(endpointURL string, cfg config.RateLimitConfig) {
lm.mu.Lock()
defer lm.mu.Unlock()
limiter := createLimiter(cfg)
lm.limiters[endpointURL] = &EndpointLimiter{
URL: endpointURL,
Limiter: limiter,
Config: cfg,
}
}
// GetEndpoints returns all endpoint URLs
func (lm *LimiterManager) GetEndpoints() []string {
lm.mu.RLock()
defer lm.mu.RUnlock()
endpoints := make([]string, 0, len(lm.limiters))
for url := range lm.limiters {
endpoints = append(endpoints, url)
}
return endpoints
}

View File

@@ -0,0 +1,243 @@
package ratelimit
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
"github.com/fraktal/mev-beta/internal/config"
)
func TestNewLimiterManager(t *testing.T) {
// Create test config
cfg := &config.ArbitrumConfig{
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 10,
Burst: 20,
},
ReadingEndpoints: []config.EndpointConfig{
{
URL: "https://read.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 5,
Burst: 10,
},
},
},
ExecutionEndpoints: []config.EndpointConfig{
{
URL: "https://exec.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 3,
Burst: 6,
},
},
},
}
// Create limiter manager
lm := NewLimiterManager(cfg)
// Verify limiter manager was created correctly
assert.NotNil(t, lm)
assert.NotNil(t, lm.limiters)
assert.Len(t, lm.limiters, 3) // Primary + 1 fallback
// Check primary endpoint limiter
primaryLimiter, exists := lm.limiters[cfg.RPCEndpoint]
assert.True(t, exists)
assert.Equal(t, cfg.RPCEndpoint, primaryLimiter.URL)
assert.Equal(t, cfg.RateLimit, primaryLimiter.Config)
assert.NotNil(t, primaryLimiter.Limiter)
// Check fallback endpoint limiter
fallbackLimiter, exists := lm.limiters[cfg.ReadingEndpoints[0].URL]
assert.True(t, exists)
assert.Equal(t, cfg.ReadingEndpoints[0].URL, fallbackLimiter.URL)
assert.Equal(t, cfg.ReadingEndpoints[0].RateLimit, fallbackLimiter.Config)
assert.NotNil(t, fallbackLimiter.Limiter)
}
func TestWaitForLimit(t *testing.T) {
// Create test config
cfg := &config.ArbitrumConfig{
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 10,
Burst: 20,
},
}
// Create limiter manager
lm := NewLimiterManager(cfg)
// Test waiting for limit on existing endpoint
ctx := context.Background()
err := lm.WaitForLimit(ctx, cfg.RPCEndpoint)
assert.NoError(t, err)
// Test waiting for limit on non-existing endpoint
err = lm.WaitForLimit(ctx, "https://nonexistent.com")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no rate limiter found for endpoint")
}
func TestTryWaitForLimit(t *testing.T) {
// Create test config
cfg := &config.ArbitrumConfig{
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 10,
Burst: 20,
},
}
// Create limiter manager
lm := NewLimiterManager(cfg)
// Test trying to wait for limit on existing endpoint
ctx := context.Background()
err := lm.TryWaitForLimit(ctx, cfg.RPCEndpoint)
assert.NoError(t, err) // Should succeed since we have burst capacity
// Test trying to wait for limit on non-existing endpoint
err = lm.TryWaitForLimit(ctx, "https://nonexistent.com")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no rate limiter found for endpoint")
}
func TestGetLimiter(t *testing.T) {
// Create test config
cfg := &config.ArbitrumConfig{
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 10,
Burst: 20,
},
}
// Create limiter manager
lm := NewLimiterManager(cfg)
// Test getting limiter for existing endpoint
limiter, err := lm.GetLimiter(cfg.RPCEndpoint)
assert.NoError(t, err)
assert.NotNil(t, limiter)
assert.IsType(t, &rate.Limiter{}, limiter)
// Test getting limiter for non-existing endpoint
limiter, err = lm.GetLimiter("https://nonexistent.com")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no rate limiter found for endpoint")
assert.Nil(t, limiter)
}
func TestUpdateLimiter(t *testing.T) {
// Create test config
cfg := &config.ArbitrumConfig{
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 10,
Burst: 20,
},
}
// Create limiter manager
lm := NewLimiterManager(cfg)
// Get original limiter
originalLimiter, err := lm.GetLimiter(cfg.RPCEndpoint)
assert.NoError(t, err)
assert.NotNil(t, originalLimiter)
// Update the limiter
newConfig := config.RateLimitConfig{
RequestsPerSecond: 20,
Burst: 40,
}
lm.UpdateLimiter(cfg.RPCEndpoint, newConfig)
// Get updated limiter
updatedLimiter, err := lm.GetLimiter(cfg.RPCEndpoint)
assert.NoError(t, err)
assert.NotNil(t, updatedLimiter)
// The limiter should be different (new instance)
assert.NotEqual(t, originalLimiter, updatedLimiter)
// Check that the config was updated
endpointLimiter := lm.limiters[cfg.RPCEndpoint]
assert.Equal(t, newConfig, endpointLimiter.Config)
}
func TestGetEndpoints(t *testing.T) {
// Create test config
cfg := &config.ArbitrumConfig{
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 10,
Burst: 20,
},
ReadingEndpoints: []config.EndpointConfig{
{
URL: "https://fallback1.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 5,
Burst: 10,
},
},
{
URL: "https://fallback2.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 3,
Burst: 6,
},
},
},
}
// Create limiter manager
lm := NewLimiterManager(cfg)
// Get endpoints
endpoints := lm.GetEndpoints()
// Verify results
assert.Len(t, endpoints, 3) // Primary + 2 fallbacks
assert.Contains(t, endpoints, cfg.RPCEndpoint)
assert.Contains(t, endpoints, cfg.ReadingEndpoints[0].URL)
assert.Contains(t, endpoints, cfg.ReadingEndpoints[1].URL)
}
func TestRateLimiting(t *testing.T) {
// Create test config with very low rate limit for testing
cfg := &config.ArbitrumConfig{
RPCEndpoint: "https://arb1.arbitrum.io/rpc",
RateLimit: config.RateLimitConfig{
RequestsPerSecond: 1, // 1 request per second
Burst: 1, // No burst
},
}
// Create limiter manager
lm := NewLimiterManager(cfg)
// Make a request (should succeed immediately)
start := time.Now()
ctx := context.Background()
err := lm.WaitForLimit(ctx, cfg.RPCEndpoint)
assert.NoError(t, err)
duration := time.Since(start)
assert.True(t, duration < time.Millisecond*100, "First request should be fast")
// Make another request immediately (should be delayed)
start = time.Now()
err = lm.WaitForLimit(ctx, cfg.RPCEndpoint)
assert.NoError(t, err)
duration = time.Since(start)
assert.True(t, duration >= time.Second, "Second request should be delayed by rate limiter")
}

View File

@@ -0,0 +1,621 @@
package recovery
import (
"context"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/fraktal/mev-beta/internal/logger"
)
// ErrorSeverity represents the severity level of an error
type ErrorSeverity int
const (
SeverityLow ErrorSeverity = iota
SeverityMedium
SeverityHigh
SeverityCritical
)
func (s ErrorSeverity) String() string {
switch s {
case SeverityLow:
return "LOW"
case SeverityMedium:
return "MEDIUM"
case SeverityHigh:
return "HIGH"
case SeverityCritical:
return "CRITICAL"
default:
return "UNKNOWN"
}
}
// ErrorType categorizes different types of errors
type ErrorType int
const (
ErrorTypeAddressCorruption ErrorType = iota
ErrorTypeContractCallFailed
ErrorTypeRPCConnectionFailed
ErrorTypeDataParsingFailed
ErrorTypeValidationFailed
ErrorTypeTimeoutError
)
func (e ErrorType) String() string {
switch e {
case ErrorTypeAddressCorruption:
return "ADDRESS_CORRUPTION"
case ErrorTypeContractCallFailed:
return "CONTRACT_CALL_FAILED"
case ErrorTypeRPCConnectionFailed:
return "RPC_CONNECTION_FAILED"
case ErrorTypeDataParsingFailed:
return "DATA_PARSING_FAILED"
case ErrorTypeValidationFailed:
return "VALIDATION_FAILED"
case ErrorTypeTimeoutError:
return "TIMEOUT_ERROR"
default:
return "UNKNOWN_ERROR"
}
}
// RecoveryAction represents an action to take when an error occurs
type RecoveryAction int
const (
ActionSkipAndContinue RecoveryAction = iota
ActionRetryWithBackoff
ActionUseFallbackData
ActionCircuitBreaker
ActionEmergencyStop
)
func (a RecoveryAction) String() string {
switch a {
case ActionSkipAndContinue:
return "SKIP_AND_CONTINUE"
case ActionRetryWithBackoff:
return "RETRY_WITH_BACKOFF"
case ActionUseFallbackData:
return "USE_FALLBACK_DATA"
case ActionCircuitBreaker:
return "CIRCUIT_BREAKER"
case ActionEmergencyStop:
return "EMERGENCY_STOP"
default:
return "UNKNOWN_ACTION"
}
}
// ErrorEvent represents a specific error occurrence
type ErrorEvent struct {
Timestamp time.Time
Type ErrorType
Severity ErrorSeverity
Component string
Address common.Address
Message string
Context map[string]interface{}
AttemptCount int
LastAttempt time.Time
Resolved bool
ResolvedAt time.Time
}
// RecoveryRule defines how to handle specific error patterns
type RecoveryRule struct {
ErrorType ErrorType
MaxSeverity ErrorSeverity
Action RecoveryAction
MaxRetries int
BackoffInterval time.Duration
CircuitBreakerThreshold int
ContextMatchers map[string]interface{}
}
// ErrorHandler provides comprehensive error handling and recovery capabilities
type ErrorHandler struct {
mu sync.RWMutex
logger *logger.Logger
errorHistory []ErrorEvent
componentStats map[string]*ComponentStats
circuitBreakers map[string]*CircuitBreaker
recoveryRules []RecoveryRule
fallbackProvider FallbackDataProvider
maxHistorySize int
alertThresholds map[ErrorType]int
enabled bool
}
// ComponentStats tracks error statistics for components
type ComponentStats struct {
mu sync.RWMutex
Component string
TotalErrors int
ErrorsByType map[ErrorType]int
ErrorsBySeverity map[ErrorSeverity]int
LastError time.Time
ConsecutiveFailures int
SuccessCount int
IsHealthy bool
LastHealthCheck time.Time
}
// CircuitBreaker implements circuit breaker pattern for failing components
type CircuitBreaker struct {
mu sync.RWMutex
Name string
State CircuitState
FailureCount int
Threshold int
Timeout time.Duration
LastFailure time.Time
LastSuccess time.Time
HalfOpenAllowed bool
}
type CircuitState int
const (
CircuitClosed CircuitState = iota
CircuitOpen
CircuitHalfOpen
)
func (s CircuitState) String() string {
switch s {
case CircuitClosed:
return "CLOSED"
case CircuitOpen:
return "OPEN"
case CircuitHalfOpen:
return "HALF_OPEN"
default:
return "UNKNOWN"
}
}
// FallbackDataProvider interface for providing fallback data when primary sources fail
type FallbackDataProvider interface {
GetFallbackTokenInfo(ctx context.Context, address common.Address) (*FallbackTokenInfo, error)
GetFallbackPoolInfo(ctx context.Context, address common.Address) (*FallbackPoolInfo, error)
GetFallbackContractType(ctx context.Context, address common.Address) (string, error)
}
type FallbackTokenInfo struct {
Address common.Address
Symbol string
Name string
Decimals uint8
IsVerified bool
Source string
Confidence float64
}
type FallbackPoolInfo struct {
Address common.Address
Token0 common.Address
Token1 common.Address
Protocol string
Fee uint32
IsVerified bool
Source string
Confidence float64
}
// NewErrorHandler creates a new error handler with default configuration
func NewErrorHandler(logger *logger.Logger) *ErrorHandler {
handler := &ErrorHandler{
logger: logger,
errorHistory: make([]ErrorEvent, 0),
componentStats: make(map[string]*ComponentStats),
circuitBreakers: make(map[string]*CircuitBreaker),
maxHistorySize: 1000,
alertThresholds: make(map[ErrorType]int),
enabled: true,
}
// Initialize default recovery rules
handler.initializeDefaultRules()
// Initialize default alert thresholds
handler.initializeAlertThresholds()
return handler
}
// initializeDefaultRules sets up default recovery rules for common error scenarios
func (eh *ErrorHandler) initializeDefaultRules() {
eh.recoveryRules = []RecoveryRule{
{
ErrorType: ErrorTypeAddressCorruption,
MaxSeverity: SeverityMedium,
Action: ActionRetryWithBackoff,
MaxRetries: 2,
BackoffInterval: 500 * time.Millisecond,
},
{
ErrorType: ErrorTypeAddressCorruption,
MaxSeverity: SeverityCritical,
Action: ActionUseFallbackData,
MaxRetries: 0,
BackoffInterval: 0,
},
{
ErrorType: ErrorTypeContractCallFailed,
MaxSeverity: SeverityMedium,
Action: ActionRetryWithBackoff,
MaxRetries: 3,
BackoffInterval: 2 * time.Second,
},
{
ErrorType: ErrorTypeRPCConnectionFailed,
MaxSeverity: SeverityHigh,
Action: ActionCircuitBreaker,
MaxRetries: 5,
BackoffInterval: 5 * time.Second,
CircuitBreakerThreshold: 10,
},
{
ErrorType: ErrorTypeDataParsingFailed,
MaxSeverity: SeverityMedium,
Action: ActionUseFallbackData,
MaxRetries: 2,
BackoffInterval: 1 * time.Second,
},
{
ErrorType: ErrorTypeValidationFailed,
MaxSeverity: SeverityLow,
Action: ActionSkipAndContinue,
MaxRetries: 0,
BackoffInterval: 0,
},
{
ErrorType: ErrorTypeValidationFailed,
MaxSeverity: SeverityHigh,
Action: ActionRetryWithBackoff,
MaxRetries: 1,
BackoffInterval: 500 * time.Millisecond,
},
{
ErrorType: ErrorTypeTimeoutError,
MaxSeverity: SeverityMedium,
Action: ActionRetryWithBackoff,
MaxRetries: 3,
BackoffInterval: 3 * time.Second,
},
}
}
// initializeAlertThresholds sets up alert thresholds for different error types
func (eh *ErrorHandler) initializeAlertThresholds() {
eh.alertThresholds[ErrorTypeAddressCorruption] = 5
eh.alertThresholds[ErrorTypeContractCallFailed] = 20
eh.alertThresholds[ErrorTypeRPCConnectionFailed] = 10
eh.alertThresholds[ErrorTypeDataParsingFailed] = 15
eh.alertThresholds[ErrorTypeValidationFailed] = 25
eh.alertThresholds[ErrorTypeTimeoutError] = 30
}
// HandleError processes an error and determines the appropriate recovery action
func (eh *ErrorHandler) HandleError(ctx context.Context, errorType ErrorType, severity ErrorSeverity, component string, address common.Address, message string, context map[string]interface{}) RecoveryAction {
if !eh.enabled {
return ActionSkipAndContinue
}
eh.mu.Lock()
defer eh.mu.Unlock()
// Record the error event
event := ErrorEvent{
Timestamp: time.Now(),
Type: errorType,
Severity: severity,
Component: component,
Address: address,
Message: message,
Context: context,
AttemptCount: 1,
LastAttempt: time.Now(),
}
// Update error history
eh.addToHistory(event)
// Update component statistics
eh.updateComponentStats(component, errorType, severity)
// Check circuit breakers
if eh.shouldTriggerCircuitBreaker(component, errorType) {
eh.triggerCircuitBreaker(component)
return ActionCircuitBreaker
}
// Find matching recovery rule
rule := eh.findRecoveryRule(errorType, severity, context)
if rule == nil {
// Default action for unmatched errors
return ActionSkipAndContinue
}
// Log the error and recovery action
eh.logger.Error("Error handled by recovery system",
"type", errorType.String(),
"severity", severity.String(),
"component", component,
"address", address.Hex(),
"message", message,
"action", rule.Action.String())
// Check if alert threshold is reached
eh.checkAlertThresholds(errorType)
return rule.Action
}
// addToHistory adds an error event to the history buffer
func (eh *ErrorHandler) addToHistory(event ErrorEvent) {
eh.errorHistory = append(eh.errorHistory, event)
// Trim history if it exceeds max size
if len(eh.errorHistory) > eh.maxHistorySize {
eh.errorHistory = eh.errorHistory[len(eh.errorHistory)-eh.maxHistorySize:]
}
}
// updateComponentStats updates statistics for a component
func (eh *ErrorHandler) updateComponentStats(component string, errorType ErrorType, severity ErrorSeverity) {
stats, exists := eh.componentStats[component]
if !exists {
stats = &ComponentStats{
Component: component,
ErrorsByType: make(map[ErrorType]int),
ErrorsBySeverity: make(map[ErrorSeverity]int),
IsHealthy: true,
}
eh.componentStats[component] = stats
}
stats.mu.Lock()
defer stats.mu.Unlock()
stats.TotalErrors++
stats.ErrorsByType[errorType]++
stats.ErrorsBySeverity[severity]++
stats.LastError = time.Now()
stats.ConsecutiveFailures++
// Mark as unhealthy if too many consecutive failures
if stats.ConsecutiveFailures > 10 {
stats.IsHealthy = false
}
}
// findRecoveryRule finds the best matching recovery rule for an error
func (eh *ErrorHandler) findRecoveryRule(errorType ErrorType, severity ErrorSeverity, context map[string]interface{}) *RecoveryRule {
for _, rule := range eh.recoveryRules {
if rule.ErrorType == errorType && severity <= rule.MaxSeverity {
// Check context matchers if present
if len(rule.ContextMatchers) > 0 {
if !eh.matchesContext(context, rule.ContextMatchers) {
continue
}
}
return &rule
}
}
return nil
}
// matchesContext checks if the error context matches the rule's context matchers
func (eh *ErrorHandler) matchesContext(errorContext, ruleMatchers map[string]interface{}) bool {
for key, expectedValue := range ruleMatchers {
if actualValue, exists := errorContext[key]; !exists || actualValue != expectedValue {
return false
}
}
return true
}
// shouldTriggerCircuitBreaker determines if a circuit breaker should be triggered
func (eh *ErrorHandler) shouldTriggerCircuitBreaker(component string, errorType ErrorType) bool {
stats, exists := eh.componentStats[component]
if !exists {
return false
}
stats.mu.RLock()
defer stats.mu.RUnlock()
// Trigger if consecutive failures exceed threshold for critical errors
if errorType == ErrorTypeRPCConnectionFailed && stats.ConsecutiveFailures >= 5 {
return true
}
if errorType == ErrorTypeAddressCorruption && stats.ConsecutiveFailures >= 3 {
return true
}
return false
}
// triggerCircuitBreaker activates a circuit breaker for a component
func (eh *ErrorHandler) triggerCircuitBreaker(component string) {
breaker := &CircuitBreaker{
Name: component,
State: CircuitOpen,
FailureCount: 0,
Threshold: 5,
Timeout: 30 * time.Second,
LastFailure: time.Now(),
}
eh.circuitBreakers[component] = breaker
eh.logger.Warn("Circuit breaker triggered",
"component", component,
"timeout", breaker.Timeout)
}
// checkAlertThresholds checks if error counts have reached alert thresholds
func (eh *ErrorHandler) checkAlertThresholds(errorType ErrorType) {
threshold, exists := eh.alertThresholds[errorType]
if !exists {
return
}
// Count recent errors of this type (last hour)
recentCount := 0
cutoff := time.Now().Add(-1 * time.Hour)
for _, event := range eh.errorHistory {
if event.Type == errorType && event.Timestamp.After(cutoff) {
recentCount++
}
}
if recentCount >= threshold {
eh.logger.Warn("Error threshold reached - alert triggered",
"error_type", errorType.String(),
"count", recentCount,
"threshold", threshold)
// Here you would trigger your alerting system
}
}
// GetComponentHealth returns the health status of all components
func (eh *ErrorHandler) GetComponentHealth() map[string]*ComponentStats {
eh.mu.RLock()
defer eh.mu.RUnlock()
// Return a copy to prevent external modification
result := make(map[string]*ComponentStats)
for name, stats := range eh.componentStats {
result[name] = &ComponentStats{
Component: stats.Component,
TotalErrors: stats.TotalErrors,
ErrorsByType: make(map[ErrorType]int),
ErrorsBySeverity: make(map[ErrorSeverity]int),
LastError: stats.LastError,
ConsecutiveFailures: stats.ConsecutiveFailures,
SuccessCount: stats.SuccessCount,
IsHealthy: stats.IsHealthy,
LastHealthCheck: stats.LastHealthCheck,
}
// Copy maps
for k, v := range stats.ErrorsByType {
result[name].ErrorsByType[k] = v
}
for k, v := range stats.ErrorsBySeverity {
result[name].ErrorsBySeverity[k] = v
}
}
return result
}
// RecordSuccess records a successful operation for a component
func (eh *ErrorHandler) RecordSuccess(component string) {
eh.mu.Lock()
defer eh.mu.Unlock()
stats, exists := eh.componentStats[component]
if !exists {
stats = &ComponentStats{
Component: component,
ErrorsByType: make(map[ErrorType]int),
ErrorsBySeverity: make(map[ErrorSeverity]int),
IsHealthy: true,
}
eh.componentStats[component] = stats
}
stats.mu.Lock()
defer stats.mu.Unlock()
stats.SuccessCount++
stats.ConsecutiveFailures = 0
stats.IsHealthy = true
stats.LastHealthCheck = time.Now()
// Reset circuit breaker if it exists
if breaker, exists := eh.circuitBreakers[component]; exists {
breaker.mu.Lock()
breaker.State = CircuitClosed
breaker.FailureCount = 0
breaker.LastSuccess = time.Now()
breaker.mu.Unlock()
}
}
// IsCircuitOpen checks if a circuit breaker is open for a component
func (eh *ErrorHandler) IsCircuitOpen(component string) bool {
eh.mu.RLock()
defer eh.mu.RUnlock()
breaker, exists := eh.circuitBreakers[component]
if !exists {
return false
}
breaker.mu.RLock()
defer breaker.mu.RUnlock()
if breaker.State == CircuitOpen {
// Check if timeout has passed
if time.Since(breaker.LastFailure) > breaker.Timeout {
breaker.State = CircuitHalfOpen
breaker.HalfOpenAllowed = true
return false
}
return true
}
return false
}
// SetFallbackProvider sets the fallback data provider
func (eh *ErrorHandler) SetFallbackProvider(provider FallbackDataProvider) {
eh.mu.Lock()
defer eh.mu.Unlock()
eh.fallbackProvider = provider
}
// GetErrorSummary returns a summary of recent errors
func (eh *ErrorHandler) GetErrorSummary(duration time.Duration) map[string]interface{} {
eh.mu.RLock()
defer eh.mu.RUnlock()
cutoff := time.Now().Add(-duration)
summary := map[string]interface{}{
"total_errors": 0,
"errors_by_type": make(map[string]int),
"errors_by_severity": make(map[string]int),
"errors_by_component": make(map[string]int),
"time_range": duration.String(),
}
for _, event := range eh.errorHistory {
if event.Timestamp.After(cutoff) {
summary["total_errors"] = summary["total_errors"].(int) + 1
typeKey := event.Type.String()
summary["errors_by_type"].(map[string]int)[typeKey]++
severityKey := event.Severity.String()
summary["errors_by_severity"].(map[string]int)[severityKey]++
summary["errors_by_component"].(map[string]int)[event.Component]++
}
}
return summary
}

View File

@@ -0,0 +1,384 @@
package recovery
import (
"context"
"fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/fraktal/mev-beta/internal/logger"
"github.com/fraktal/mev-beta/internal/registry"
)
// DefaultFallbackProvider implements FallbackDataProvider with multiple data sources
type DefaultFallbackProvider struct {
mu sync.RWMutex
logger *logger.Logger
contractRegistry *registry.ContractRegistry
staticTokenData map[common.Address]*FallbackTokenInfo
staticPoolData map[common.Address]*FallbackPoolInfo
cacheTimeout time.Duration
enabled bool
}
// NewDefaultFallbackProvider creates a new fallback data provider
func NewDefaultFallbackProvider(logger *logger.Logger, contractRegistry *registry.ContractRegistry) *DefaultFallbackProvider {
provider := &DefaultFallbackProvider{
logger: logger,
contractRegistry: contractRegistry,
staticTokenData: make(map[common.Address]*FallbackTokenInfo),
staticPoolData: make(map[common.Address]*FallbackPoolInfo),
cacheTimeout: 5 * time.Minute,
enabled: true,
}
// Initialize with known safe data
provider.initializeStaticData()
return provider
}
// initializeStaticData populates the provider with known good data for critical Arbitrum contracts
func (fp *DefaultFallbackProvider) initializeStaticData() {
fp.mu.Lock()
defer fp.mu.Unlock()
// Major Arbitrum tokens with verified addresses
fp.staticTokenData[common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1")] = &FallbackTokenInfo{
Address: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"),
Symbol: "WETH",
Name: "Wrapped Ether",
Decimals: 18,
IsVerified: true,
Source: "static_fallback",
Confidence: 1.0,
}
fp.staticTokenData[common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831")] = &FallbackTokenInfo{
Address: common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831"),
Symbol: "USDC",
Name: "USD Coin",
Decimals: 6,
IsVerified: true,
Source: "static_fallback",
Confidence: 1.0,
}
fp.staticTokenData[common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9")] = &FallbackTokenInfo{
Address: common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"),
Symbol: "USDT",
Name: "Tether USD",
Decimals: 6,
IsVerified: true,
Source: "static_fallback",
Confidence: 1.0,
}
fp.staticTokenData[common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f")] = &FallbackTokenInfo{
Address: common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"),
Symbol: "WBTC",
Name: "Wrapped BTC",
Decimals: 8,
IsVerified: true,
Source: "static_fallback",
Confidence: 1.0,
}
fp.staticTokenData[common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548")] = &FallbackTokenInfo{
Address: common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548"),
Symbol: "ARB",
Name: "Arbitrum",
Decimals: 18,
IsVerified: true,
Source: "static_fallback",
Confidence: 1.0,
}
// High-volume Uniswap V3 pools with verified addresses and token pairs
fp.staticPoolData[common.HexToAddress("0xC6962004f452bE9203591991D15f6b388e09E8D0")] = &FallbackPoolInfo{
Address: common.HexToAddress("0xC6962004f452bE9203591991D15f6b388e09E8D0"),
Token0: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
Token1: common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831"), // USDC
Protocol: "UniswapV3",
Fee: 500, // 0.05%
IsVerified: true,
Source: "static_fallback",
Confidence: 1.0,
}
fp.staticPoolData[common.HexToAddress("0x641C00A822e8b671738d32a431a4Fb6074E5c79d")] = &FallbackPoolInfo{
Address: common.HexToAddress("0x641C00A822e8b671738d32a431a4Fb6074E5c79d"),
Token0: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
Token1: common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831"), // USDC
Protocol: "UniswapV3",
Fee: 3000, // 0.3%
IsVerified: true,
Source: "static_fallback",
Confidence: 1.0,
}
fp.staticPoolData[common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d")] = &FallbackPoolInfo{
Address: common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d"),
Token0: common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831"), // USDC
Token1: common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"), // USDT
Protocol: "UniswapV3",
Fee: 100, // 0.01%
IsVerified: true,
Source: "static_fallback",
Confidence: 1.0,
}
fp.staticPoolData[common.HexToAddress("0x2f5e87C032bc4F8526F320c012A4e678F1fa6cAB")] = &FallbackPoolInfo{
Address: common.HexToAddress("0x2f5e87C032bc4F8526F320c012A4e678F1fa6cAB"),
Token0: common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"), // WBTC
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
Protocol: "UniswapV3",
Fee: 500, // 0.05%
IsVerified: true,
Source: "static_fallback",
Confidence: 1.0,
}
fp.logger.Info("Initialized fallback provider with static data",
"tokens", len(fp.staticTokenData),
"pools", len(fp.staticPoolData))
}
// GetFallbackTokenInfo provides fallback token information
func (fp *DefaultFallbackProvider) GetFallbackTokenInfo(ctx context.Context, address common.Address) (*FallbackTokenInfo, error) {
if !fp.enabled {
return nil, fmt.Errorf("fallback provider disabled")
}
fp.mu.RLock()
defer fp.mu.RUnlock()
// First, try static data
if tokenInfo, exists := fp.staticTokenData[address]; exists {
fp.logger.Debug("Fallback token info from static data",
"address", address.Hex(),
"symbol", tokenInfo.Symbol,
"source", tokenInfo.Source)
return tokenInfo, nil
}
// Second, try contract registry if available
if fp.contractRegistry != nil {
if contractInfo, err := fp.contractRegistry.GetContractInfo(ctx, address); err == nil && contractInfo != nil {
tokenInfo := &FallbackTokenInfo{
Address: address,
Symbol: contractInfo.Symbol,
Name: contractInfo.Name,
Decimals: contractInfo.Decimals,
IsVerified: contractInfo.IsVerified,
Source: "contract_registry",
Confidence: contractInfo.Confidence,
}
fp.logger.Debug("Fallback token info from registry",
"address", address.Hex(),
"symbol", tokenInfo.Symbol,
"confidence", tokenInfo.Confidence)
return tokenInfo, nil
}
}
// Third, provide minimal safe fallback for unknown tokens
tokenInfo := &FallbackTokenInfo{
Address: address,
Symbol: fmt.Sprintf("UNK_%s", address.Hex()[:8]),
Name: "Unknown Token",
Decimals: 18, // Safe default
IsVerified: false,
Source: "generated_fallback",
Confidence: 0.1,
}
fp.logger.Warn("Using generated fallback token info",
"address", address.Hex(),
"symbol", tokenInfo.Symbol)
return tokenInfo, nil
}
// GetFallbackPoolInfo provides fallback pool information
func (fp *DefaultFallbackProvider) GetFallbackPoolInfo(ctx context.Context, address common.Address) (*FallbackPoolInfo, error) {
if !fp.enabled {
return nil, fmt.Errorf("fallback provider disabled")
}
fp.mu.RLock()
defer fp.mu.RUnlock()
// First, try static data
if poolInfo, exists := fp.staticPoolData[address]; exists {
fp.logger.Debug("Fallback pool info from static data",
"address", address.Hex(),
"protocol", poolInfo.Protocol,
"token0", poolInfo.Token0.Hex(),
"token1", poolInfo.Token1.Hex())
return poolInfo, nil
}
// Second, try contract registry if available
if fp.contractRegistry != nil {
if poolInfo := fp.contractRegistry.GetPoolInfo(address); poolInfo != nil {
fallbackInfo := &FallbackPoolInfo{
Address: address,
Token0: poolInfo.Token0,
Token1: poolInfo.Token1,
Protocol: poolInfo.Protocol,
Fee: poolInfo.Fee,
IsVerified: poolInfo.IsVerified,
Source: "contract_registry",
Confidence: poolInfo.Confidence,
}
fp.logger.Debug("Fallback pool info from registry",
"address", address.Hex(),
"protocol", fallbackInfo.Protocol,
"confidence", fallbackInfo.Confidence)
return fallbackInfo, nil
}
}
// No fallback available for unknown pools - return error
return nil, fmt.Errorf("no fallback data available for pool %s", address.Hex())
}
// GetFallbackContractType provides fallback contract type information
func (fp *DefaultFallbackProvider) GetFallbackContractType(ctx context.Context, address common.Address) (string, error) {
if !fp.enabled {
return "", fmt.Errorf("fallback provider disabled")
}
fp.mu.RLock()
defer fp.mu.RUnlock()
// Check if it's a known token
if _, exists := fp.staticTokenData[address]; exists {
return "ERC20", nil
}
// Check if it's a known pool
if _, exists := fp.staticPoolData[address]; exists {
return "Pool", nil
}
// Try contract registry
if fp.contractRegistry != nil {
if contractInfo, err := fp.contractRegistry.GetContractInfo(ctx, address); err == nil && contractInfo != nil {
return contractInfo.Type.String(), nil
}
}
// Default to unknown
return "Unknown", nil
}
// AddStaticTokenData adds static token data for fallback use
func (fp *DefaultFallbackProvider) AddStaticTokenData(address common.Address, info *FallbackTokenInfo) {
fp.mu.Lock()
defer fp.mu.Unlock()
fp.staticTokenData[address] = info
fp.logger.Debug("Added static token data",
"address", address.Hex(),
"symbol", info.Symbol)
}
// AddStaticPoolData adds static pool data for fallback use
func (fp *DefaultFallbackProvider) AddStaticPoolData(address common.Address, info *FallbackPoolInfo) {
fp.mu.Lock()
defer fp.mu.Unlock()
fp.staticPoolData[address] = info
fp.logger.Debug("Added static pool data",
"address", address.Hex(),
"protocol", info.Protocol)
}
// IsAddressKnown checks if an address is in the static fallback data
func (fp *DefaultFallbackProvider) IsAddressKnown(address common.Address) bool {
fp.mu.RLock()
defer fp.mu.RUnlock()
_, isToken := fp.staticTokenData[address]
_, isPool := fp.staticPoolData[address]
return isToken || isPool
}
// GetKnownAddresses returns all known addresses in the fallback provider
func (fp *DefaultFallbackProvider) GetKnownAddresses() (tokens []common.Address, pools []common.Address) {
fp.mu.RLock()
defer fp.mu.RUnlock()
for addr := range fp.staticTokenData {
tokens = append(tokens, addr)
}
for addr := range fp.staticPoolData {
pools = append(pools, addr)
}
return tokens, pools
}
// ValidateAddressWithFallback performs validation using fallback data
func (fp *DefaultFallbackProvider) ValidateAddressWithFallback(ctx context.Context, address common.Address, expectedType string) (bool, float64, error) {
if !fp.enabled {
return false, 0.0, fmt.Errorf("fallback provider disabled")
}
// Check if address is known in our static data
if fp.IsAddressKnown(address) {
actualType, err := fp.GetFallbackContractType(ctx, address)
if err != nil {
return false, 0.0, err
}
if actualType == expectedType {
return true, 1.0, nil // High confidence for known addresses
}
return false, 0.0, fmt.Errorf("type mismatch: expected %s, got %s", expectedType, actualType)
}
// For unknown addresses, provide low confidence validation
return true, 0.3, nil // Allow with low confidence
}
// GetStats returns statistics about the fallback provider
func (fp *DefaultFallbackProvider) GetStats() map[string]interface{} {
fp.mu.RLock()
defer fp.mu.RUnlock()
return map[string]interface{}{
"enabled": fp.enabled,
"static_tokens_count": len(fp.staticTokenData),
"static_pools_count": len(fp.staticPoolData),
"cache_timeout": fp.cacheTimeout.String(),
"has_registry": fp.contractRegistry != nil,
}
}
// Enable enables the fallback provider
func (fp *DefaultFallbackProvider) Enable() {
fp.mu.Lock()
defer fp.mu.Unlock()
fp.enabled = true
fp.logger.Info("Fallback provider enabled")
}
// Disable disables the fallback provider
func (fp *DefaultFallbackProvider) Disable() {
fp.mu.Lock()
defer fp.mu.Unlock()
fp.enabled = false
fp.logger.Info("Fallback provider disabled")
}

View File

@@ -0,0 +1,446 @@
package recovery
import (
"context"
"math"
"sync"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// RetryConfig defines retry behavior configuration
type RetryConfig struct {
MaxAttempts int
InitialDelay time.Duration
MaxDelay time.Duration
BackoffFactor float64
JitterEnabled bool
TimeoutPerAttempt time.Duration
}
// DefaultRetryConfig returns a sensible default retry configuration
func DefaultRetryConfig() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
JitterEnabled: true,
TimeoutPerAttempt: 10 * time.Second,
}
}
// RetryableOperation represents an operation that can be retried
type RetryableOperation func(ctx context.Context, attempt int) error
// RetryHandler provides exponential backoff retry capabilities
type RetryHandler struct {
mu sync.RWMutex
logger *logger.Logger
configs map[string]RetryConfig
stats map[string]*RetryStats
enabled bool
}
// RetryStats tracks retry statistics for operations
type RetryStats struct {
mu sync.RWMutex
OperationType string
TotalAttempts int
SuccessfulRetries int
FailedRetries int
AverageAttempts float64
LastAttempt time.Time
LastSuccess time.Time
LastFailure time.Time
}
// RetryResult contains the result of a retry operation
type RetryResult struct {
Success bool
Attempts int
TotalDuration time.Duration
LastError error
LastAttemptAt time.Time
}
// NewRetryHandler creates a new retry handler
func NewRetryHandler(logger *logger.Logger) *RetryHandler {
handler := &RetryHandler{
logger: logger,
configs: make(map[string]RetryConfig),
stats: make(map[string]*RetryStats),
enabled: true,
}
// Initialize default configurations for common operations
handler.initializeDefaultConfigs()
return handler
}
// initializeDefaultConfigs sets up default retry configurations
func (rh *RetryHandler) initializeDefaultConfigs() {
// Contract call retries - moderate backoff
rh.configs["contract_call"] = RetryConfig{
MaxAttempts: 3,
InitialDelay: 500 * time.Millisecond,
MaxDelay: 5 * time.Second,
BackoffFactor: 2.0,
JitterEnabled: true,
TimeoutPerAttempt: 10 * time.Second,
}
// RPC connection retries - aggressive backoff
rh.configs["rpc_connection"] = RetryConfig{
MaxAttempts: 5,
InitialDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.5,
JitterEnabled: true,
TimeoutPerAttempt: 15 * time.Second,
}
// Data parsing retries - quick retries
rh.configs["data_parsing"] = RetryConfig{
MaxAttempts: 2,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 1 * time.Second,
BackoffFactor: 2.0,
JitterEnabled: false,
TimeoutPerAttempt: 5 * time.Second,
}
// Block processing retries - conservative
rh.configs["block_processing"] = RetryConfig{
MaxAttempts: 3,
InitialDelay: 2 * time.Second,
MaxDelay: 10 * time.Second,
BackoffFactor: 2.0,
JitterEnabled: true,
TimeoutPerAttempt: 30 * time.Second,
}
// Token metadata retries - patient backoff
rh.configs["token_metadata"] = RetryConfig{
MaxAttempts: 4,
InitialDelay: 1 * time.Second,
MaxDelay: 20 * time.Second,
BackoffFactor: 2.0,
JitterEnabled: true,
TimeoutPerAttempt: 15 * time.Second,
}
}
// ExecuteWithRetry executes an operation with retry logic
func (rh *RetryHandler) ExecuteWithRetry(ctx context.Context, operationType string, operation RetryableOperation) *RetryResult {
if !rh.enabled {
// If retries are disabled, try once
err := operation(ctx, 1)
return &RetryResult{
Success: err == nil,
Attempts: 1,
TotalDuration: 0,
LastError: err,
LastAttemptAt: time.Now(),
}
}
config := rh.getConfig(operationType)
start := time.Now()
var lastError error
rh.mu.Lock()
stats, exists := rh.stats[operationType]
if !exists {
stats = &RetryStats{
OperationType: operationType,
}
rh.stats[operationType] = stats
}
rh.mu.Unlock()
for attempt := 1; attempt <= config.MaxAttempts; attempt++ {
// Create context with timeout for this attempt
attemptCtx, cancel := context.WithTimeout(ctx, config.TimeoutPerAttempt)
rh.logger.Debug("Attempting operation with retry",
"operation", operationType,
"attempt", attempt,
"max_attempts", config.MaxAttempts)
// Execute the operation
err := operation(attemptCtx, attempt)
cancel()
// Update statistics
stats.mu.Lock()
stats.TotalAttempts++
stats.LastAttempt = time.Now()
stats.mu.Unlock()
if err == nil {
// Success!
duration := time.Since(start)
stats.mu.Lock()
stats.SuccessfulRetries++
stats.LastSuccess = time.Now()
denominator := stats.SuccessfulRetries + stats.FailedRetries
if denominator > 0 {
stats.AverageAttempts = float64(stats.TotalAttempts) / float64(denominator)
}
stats.mu.Unlock()
rh.logger.Debug("Operation succeeded",
"operation", operationType,
"attempt", attempt,
"duration", duration)
return &RetryResult{
Success: true,
Attempts: attempt,
TotalDuration: duration,
LastError: nil,
LastAttemptAt: time.Now(),
}
}
lastError = err
// Check if context was cancelled
if ctx.Err() != nil {
rh.logger.Debug("Operation cancelled by context",
"operation", operationType,
"attempt", attempt,
"error", ctx.Err())
break
}
// Don't wait after the last attempt
if attempt < config.MaxAttempts {
delay := rh.calculateDelay(config, attempt)
rh.logger.Debug("Operation failed, retrying",
"operation", operationType,
"attempt", attempt,
"error", err,
"delay", delay)
// Wait before next attempt
select {
case <-time.After(delay):
// Continue to next attempt
case <-ctx.Done():
// Context cancelled during wait
break
}
} else {
rh.logger.Warn("Operation failed after all retries",
"operation", operationType,
"attempts", attempt,
"error", err)
}
}
// All attempts failed
duration := time.Since(start)
stats.mu.Lock()
stats.FailedRetries++
stats.LastFailure = time.Now()
denominator := stats.SuccessfulRetries + stats.FailedRetries
if denominator > 0 {
stats.AverageAttempts = float64(stats.TotalAttempts) / float64(denominator)
}
stats.mu.Unlock()
return &RetryResult{
Success: false,
Attempts: config.MaxAttempts,
TotalDuration: duration,
LastError: lastError,
LastAttemptAt: time.Now(),
}
}
// calculateDelay calculates the delay before the next retry attempt
func (rh *RetryHandler) calculateDelay(config RetryConfig, attempt int) time.Duration {
// Calculate exponential backoff
delay := float64(config.InitialDelay) * math.Pow(config.BackoffFactor, float64(attempt-1))
// Apply maximum delay cap
if delay > float64(config.MaxDelay) {
delay = float64(config.MaxDelay)
}
duration := time.Duration(delay)
// Add jitter if enabled
if config.JitterEnabled {
jitter := time.Duration(float64(duration) * 0.1 * (2*rh.randomFloat() - 1))
duration += jitter
}
// Ensure minimum delay
if duration < 0 {
duration = config.InitialDelay
}
return duration
}
// randomFloat returns a pseudo-random float between 0 and 1
func (rh *RetryHandler) randomFloat() float64 {
// Simple pseudo-random number based on current time
return float64(time.Now().UnixNano()%1000) / 1000.0
}
// getConfig returns the retry configuration for an operation type
func (rh *RetryHandler) getConfig(operationType string) RetryConfig {
rh.mu.RLock()
defer rh.mu.RUnlock()
if config, exists := rh.configs[operationType]; exists {
return config
}
// Return default config if no specific config found
return DefaultRetryConfig()
}
// SetConfig sets a custom retry configuration for an operation type
func (rh *RetryHandler) SetConfig(operationType string, config RetryConfig) {
rh.mu.Lock()
defer rh.mu.Unlock()
rh.configs[operationType] = config
rh.logger.Debug("Set retry config",
"operation", operationType,
"max_attempts", config.MaxAttempts,
"initial_delay", config.InitialDelay,
"max_delay", config.MaxDelay)
}
// GetStats returns retry statistics for all operation types
func (rh *RetryHandler) GetStats() map[string]*RetryStats {
rh.mu.RLock()
defer rh.mu.RUnlock()
// Return a copy to prevent external modification
result := make(map[string]*RetryStats)
for opType, stats := range rh.stats {
stats.mu.RLock()
result[opType] = &RetryStats{
OperationType: stats.OperationType,
TotalAttempts: stats.TotalAttempts,
SuccessfulRetries: stats.SuccessfulRetries,
FailedRetries: stats.FailedRetries,
AverageAttempts: stats.AverageAttempts,
LastAttempt: stats.LastAttempt,
LastSuccess: stats.LastSuccess,
LastFailure: stats.LastFailure,
}
stats.mu.RUnlock()
}
return result
}
// GetOperationStats returns statistics for a specific operation type
func (rh *RetryHandler) GetOperationStats(operationType string) *RetryStats {
rh.mu.RLock()
defer rh.mu.RUnlock()
stats, exists := rh.stats[operationType]
if !exists {
return nil
}
stats.mu.RLock()
defer stats.mu.RUnlock()
return &RetryStats{
OperationType: stats.OperationType,
TotalAttempts: stats.TotalAttempts,
SuccessfulRetries: stats.SuccessfulRetries,
FailedRetries: stats.FailedRetries,
AverageAttempts: stats.AverageAttempts,
LastAttempt: stats.LastAttempt,
LastSuccess: stats.LastSuccess,
LastFailure: stats.LastFailure,
}
}
// ResetStats resets statistics for all operation types
func (rh *RetryHandler) ResetStats() {
rh.mu.Lock()
defer rh.mu.Unlock()
rh.stats = make(map[string]*RetryStats)
rh.logger.Info("Reset retry statistics")
}
// Enable enables the retry handler
func (rh *RetryHandler) Enable() {
rh.mu.Lock()
defer rh.mu.Unlock()
rh.enabled = true
rh.logger.Info("Retry handler enabled")
}
// Disable disables the retry handler
func (rh *RetryHandler) Disable() {
rh.mu.Lock()
defer rh.mu.Unlock()
rh.enabled = false
rh.logger.Info("Retry handler disabled")
}
// IsEnabled returns whether the retry handler is enabled
func (rh *RetryHandler) IsEnabled() bool {
rh.mu.RLock()
defer rh.mu.RUnlock()
return rh.enabled
}
// GetHealthSummary returns a health summary based on retry statistics
func (rh *RetryHandler) GetHealthSummary() map[string]interface{} {
stats := rh.GetStats()
summary := map[string]interface{}{
"enabled": rh.enabled,
"total_operations": len(stats),
"healthy_operations": 0,
"unhealthy_operations": 0,
"operation_details": make(map[string]interface{}),
}
for opType, opStats := range stats {
total := opStats.SuccessfulRetries + opStats.FailedRetries
successRate := 0.0
if total > 0 {
successRate = float64(opStats.SuccessfulRetries) / float64(total)
}
isHealthy := successRate >= 0.9 && opStats.AverageAttempts <= 2.0
if isHealthy {
summary["healthy_operations"] = summary["healthy_operations"].(int) + 1
} else {
summary["unhealthy_operations"] = summary["unhealthy_operations"].(int) + 1
}
summary["operation_details"].(map[string]interface{})[opType] = map[string]interface{}{
"success_rate": successRate,
"average_attempts": opStats.AverageAttempts,
"total_operations": total,
"is_healthy": isHealthy,
"last_success": opStats.LastSuccess,
"last_failure": opStats.LastFailure,
}
}
return summary
}

View File

@@ -0,0 +1,362 @@
package recovery
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/fraktal/mev-beta/internal/logger"
)
func TestRetryHandler_ExecuteWithRetry_Success(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
attempts := 0
operation := func(ctx context.Context, attempt int) error {
attempts++
if attempts == 2 {
return nil // Success on second attempt
}
return errors.New("temporary failure")
}
result := handler.ExecuteWithRetry(context.Background(), "test_operation", operation)
assert.True(t, result.Success)
assert.Equal(t, 2, result.Attempts)
assert.Nil(t, result.LastError)
assert.Equal(t, 2, attempts)
}
func TestRetryHandler_ExecuteWithRetry_MaxAttemptsReached(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
attempts := 0
operation := func(ctx context.Context, attempt int) error {
attempts++
return errors.New("persistent failure")
}
result := handler.ExecuteWithRetry(context.Background(), "test_operation", operation)
assert.False(t, result.Success)
assert.Equal(t, 3, result.Attempts) // Default max attempts
assert.NotNil(t, result.LastError)
assert.Equal(t, "persistent failure", result.LastError.Error())
assert.Equal(t, 3, attempts)
}
func TestRetryHandler_ExecuteWithRetry_ContextCanceled(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
ctx, cancel := context.WithCancel(context.Background())
attempts := 0
operation := func(ctx context.Context, attempt int) error {
attempts++
if attempts == 2 {
cancel() // Cancel context on second attempt
}
return errors.New("failure")
}
result := handler.ExecuteWithRetry(ctx, "test_operation", operation)
assert.False(t, result.Success)
assert.LessOrEqual(t, result.Attempts, 3)
assert.NotNil(t, result.LastError)
}
func TestRetryHandler_ExecuteWithRetry_CustomConfig(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
// Set custom configuration
customConfig := RetryConfig{
MaxAttempts: 5,
InitialDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
BackoffFactor: 2.0,
JitterEnabled: false,
TimeoutPerAttempt: 1 * time.Second,
}
handler.SetConfig("custom_operation", customConfig)
attempts := 0
operation := func(ctx context.Context, attempt int) error {
attempts++
return errors.New("persistent failure")
}
start := time.Now()
result := handler.ExecuteWithRetry(context.Background(), "custom_operation", operation)
duration := time.Since(start)
assert.False(t, result.Success)
assert.Equal(t, 5, result.Attempts) // Custom max attempts
assert.Equal(t, 5, attempts)
// Should have taken some time due to delays (at least 150ms for delays)
expectedMinDuration := 10*time.Millisecond + 20*time.Millisecond + 40*time.Millisecond + 80*time.Millisecond
assert.GreaterOrEqual(t, duration, expectedMinDuration)
}
func TestRetryHandler_ExecuteWithRetry_Disabled(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
handler.Disable()
attempts := 0
operation := func(ctx context.Context, attempt int) error {
attempts++
return errors.New("failure")
}
result := handler.ExecuteWithRetry(context.Background(), "test_operation", operation)
assert.False(t, result.Success)
assert.Equal(t, 1, result.Attempts) // Only one attempt when disabled
assert.Equal(t, 1, attempts)
}
func TestRetryHandler_CalculateDelay(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
config := RetryConfig{
InitialDelay: 100 * time.Millisecond,
MaxDelay: 1 * time.Second,
BackoffFactor: 2.0,
JitterEnabled: false,
}
tests := []struct {
attempt int
expectedMin time.Duration
expectedMax time.Duration
}{
{1, 100 * time.Millisecond, 100 * time.Millisecond},
{2, 200 * time.Millisecond, 200 * time.Millisecond},
{3, 400 * time.Millisecond, 400 * time.Millisecond},
{4, 800 * time.Millisecond, 800 * time.Millisecond},
{5, 1 * time.Second, 1 * time.Second}, // Should be capped at MaxDelay
}
for _, tt := range tests {
t.Run(fmt.Sprintf("attempt_%d", tt.attempt), func(t *testing.T) {
delay := handler.calculateDelay(config, tt.attempt)
assert.GreaterOrEqual(t, delay, tt.expectedMin)
assert.LessOrEqual(t, delay, tt.expectedMax)
})
}
}
func TestRetryHandler_CalculateDelay_WithJitter(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
config := RetryConfig{
InitialDelay: 100 * time.Millisecond,
MaxDelay: 1 * time.Second,
BackoffFactor: 2.0,
JitterEnabled: true,
}
// Test jitter variation
delays := make([]time.Duration, 10)
for i := 0; i < 10; i++ {
delays[i] = handler.calculateDelay(config, 2) // 200ms base
}
// Should have some variation due to jitter
allSame := true
for i := 1; i < len(delays); i++ {
if delays[i] != delays[0] {
allSame = false
break
}
}
assert.False(t, allSame, "Jitter should cause variation in delays")
// All delays should be reasonable (within 10% of base)
baseDelay := 200 * time.Millisecond
for _, delay := range delays {
assert.GreaterOrEqual(t, delay, baseDelay*9/10) // 10% below
assert.LessOrEqual(t, delay, baseDelay*11/10) // 10% above
}
}
func TestRetryHandler_GetStats(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
// Execute some operations
successOp := func(ctx context.Context, attempt int) error {
return nil
}
failOp := func(ctx context.Context, attempt int) error {
return errors.New("failure")
}
handler.ExecuteWithRetry(context.Background(), "test_success", successOp)
handler.ExecuteWithRetry(context.Background(), "test_success", successOp)
handler.ExecuteWithRetry(context.Background(), "test_fail", failOp)
stats := handler.GetStats()
// Check success stats
successStats := stats["test_success"]
require.NotNil(t, successStats)
assert.Equal(t, 2, successStats.TotalAttempts)
assert.Equal(t, 2, successStats.SuccessfulRetries)
assert.Equal(t, 0, successStats.FailedRetries)
// Check failure stats
failStats := stats["test_fail"]
require.NotNil(t, failStats)
assert.Equal(t, 3, failStats.TotalAttempts) // Default max attempts
assert.Equal(t, 0, failStats.SuccessfulRetries)
assert.Equal(t, 1, failStats.FailedRetries)
}
func TestRetryHandler_GetHealthSummary(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
// Execute some operations to generate stats
successOp := func(ctx context.Context, attempt int) error {
return nil
}
partialFailOp := func(ctx context.Context, attempt int) error {
if attempt < 2 {
return errors.New("temporary failure")
}
return nil
}
// 2 immediate successes
handler.ExecuteWithRetry(context.Background(), "immediate_success", successOp)
handler.ExecuteWithRetry(context.Background(), "immediate_success", successOp)
// 1 success after retry
handler.ExecuteWithRetry(context.Background(), "retry_success", partialFailOp)
summary := handler.GetHealthSummary()
assert.True(t, summary["enabled"].(bool))
assert.Equal(t, 2, summary["total_operations"].(int))
assert.Equal(t, 2, summary["healthy_operations"].(int))
assert.Equal(t, 0, summary["unhealthy_operations"].(int))
// Check operation details
details := summary["operation_details"].(map[string]interface{})
immediateDetails := details["immediate_success"].(map[string]interface{})
assert.Equal(t, 1.0, immediateDetails["success_rate"].(float64))
assert.Equal(t, 1.0, immediateDetails["average_attempts"].(float64))
assert.True(t, immediateDetails["is_healthy"].(bool))
retryDetails := details["retry_success"].(map[string]interface{})
assert.Equal(t, 1.0, retryDetails["success_rate"].(float64))
assert.Equal(t, 2.0, retryDetails["average_attempts"].(float64))
assert.True(t, retryDetails["is_healthy"].(bool)) // Still healthy despite retries
}
func TestRetryHandler_ConcurrentExecution(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
const numGoroutines = 50
const operationsPerGoroutine = 20
done := make(chan bool, numGoroutines)
successCount := make(chan int, numGoroutines)
operation := func(ctx context.Context, attempt int) error {
// 80% success rate
if attempt <= 1 && time.Now().UnixNano()%5 != 0 {
return nil
}
if attempt == 2 {
return nil // Always succeed on second attempt
}
return errors.New("failure")
}
// Launch concurrent retry operations
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer func() { done <- true }()
successes := 0
for j := 0; j < operationsPerGoroutine; j++ {
result := handler.ExecuteWithRetry(context.Background(),
fmt.Sprintf("concurrent_op_%d", id), operation)
if result.Success {
successes++
}
}
successCount <- successes
}(i)
}
// Collect results
totalSuccesses := 0
for i := 0; i < numGoroutines; i++ {
select {
case <-done:
totalSuccesses += <-successCount
case <-time.After(30 * time.Second):
t.Fatal("Concurrent retry test timed out")
}
}
totalOperations := numGoroutines * operationsPerGoroutine
successRate := float64(totalSuccesses) / float64(totalOperations)
t.Logf("Concurrent execution: %d/%d operations succeeded (%.2f%%)",
totalSuccesses, totalOperations, successRate*100)
// Should have high success rate due to retries
assert.GreaterOrEqual(t, successRate, 0.8, "Success rate should be at least 80%")
// Verify stats are consistent
stats := handler.GetStats()
assert.NotEmpty(t, stats, "Should have recorded stats")
}
func TestRetryHandler_EdgeCases(t *testing.T) {
log := logger.New("debug", "text", "")
handler := NewRetryHandler(log)
t.Run("nil operation", func(t *testing.T) {
assert.Panics(t, func() {
handler.ExecuteWithRetry(context.Background(), "nil_op", nil)
})
})
t.Run("empty operation type", func(t *testing.T) {
operation := func(ctx context.Context, attempt int) error {
return nil
}
result := handler.ExecuteWithRetry(context.Background(), "", operation)
assert.True(t, result.Success)
})
t.Run("very long operation type", func(t *testing.T) {
longName := string(make([]byte, 1000))
operation := func(ctx context.Context, attempt int) error {
return nil
}
result := handler.ExecuteWithRetry(context.Background(), longName, operation)
assert.True(t, result.Success)
})
}

View File

@@ -0,0 +1,493 @@
package registry
import (
"context"
"fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/fraktal/mev-beta/internal/contracts"
"github.com/fraktal/mev-beta/internal/logger"
)
// ContractInfo contains comprehensive information about a contract
type ContractInfo struct {
Address common.Address
Type contracts.ContractType
Name string
Symbol string
Decimals uint8
Token0 common.Address // For pools
Token1 common.Address // For pools
Fee uint32 // For V3 pools (in basis points)
Factory common.Address
Protocol string
IsVerified bool
LastUpdated time.Time
Confidence float64
}
// ContractRegistry provides authoritative mapping of contract addresses to their types and metadata
type ContractRegistry struct {
mu sync.RWMutex
contracts map[common.Address]*ContractInfo
tokensBySymbol map[string]common.Address
poolsByTokenPair map[string][]common.Address // "token0:token1" -> pool addresses
detector *contracts.ContractDetector
logger *logger.Logger
updateInterval time.Duration
lastFullUpdate time.Time
}
// NewContractRegistry creates a new contract registry
func NewContractRegistry(detector *contracts.ContractDetector, logger *logger.Logger) *ContractRegistry {
registry := &ContractRegistry{
contracts: make(map[common.Address]*ContractInfo),
tokensBySymbol: make(map[string]common.Address),
poolsByTokenPair: make(map[string][]common.Address),
detector: detector,
logger: logger,
updateInterval: 24 * time.Hour, // Update every 24 hours
lastFullUpdate: time.Time{},
}
// Initialize with known Arbitrum contracts
registry.initializeKnownContracts()
return registry
}
// initializeKnownContracts populates the registry with well-known Arbitrum contracts
func (cr *ContractRegistry) initializeKnownContracts() {
cr.mu.Lock()
defer cr.mu.Unlock()
now := time.Now()
// Major ERC-20 tokens on Arbitrum
knownTokens := map[common.Address]*ContractInfo{
common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"): {
Address: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"),
Type: contracts.ContractTypeERC20Token,
Name: "Wrapped Ether",
Symbol: "WETH",
Decimals: 18,
Protocol: "Arbitrum",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
common.HexToAddress("0xA0b86a33E6D8E4BBa6Fd6bD5BA0e2FF8A1e8B8D4"): {
Address: common.HexToAddress("0xA0b86a33E6D8E4BBa6Fd6bD5BA0e2FF8A1e8B8D4"),
Type: contracts.ContractTypeERC20Token,
Name: "USD Coin",
Symbol: "USDC",
Decimals: 6,
Protocol: "Arbitrum",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"): {
Address: common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"),
Type: contracts.ContractTypeERC20Token,
Name: "Tether USD",
Symbol: "USDT",
Decimals: 6,
Protocol: "Arbitrum",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"): {
Address: common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"),
Type: contracts.ContractTypeERC20Token,
Name: "Wrapped BTC",
Symbol: "WBTC",
Decimals: 8,
Protocol: "Arbitrum",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548"): {
Address: common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548"),
Type: contracts.ContractTypeERC20Token,
Name: "Arbitrum",
Symbol: "ARB",
Decimals: 18,
Protocol: "Arbitrum",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
}
// Major Uniswap V3 pools on Arbitrum
knownPools := map[common.Address]*ContractInfo{
common.HexToAddress("0xC6962004f452bE9203591991D15f6b388e09E8D0"): {
Address: common.HexToAddress("0xC6962004f452bE9203591991D15f6b388e09E8D0"),
Type: contracts.ContractTypeUniswapV3Pool,
Name: "USDC/WETH 0.05%",
Token0: common.HexToAddress("0xA0b86a33E6D8E4BBa6Fd6bD5BA0e2FF8A1e8B8D4"), // USDC
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
Fee: 500, // 0.05%
Factory: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
Protocol: "UniswapV3",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d"): {
Address: common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d"),
Type: contracts.ContractTypeUniswapV3Pool,
Name: "USDC/WETH 0.3%",
Token0: common.HexToAddress("0xA0b86a33E6D8E4BBa6Fd6bD5BA0e2FF8A1e8B8D4"), // USDC
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
Fee: 3000, // 0.3%
Factory: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
Protocol: "UniswapV3",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
common.HexToAddress("0x2f5e87C9312fa29aed5c179E456625D79015299c"): {
Address: common.HexToAddress("0x2f5e87C9312fa29aed5c179E456625D79015299c"),
Type: contracts.ContractTypeUniswapV3Pool,
Name: "WBTC/WETH 0.05%",
Token0: common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"), // WBTC
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
Fee: 500, // 0.05%
Factory: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
Protocol: "UniswapV3",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
common.HexToAddress("0x641C00A822e8b671738d32a431a4Fb6074E5c79d"): {
Address: common.HexToAddress("0x641C00A822e8b671738d32a431a4Fb6074E5c79d"),
Type: contracts.ContractTypeUniswapV3Pool,
Name: "USDT/WETH 0.05%",
Token0: common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"), // USDT
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
Fee: 500, // 0.05%
Factory: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
Protocol: "UniswapV3",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
common.HexToAddress("0xFe7D6a84287235C7b4b57C4fEb9a44d4C6Ed3BB8"): {
Address: common.HexToAddress("0xFe7D6a84287235C7b4b57C4fEb9a44d4C6Ed3BB8"),
Type: contracts.ContractTypeUniswapV3Pool,
Name: "ARB/WETH 0.05%",
Token0: common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548"), // ARB
Token1: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // WETH
Fee: 500, // 0.05%
Factory: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
Protocol: "UniswapV3",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
}
// Major routers on Arbitrum
knownRouters := map[common.Address]*ContractInfo{
common.HexToAddress("0xE592427A0AEce92De3Edee1F18E0157C05861564"): {
Address: common.HexToAddress("0xE592427A0AEce92De3Edee1F18E0157C05861564"),
Type: contracts.ContractTypeUniswapV3Router,
Name: "Uniswap V3 Router",
Protocol: "UniswapV3",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
common.HexToAddress("0x68b3465833fb72A70ecDF485E0e4C7bD8665Fc45"): {
Address: common.HexToAddress("0x68b3465833fb72A70ecDF485E0e4C7bD8665Fc45"),
Type: contracts.ContractTypeUniswapV3Router,
Name: "Uniswap V3 Router 2",
Protocol: "UniswapV3",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
common.HexToAddress("0x1b02dA8Cb0d097eB8D57A175b88c7D8b47997506"): {
Address: common.HexToAddress("0x1b02dA8Cb0d097eB8D57A175b88c7D8b47997506"),
Type: contracts.ContractTypeUniswapV2Router,
Name: "SushiSwap Router",
Protocol: "SushiSwap",
IsVerified: true,
LastUpdated: now,
Confidence: 1.0,
},
}
// Add all known contracts
for addr, info := range knownTokens {
cr.contracts[addr] = info
cr.tokensBySymbol[info.Symbol] = addr
}
for addr, info := range knownPools {
cr.contracts[addr] = info
// Index pools by token pair
if info.Token0 != (common.Address{}) && info.Token1 != (common.Address{}) {
pairKey := cr.makeTokenPairKey(info.Token0, info.Token1)
cr.poolsByTokenPair[pairKey] = append(cr.poolsByTokenPair[pairKey], addr)
}
}
for addr, info := range knownRouters {
cr.contracts[addr] = info
}
cr.logger.Info("Contract registry initialized",
"tokens", len(knownTokens),
"pools", len(knownPools),
"routers", len(knownRouters))
}
// GetContractInfo retrieves contract information, using cache first then detection
func (cr *ContractRegistry) GetContractInfo(ctx context.Context, address common.Address) (*ContractInfo, error) {
// Try cache first
if info := cr.getCachedInfo(address); info != nil {
return info, nil
}
// Not in cache, use detector
detection := cr.detector.DetectContractType(ctx, address)
if detection.Error != nil {
return nil, fmt.Errorf("contract detection failed: %w", detection.Error)
}
// Create contract info from detection
info := &ContractInfo{
Address: address,
Type: detection.ContractType,
Name: fmt.Sprintf("Unknown %s", detection.ContractType.String()),
Protocol: "Unknown",
IsVerified: false,
LastUpdated: time.Now(),
Confidence: detection.Confidence,
}
// Cache the result
cr.cacheContractInfo(info)
return info, nil
}
// getCachedInfo safely retrieves cached contract info
func (cr *ContractRegistry) getCachedInfo(address common.Address) *ContractInfo {
cr.mu.RLock()
defer cr.mu.RUnlock()
return cr.contracts[address]
}
// cacheContractInfo safely caches contract info
func (cr *ContractRegistry) cacheContractInfo(info *ContractInfo) {
cr.mu.Lock()
defer cr.mu.Unlock()
cr.contracts[info.Address] = info
}
// IsKnownToken checks if an address is a known ERC-20 token
func (cr *ContractRegistry) IsKnownToken(address common.Address) bool {
cr.mu.RLock()
defer cr.mu.RUnlock()
if info, exists := cr.contracts[address]; exists {
return info.Type == contracts.ContractTypeERC20Token
}
return false
}
// IsKnownPool checks if an address is a known pool
func (cr *ContractRegistry) IsKnownPool(address common.Address) bool {
cr.mu.RLock()
defer cr.mu.RUnlock()
if info, exists := cr.contracts[address]; exists {
return info.Type == contracts.ContractTypeUniswapV2Pool ||
info.Type == contracts.ContractTypeUniswapV3Pool
}
return false
}
// IsKnownRouter checks if an address is a known router
func (cr *ContractRegistry) IsKnownRouter(address common.Address) bool {
cr.mu.RLock()
defer cr.mu.RUnlock()
if info, exists := cr.contracts[address]; exists {
return info.Type == contracts.ContractTypeUniswapV2Router ||
info.Type == contracts.ContractTypeUniswapV3Router ||
info.Type == contracts.ContractTypeUniversalRouter
}
return false
}
// GetTokenBySymbol retrieves a token address by symbol
func (cr *ContractRegistry) GetTokenBySymbol(symbol string) (common.Address, bool) {
cr.mu.RLock()
defer cr.mu.RUnlock()
addr, exists := cr.tokensBySymbol[symbol]
return addr, exists
}
// GetPoolsForTokenPair retrieves pools for a given token pair
func (cr *ContractRegistry) GetPoolsForTokenPair(token0, token1 common.Address) []common.Address {
cr.mu.RLock()
defer cr.mu.RUnlock()
pairKey := cr.makeTokenPairKey(token0, token1)
return cr.poolsByTokenPair[pairKey]
}
// makeTokenPairKey creates a consistent key for token pairs
func (cr *ContractRegistry) makeTokenPairKey(token0, token1 common.Address) string {
// Ensure consistent ordering
if token0.Big().Cmp(token1.Big()) > 0 {
token0, token1 = token1, token0
}
return fmt.Sprintf("%s:%s", token0.Hex(), token1.Hex())
}
// GetKnownContracts returns all cached contracts
func (cr *ContractRegistry) GetKnownContracts() map[common.Address]*ContractInfo {
cr.mu.RLock()
defer cr.mu.RUnlock()
// Return a copy to prevent external modification
result := make(map[common.Address]*ContractInfo)
for addr, info := range cr.contracts {
result[addr] = info
}
return result
}
// UpdateContractInfo updates contract information (for dynamic discovery)
func (cr *ContractRegistry) UpdateContractInfo(info *ContractInfo) {
cr.mu.Lock()
defer cr.mu.Unlock()
info.LastUpdated = time.Now()
cr.contracts[info.Address] = info
// Update indexes
if info.Type == contracts.ContractTypeERC20Token && info.Symbol != "" {
cr.tokensBySymbol[info.Symbol] = info.Address
}
if (info.Type == contracts.ContractTypeUniswapV2Pool || info.Type == contracts.ContractTypeUniswapV3Pool) &&
info.Token0 != (common.Address{}) && info.Token1 != (common.Address{}) {
pairKey := cr.makeTokenPairKey(info.Token0, info.Token1)
// Check if already exists in slice
pools := cr.poolsByTokenPair[pairKey]
exists := false
for _, pool := range pools {
if pool == info.Address {
exists = true
break
}
}
if !exists {
cr.poolsByTokenPair[pairKey] = append(pools, info.Address)
}
}
}
// GetCacheStats returns statistics about the cached contracts
func (cr *ContractRegistry) GetCacheStats() map[string]int {
cr.mu.RLock()
defer cr.mu.RUnlock()
stats := map[string]int{
"total": len(cr.contracts),
"tokens": 0,
"v2_pools": 0,
"v3_pools": 0,
"routers": 0,
"verified": 0,
}
for _, info := range cr.contracts {
switch info.Type {
case contracts.ContractTypeERC20Token:
stats["tokens"]++
case contracts.ContractTypeUniswapV2Pool:
stats["v2_pools"]++
case contracts.ContractTypeUniswapV3Pool:
stats["v3_pools"]++
case contracts.ContractTypeUniswapV2Router, contracts.ContractTypeUniswapV3Router:
stats["routers"]++
}
if info.IsVerified {
stats["verified"]++
}
}
return stats
}
// GetPoolInfo returns pool information for a given address
func (cr *ContractRegistry) GetPoolInfo(address common.Address) *ContractInfo {
cr.mu.RLock()
defer cr.mu.RUnlock()
if info, exists := cr.contracts[address]; exists {
if info.Type == contracts.ContractTypeUniswapV2Pool || info.Type == contracts.ContractTypeUniswapV3Pool {
return info
}
}
return nil
}
// AddPool adds a pool to the registry
func (cr *ContractRegistry) AddPool(poolAddress, token0, token1 common.Address, protocol string) {
cr.mu.Lock()
defer cr.mu.Unlock()
// Determine pool type based on protocol
var poolType contracts.ContractType
switch protocol {
case "UniswapV2":
poolType = contracts.ContractTypeUniswapV2Pool
case "UniswapV3":
poolType = contracts.ContractTypeUniswapV3Pool
default:
poolType = contracts.ContractTypeUniswapV2Pool // Default to V2
}
// Add or update pool info
info := &ContractInfo{
Address: poolAddress,
Type: poolType,
Token0: token0,
Token1: token1,
Protocol: protocol,
IsVerified: false, // Runtime discovered pools are not pre-verified
LastUpdated: time.Now(),
Confidence: 0.8, // High confidence for runtime discovered pools
}
cr.contracts[poolAddress] = info
// Update token pair mapping
pairKey := cr.makeTokenPairKey(token0, token1)
if pools, exists := cr.poolsByTokenPair[pairKey]; exists {
// Check if pool already exists in the list
for _, existingPool := range pools {
if existingPool == poolAddress {
return // Already exists
}
}
cr.poolsByTokenPair[pairKey] = append(pools, poolAddress)
} else {
cr.poolsByTokenPair[pairKey] = []common.Address{poolAddress}
}
}

View File

@@ -0,0 +1,292 @@
package secure
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/fraktal/mev-beta/internal/logger"
)
// ConfigManager handles secure configuration management
type ConfigManager struct {
logger *logger.Logger
aesGCM cipher.AEAD
key []byte
}
// NewConfigManager creates a new secure configuration manager
func NewConfigManager(logger *logger.Logger) (*ConfigManager, error) {
// Get encryption key from environment or generate one
keyStr := os.Getenv("MEV_BOT_CONFIG_KEY")
if keyStr == "" {
return nil, errors.New("MEV_BOT_CONFIG_KEY environment variable not set")
}
// Create SHA-256 hash of the key for AES-256
key := sha256.Sum256([]byte(keyStr))
// Create AES cipher
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
// Create GCM mode
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM mode: %w", err)
}
return &ConfigManager{
logger: logger,
aesGCM: aesGCM,
key: key[:],
}, nil
}
// EncryptValue encrypts a configuration value
func (cm *ConfigManager) EncryptValue(plaintext string) (string, error) {
// Create a random nonce
nonce := make([]byte, cm.aesGCM.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt the plaintext
ciphertext := cm.aesGCM.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode to base64 for storage
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptValue decrypts a configuration value
func (cm *ConfigManager) DecryptValue(ciphertext string) (string, error) {
// Decode from base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("failed to decode base64: %w", err)
}
// Check minimum length (nonce size)
nonceSize := cm.aesGCM.NonceSize()
if len(data) < nonceSize {
return "", errors.New("ciphertext too short")
}
// Extract nonce and ciphertext
nonce, ciphertext_bytes := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := cm.aesGCM.Open(nil, nonce, ciphertext_bytes, nil)
if err != nil {
return "", fmt.Errorf("failed to decrypt: %w", err)
}
return string(plaintext), nil
}
// GetSecureValue gets a secure value from environment with fallback to encrypted storage
func (cm *ConfigManager) GetSecureValue(key string) (string, error) {
// First try environment variable
if value := os.Getenv(key); value != "" {
return value, nil
}
// Try encrypted environment variable
encryptedKey := key + "_ENCRYPTED"
if encryptedValue := os.Getenv(encryptedKey); encryptedValue != "" {
return cm.DecryptValue(encryptedValue)
}
return "", fmt.Errorf("secure value not found for key: %s", key)
}
// SecureConfig holds encrypted configuration values
type SecureConfig struct {
manager *ConfigManager
values map[string]string
}
// NewSecureConfig creates a new secure configuration
func NewSecureConfig(manager *ConfigManager) *SecureConfig {
return &SecureConfig{
manager: manager,
values: make(map[string]string),
}
}
// Set stores a value securely
func (sc *SecureConfig) Set(key, value string) error {
encrypted, err := sc.manager.EncryptValue(value)
if err != nil {
return fmt.Errorf("failed to encrypt value for key %s: %w", key, err)
}
sc.values[key] = encrypted
return nil
}
// Get retrieves a value securely
func (sc *SecureConfig) Get(key string) (string, error) {
// Check local encrypted storage first
if encrypted, exists := sc.values[key]; exists {
return sc.manager.DecryptValue(encrypted)
}
// Fallback to secure environment lookup
return sc.manager.GetSecureValue(key)
}
// GetRequired retrieves a required value, returning error if not found
func (sc *SecureConfig) GetRequired(key string) (string, error) {
value, err := sc.Get(key)
if err != nil {
return "", fmt.Errorf("required configuration value missing: %s", key)
}
if strings.TrimSpace(value) == "" {
return "", fmt.Errorf("required configuration value empty: %s", key)
}
return value, nil
}
// GetWithDefault retrieves a value with a default fallback
func (sc *SecureConfig) GetWithDefault(key, defaultValue string) string {
value, err := sc.Get(key)
if err != nil {
return defaultValue
}
return value
}
// LoadFromEnvironment loads configuration from environment variables
func (sc *SecureConfig) LoadFromEnvironment(keys []string) error {
for _, key := range keys {
value, err := sc.manager.GetSecureValue(key)
if err != nil {
sc.manager.logger.Warn(fmt.Sprintf("Could not load secure config for %s: %v", key, err))
continue
}
// Store encrypted in memory
if err := sc.Set(key, value); err != nil {
return fmt.Errorf("failed to store secure config for %s: %w", key, err)
}
}
return nil
}
// Clear removes all stored values from memory
func (sc *SecureConfig) Clear() {
// Zero out the map entries before clearing
for key := range sc.values {
// Overwrite with zeros
sc.values[key] = strings.Repeat("0", len(sc.values[key]))
delete(sc.values, key)
}
}
// Validate checks that all required configuration is present
func (sc *SecureConfig) Validate(requiredKeys []string) error {
var missingKeys []string
for _, key := range requiredKeys {
if _, err := sc.GetRequired(key); err != nil {
missingKeys = append(missingKeys, key)
}
}
if len(missingKeys) > 0 {
return fmt.Errorf("missing required configuration keys: %s", strings.Join(missingKeys, ", "))
}
return nil
}
// GenerateConfigKey generates a new encryption key for configuration
func GenerateConfigKey() (string, error) {
key := make([]byte, 32) // 256-bit key
if _, err := rand.Read(key); err != nil {
return "", fmt.Errorf("failed to generate random key: %w", err)
}
return base64.StdEncoding.EncodeToString(key), nil
}
// ConfigValidator provides validation utilities
type ConfigValidator struct {
logger *logger.Logger
}
// NewConfigValidator creates a new configuration validator
func NewConfigValidator(logger *logger.Logger) *ConfigValidator {
return &ConfigValidator{
logger: logger,
}
}
// ValidateURL validates that a URL is properly formatted and uses HTTPS
func (cv *ConfigValidator) ValidateURL(url string) error {
if url == "" {
return errors.New("URL cannot be empty")
}
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "wss://") {
return errors.New("URL must use HTTPS or WSS protocol")
}
// Additional validation could go here (DNS lookup, connection test, etc.)
return nil
}
// ValidateAPIKey validates that an API key meets minimum security requirements
func (cv *ConfigValidator) ValidateAPIKey(key string) error {
if key == "" {
return errors.New("API key cannot be empty")
}
if len(key) < 32 {
return errors.New("API key must be at least 32 characters")
}
// Check for basic entropy (not all same character, contains mixed case, etc.)
if strings.Count(key, string(key[0])) == len(key) {
return errors.New("API key lacks sufficient entropy")
}
return nil
}
// ValidateAddress validates an Ethereum address
func (cv *ConfigValidator) ValidateAddress(address string) error {
if address == "" {
return errors.New("address cannot be empty")
}
if !strings.HasPrefix(address, "0x") {
return errors.New("address must start with 0x")
}
if len(address) != 42 { // 0x + 40 hex chars
return errors.New("address must be 42 characters long")
}
// Validate hex format
for i, char := range address[2:] {
if !((char >= '0' && char <= '9') || (char >= 'a' && char <= 'f') || (char >= 'A' && char <= 'F')) {
return fmt.Errorf("invalid hex character at position %d: %c", i+2, char)
}
}
return nil
}

View File

@@ -0,0 +1,116 @@
package tokens
import (
"github.com/ethereum/go-ethereum/common"
)
// ArbitrumTokens contains the addresses of popular tokens on Arbitrum
type ArbitrumTokens struct {
// Tier 1 - Major Assets
WETH common.Address
USDC common.Address
USDT common.Address
ARB common.Address
WBTC common.Address
DAI common.Address
LINK common.Address
UNI common.Address
GMX common.Address
GRT common.Address
// Tier 2 - DeFi Blue Chips
USDCe common.Address // Bridged USDC
PENDLE common.Address
RDNT common.Address
MAGIC common.Address
GRAIL common.Address
// Tier 3 - Additional High Volume
AAVE common.Address
CRV common.Address
BAL common.Address
COMP common.Address
MKR common.Address
}
// GetArbitrumTokens returns the addresses of popular tokens on Arbitrum
func GetArbitrumTokens() *ArbitrumTokens {
return &ArbitrumTokens{
// Tier 1 - Major Assets
WETH: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1"), // Wrapped Ether
USDC: common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831"), // USD Coin (Native)
USDT: common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9"), // Tether USD
ARB: common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548"), // Arbitrum Token
WBTC: common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f"), // Wrapped Bitcoin
DAI: common.HexToAddress("0xDA10009cBd5D07dd0CeCc66161FC93D7c9000da1"), // Dai Stablecoin
LINK: common.HexToAddress("0xf97f4df75117a78c1A5a0DBb814Af92458539FB4"), // ChainLink Token
UNI: common.HexToAddress("0xFa7F8980b0f1E64A2062791cc3b0871572f1F7f0"), // Uniswap
GMX: common.HexToAddress("0xfc5A1A6EB076a2C7aD06eD22C90d7E710E35ad0a"), // GMX
GRT: common.HexToAddress("0x9623063377AD1B27544C965cCd7342f7EA7e88C7"), // The Graph
// Tier 2 - DeFi Blue Chips
USDCe: common.HexToAddress("0xFF970A61A04b1cA14834A43f5dE4533eBDDB5CC8"), // USD Coin (Bridged)
PENDLE: common.HexToAddress("0x0c880f6761F1af8d9Aa9C466984b80DAb9a8c9e8"), // Pendle
RDNT: common.HexToAddress("0x3082CC23568eA640225c2467653dB90e9250AaA0"), // Radiant Capital
MAGIC: common.HexToAddress("0x539bdE0d7Dbd336b79148AA742883198BBF60342"), // Magic
GRAIL: common.HexToAddress("0x3d9907F9a368ad0a51Be60f7Da3b97cf940982D8"), // Camelot (GRAIL)
// Tier 3 - Additional High Volume
AAVE: common.HexToAddress("0xba5DdD1f9d7F570dc94a51479a000E3BCE967196"), // Aave
CRV: common.HexToAddress("0x11cDb42B0EB46D95f990BeDD4695A6e3fA034978"), // Curve
BAL: common.HexToAddress("0x040d1EdC9569d4Bab2D15287Dc5A4F10F56a56B8"), // Balancer
COMP: common.HexToAddress("0x354A6dA3fcde098F8389cad84b0182725c6C91dE"), // Compound
MKR: common.HexToAddress("0x2e9a6Df78E42a30712c10a9Dc4b1C8656f8F2879"), // Maker
}
}
// GetTriangularPaths returns common triangular arbitrage paths on Arbitrum
func GetTriangularPaths() []TriangularPath {
tokens := GetArbitrumTokens()
return []TriangularPath{
{
Name: "USDC-WETH-WBTC-USDC",
Tokens: []common.Address{tokens.USDC, tokens.WETH, tokens.WBTC},
},
{
Name: "USDC-WETH-ARB-USDC",
Tokens: []common.Address{tokens.USDC, tokens.WETH, tokens.ARB},
},
{
Name: "WETH-USDC-USDT-WETH",
Tokens: []common.Address{tokens.WETH, tokens.USDC, tokens.USDT},
},
{
Name: "USDC-DAI-USDT-USDC",
Tokens: []common.Address{tokens.USDC, tokens.DAI, tokens.USDT},
},
{
Name: "WETH-ARB-GMX-WETH",
Tokens: []common.Address{tokens.WETH, tokens.ARB, tokens.GMX},
},
{
Name: "USDC-LINK-WETH-USDC",
Tokens: []common.Address{tokens.USDC, tokens.LINK, tokens.WETH},
},
}
}
// TriangularPath represents a triangular arbitrage path
type TriangularPath struct {
Name string
Tokens []common.Address
}
// GetMostLiquidTokens returns the most liquid tokens for market scanning
func GetMostLiquidTokens() []common.Address {
tokens := GetArbitrumTokens()
return []common.Address{
tokens.WETH,
tokens.USDC,
tokens.USDT,
tokens.ARB,
tokens.WBTC,
tokens.DAI,
}
}

View File

@@ -0,0 +1,139 @@
package utils
import (
"fmt"
"strings"
"github.com/ethereum/go-ethereum/common"
"github.com/fraktal/mev-beta/internal/validation"
)
// SafeAddressConverter provides validated address conversion functions
type SafeAddressConverter struct {
validator *validation.AddressValidator
}
// NewSafeAddressConverter creates a new safe address converter
func NewSafeAddressConverter() *SafeAddressConverter {
return &SafeAddressConverter{
validator: validation.NewAddressValidator(),
}
}
// SafeConversionResult contains the result of a safe address conversion
type SafeConversionResult struct {
Address common.Address
IsValid bool
Error error
Warnings []string
}
// SafeHexToAddress safely converts a hex string to an address with validation
func (c *SafeAddressConverter) SafeHexToAddress(hexStr string) *SafeConversionResult {
result := &SafeConversionResult{
Address: common.Address{},
IsValid: false,
Warnings: []string{},
}
// Basic input validation
if hexStr == "" {
result.Error = fmt.Errorf("empty hex string")
return result
}
// Normalize hex string
normalizedHex := strings.ToLower(strings.TrimSpace(hexStr))
if !strings.HasPrefix(normalizedHex, "0x") {
normalizedHex = "0x" + normalizedHex
}
// Length validation
if len(normalizedHex) != 42 { // 0x + 40 hex chars
result.Error = fmt.Errorf("invalid hex string length: %d, expected 42", len(normalizedHex))
return result
}
// Use comprehensive address validation
validationResult := c.validator.ValidateAddress(normalizedHex)
if !validationResult.IsValid {
result.Error = fmt.Errorf("address validation failed: %v", validationResult.ErrorMessages)
return result
}
// Check corruption score
if validationResult.CorruptionScore > 50 {
result.Error = fmt.Errorf("high corruption score (%d), refusing conversion", validationResult.CorruptionScore)
return result
}
// Add warnings for moderate corruption
if validationResult.CorruptionScore > 10 {
result.Warnings = append(result.Warnings, fmt.Sprintf("moderate corruption score: %d", validationResult.CorruptionScore))
}
// Add warnings from validation
if len(validationResult.WarningMessages) > 0 {
result.Warnings = append(result.Warnings, validationResult.WarningMessages...)
}
// Convert to address
result.Address = common.HexToAddress(normalizedHex)
result.IsValid = true
return result
}
// SafeHexToAddressSlice safely converts multiple hex strings to addresses
func (c *SafeAddressConverter) SafeHexToAddressSlice(hexStrings []string) ([]common.Address, []error) {
addresses := make([]common.Address, 0, len(hexStrings))
errors := make([]error, 0)
for i, hexStr := range hexStrings {
result := c.SafeHexToAddress(hexStr)
if !result.IsValid {
errors = append(errors, fmt.Errorf("address %d (%s): %v", i, hexStr, result.Error))
continue
}
addresses = append(addresses, result.Address)
}
return addresses, errors
}
// MustSafeHexToAddress safely converts hex string to address, panics on failure
// Only use for hardcoded known-good addresses
func (c *SafeAddressConverter) MustSafeHexToAddress(hexStr string) common.Address {
result := c.SafeHexToAddress(hexStr)
if !result.IsValid {
panic(fmt.Sprintf("invalid address conversion: %s -> %v", hexStr, result.Error))
}
return result.Address
}
// TryHexToAddress attempts conversion with fallback to zero address
func (c *SafeAddressConverter) TryHexToAddress(hexStr string) (common.Address, bool) {
result := c.SafeHexToAddress(hexStr)
return result.Address, result.IsValid
}
// Global converter instance for convenience
var globalConverter = NewSafeAddressConverter()
// Global convenience functions
// SafeHexToAddress is a global convenience function for safe address conversion
func SafeHexToAddress(hexStr string) *SafeConversionResult {
return globalConverter.SafeHexToAddress(hexStr)
}
// TryHexToAddress is a global convenience function for address conversion with fallback
func TryHexToAddress(hexStr string) (common.Address, bool) {
return globalConverter.TryHexToAddress(hexStr)
}
// MustSafeHexToAddress is a global convenience function for known-good addresses
func MustSafeHexToAddress(hexStr string) common.Address {
return globalConverter.MustSafeHexToAddress(hexStr)
}

View File

@@ -0,0 +1,38 @@
package utils
import (
"math"
"math/big"
"time"
)
// FormatWeiToEther formats a wei amount to ether
func FormatWeiToEther(wei *big.Int) *big.Float {
ether := new(big.Float).SetInt(wei)
ether.Quo(ether, big.NewFloat(1e18))
return ether
}
// FormatTime formats a timestamp to a readable string
func FormatTime(timestamp uint64) string {
if timestamp > math.MaxInt64 {
return "Invalid timestamp: exceeds maximum value"
}
return time.Unix(int64(timestamp), 0).Format("2006-01-02 15:04:05")
}
// Min returns the smaller of two integers
func Min(a, b int) int {
if a < b {
return a
}
return b
}
// Max returns the larger of two integers
func Max(a, b int) int {
if a > b {
return a
}
return b
}

View File

@@ -0,0 +1,704 @@
// Package validation provides comprehensive Ethereum address validation and corruption detection.
// This package is critical for the MEV bot's security and reliability, preventing costly errors
// from malformed or corrupted addresses that could cause transaction failures or security issues.
//
// Key features:
// - Multi-layer address validation (format, length, corruption detection)
// - Contract type classification and prevention of ERC-20/pool confusion
// - Corruption scoring system to identify suspicious addresses
// - Known contract registry for instant validation of major protocols
package validation
import (
"fmt"
"regexp"
"strings"
"github.com/ethereum/go-ethereum/common"
)
// AddressValidator provides comprehensive Ethereum address validation with advanced
// corruption detection and contract type classification. This validator is designed
// to prevent the costly errors that can occur when malformed addresses are used
// in contract calls, particularly in high-frequency MEV operations.
//
// The validator implements multiple validation layers:
// 1. Basic format validation (hex format, length, prefix)
// 2. Corruption pattern detection (repetitive patterns, suspicious zeros)
// 3. Contract type classification to prevent misuse
// 4. Known contract registry for instant validation
type AddressValidator struct {
// Known corrupted patterns that should be immediately rejected
// These patterns are derived from observed corruption incidents in production
corruptedPatterns []string
// Precompiled regex equivalents for efficient matching
corruptedPatternRegex []*regexp.Regexp
// Registry of known contract addresses and their verified types
// This enables instant validation without requiring RPC calls
knownContracts map[common.Address]ContractType
// Enhanced known contract registry for detailed validation
knownContractsRegistry *KnownContractRegistry
}
// ContractType represents the classification of an Ethereum contract.
// This classification is critical for preventing the ERC-20/pool confusion
// that was causing massive log spam and transaction failures in production.
type ContractType int
const (
// ContractTypeUnknown indicates the contract type could not be determined
// This is the default for addresses not in the known contracts registry
ContractTypeUnknown ContractType = iota
// ContractTypeERC20Token indicates a standard ERC-20 token contract
// These contracts should never be used in pool-specific operations
ContractTypeERC20Token
// ContractTypeUniswapV2Pool indicates a Uniswap V2 compatible pool contract
// These support token0(), token1(), and getReserves() functions
ContractTypeUniswapV2Pool
// ContractTypeUniswapV3Pool indicates a Uniswap V3 compatible pool contract
// These support token0(), token1(), and slot0() functions
ContractTypeUniswapV3Pool
// ContractTypeRouter indicates a DEX router contract
// These should not be used directly as token or pool addresses
ContractTypeRouter
// ContractTypeFactory indicates a pool factory contract
// These create pools but are not pools themselves
ContractTypeFactory
)
// String returns a human-readable representation of the contract type.
func (ct ContractType) String() string {
switch ct {
case ContractTypeERC20Token:
return "ERC20_TOKEN"
case ContractTypeUniswapV2Pool:
return "UNISWAP_V2_POOL"
case ContractTypeUniswapV3Pool:
return "UNISWAP_V3_POOL"
case ContractTypeRouter:
return "ROUTER"
case ContractTypeFactory:
return "FACTORY"
case ContractTypeUnknown:
fallthrough
default:
return "UNKNOWN"
}
}
// ValidationError represents a structured validation failure for an address.
type ValidationError struct {
Code string
Message string
Context map[string]interface{}
}
// Error implements the error interface.
func (e *ValidationError) Error() string {
if e == nil {
return ""
}
if e.Code != "" {
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}
return e.Message
}
// AddressValidationResult contains comprehensive validation results and metadata
// for an Ethereum address. This structure provides detailed information about
// the validation process, including any issues found and confidence metrics.
type AddressValidationResult struct {
// IsValid indicates whether the address passed all validation checks
// A false value means the address should not be used in transactions
IsValid bool
// Address contains the parsed Ethereum address (only valid if IsValid is true)
Address common.Address
// ContractType indicates the classification of the contract (if known)
ContractType ContractType
// ErrorMessages contains detailed descriptions of validation failures
// These are critical issues that prevent the address from being used
ErrorMessages []string
// WarningMessages contains non-critical issues or concerns
// These don't prevent usage but should be logged for monitoring
WarningMessages []string
// CorruptionScore is a 0-100 metric indicating likelihood of corruption
// Higher scores indicate more suspicious patterns (0=clean, 100=definitely corrupted)
// Addresses with scores >30 are typically rejected in critical operations
CorruptionScore int
}
// NewAddressValidator creates a new address validator
func NewAddressValidator() *AddressValidator {
patterns := []string{
// Patterns indicating clear corruption
"0000000000000000000000000000000000000000", // All zeros
"000000000000000000000000000000000000000", // Missing one zero
"00000000000000000000000000000000000000000", // Extra zero
// Patterns with trailing zeros indicating truncation
"00000000000000000000000000$",
"000000000000000000000000$",
"0000000000000000000000$",
// Patterns with leading non-hex after 0x
"^0x[^0-9a-fA-F]",
}
compiled := make([]*regexp.Regexp, 0, len(patterns))
for _, pattern := range patterns {
compiled = append(compiled, regexp.MustCompile(pattern))
}
return &AddressValidator{
corruptedPatterns: patterns,
corruptedPatternRegex: compiled,
knownContracts: make(map[common.Address]ContractType),
knownContractsRegistry: NewKnownContractRegistry(),
}
}
// InitializeKnownContracts populates the validator with known Arbitrum contracts
func (av *AddressValidator) InitializeKnownContracts() {
// Known Arbitrum tokens
av.knownContracts[common.HexToAddress("0xFF970A61A04b1cA14834A43f5dE4533eBDDB5CC8")] = ContractTypeERC20Token // USDC
av.knownContracts[common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1")] = ContractTypeERC20Token // WETH
av.knownContracts[common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9")] = ContractTypeERC20Token // USDT
av.knownContracts[common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f")] = ContractTypeERC20Token // WBTC
av.knownContracts[common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548")] = ContractTypeERC20Token // ARB
// Known Arbitrum routers
av.knownContracts[common.HexToAddress("0xE592427A0AEce92De3Edee1F18E0157C05861564")] = ContractTypeRouter // Uniswap V3 Router
av.knownContracts[common.HexToAddress("0x4752ba5dbc23f44d87826276bf6fd6b1c372ad24")] = ContractTypeRouter // Uniswap V2 Router02
av.knownContracts[common.HexToAddress("0xA51afAFe0263b40EdaEf0Df8781eA9aa03E381a3")] = ContractTypeRouter // Universal Router
av.knownContracts[common.HexToAddress("0xC36442b4a4522E871399CD717aBDD847Ab11FE88")] = ContractTypeRouter // Position Manager
// Known Arbitrum factories
av.knownContracts[common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984")] = ContractTypeFactory // Uniswap V3 Factory
av.knownContracts[common.HexToAddress("0xf1D7CC64Fb4452F05c498126312eBE29f30Fbcf9")] = ContractTypeFactory // Uniswap V2 Factory
// Known high-volume pools
av.knownContracts[common.HexToAddress("0xC6962004f452bE9203591991D15f6b388e09E8D0")] = ContractTypeUniswapV3Pool // USDC/WETH 0.05%
av.knownContracts[common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d")] = ContractTypeUniswapV3Pool // USDC/WETH 0.3%
av.knownContracts[common.HexToAddress("0x2f5e87C9312fa29aed5c179E456625D79015299c")] = ContractTypeUniswapV3Pool // WBTC/WETH 0.05%
}
// ValidateAddress performs comprehensive validation of an Ethereum address string.
// This is the primary validation function that applies all validation layers
// in sequence to determine if an address is safe to use in transactions.
//
// The validation process includes:
// 1. Basic format validation (hex format, 0x prefix)
// 2. Length validation (must be exactly 42 characters)
// 3. Corruption pattern detection
// 4. Contract type classification
// 5. Corruption scoring
//
// Parameters:
// - addressStr: The address string to validate (should include 0x prefix)
//
// Returns:
// - *AddressValidationResult: Comprehensive validation results
func (av *AddressValidator) ValidateAddress(addressStr string) *AddressValidationResult {
// Initialize the result structure with safe defaults
result := &AddressValidationResult{
IsValid: false, // Default to invalid until all checks pass
ErrorMessages: make([]string, 0),
WarningMessages: make([]string, 0),
CorruptionScore: 0, // Start with zero corruption score
}
// Basic format validation
if !av.isValidHexFormat(addressStr) {
result.ErrorMessages = append(result.ErrorMessages, "invalid hex format")
result.CorruptionScore += 50
return result
}
// Length validation
if !av.isValidLength(addressStr) {
result.ErrorMessages = append(result.ErrorMessages, "invalid address length")
result.CorruptionScore += 50
return result
}
// Corruption pattern detection
corruptionDetected, patterns := av.detectCorruption(addressStr)
if corruptionDetected {
result.ErrorMessages = append(result.ErrorMessages, fmt.Sprintf("corruption detected: %v", patterns))
result.CorruptionScore += 70
return result
}
// Convert to common.Address for further validation
if !common.IsHexAddress(addressStr) {
result.ErrorMessages = append(result.ErrorMessages, "not a valid Ethereum address")
result.CorruptionScore += 30
return result
}
address := common.HexToAddress(addressStr)
result.Address = address
// Check for zero address
if address == (common.Address{}) {
result.ErrorMessages = append(result.ErrorMessages, "zero address")
result.CorruptionScore += 40
return result
}
// Check known contract types
if contractType, exists := av.knownContracts[address]; exists {
result.ContractType = contractType
} else {
result.ContractType = ContractTypeUnknown
result.WarningMessages = append(result.WarningMessages, "unknown contract type")
}
// Additional pattern-based corruption detection
result.CorruptionScore += av.calculateCorruptionScore(addressStr)
// Mark as valid if corruption score is low enough
if result.CorruptionScore < 30 {
result.IsValid = true
}
return result
}
// isValidHexFormat checks if the string is a valid hex format
func (av *AddressValidator) isValidHexFormat(addressStr string) bool {
if len(addressStr) < 3 {
return false
}
if !strings.HasPrefix(addressStr, "0x") && !strings.HasPrefix(addressStr, "0X") {
return false
}
// Check if all characters after 0x are valid hex
hexPart := addressStr[2:]
if len(hexPart) == 0 {
return false
}
for i := 0; i < len(hexPart); i++ {
switch {
case hexPart[i] >= '0' && hexPart[i] <= '9':
case hexPart[i] >= 'a' && hexPart[i] <= 'f':
case hexPart[i] >= 'A' && hexPart[i] <= 'F':
default:
return false
}
}
return true
}
// isValidLength checks if the address has the correct length
func (av *AddressValidator) isValidLength(addressStr string) bool {
// Ethereum addresses should be 42 characters (0x + 40 hex chars)
return len(addressStr) == 42
}
// detectCorruption checks for known corruption patterns
func (av *AddressValidator) detectCorruption(addressStr string) (bool, []string) {
var detectedPatterns []string
for idx, re := range av.corruptedPatternRegex {
if re.MatchString(addressStr) {
detectedPatterns = append(detectedPatterns, av.corruptedPatterns[idx])
}
}
return len(detectedPatterns) > 0, detectedPatterns
}
// calculateCorruptionScore calculates a 0-100 score indicating the likelihood
// that an address has been corrupted or malformed. This scoring system is based
// on patterns observed in production corruption incidents.
//
// Scoring factors:
// - Trailing zeros (indicates truncation): +1 per excess zero
// - Leading zeros in middle (unusual pattern): +0.5 per zero
// - Repetitive patterns (indicates generation errors): +10
// - Other suspicious patterns: variable points
//
// Parameters:
// - addressStr: The address string to analyze (with 0x prefix)
//
// Returns:
// - int: Corruption score (0=clean, 100=definitely corrupted)
func (av *AddressValidator) calculateCorruptionScore(addressStr string) int {
score := 0
// Extract the hex part (remove 0x prefix)
hexPart := addressStr[2:]
// Count trailing zeros (sign of truncation)
trailingZeros := 0
for i := len(hexPart) - 1; i >= 0; i-- {
if hexPart[i] == '0' {
trailingZeros++
} else {
break
}
}
// More than 10 trailing zeros is suspicious
if trailingZeros > 10 {
score += trailingZeros
}
// Count leading zeros after first non-zero
leadingZeros := 0
foundNonZero := false
for _, char := range hexPart {
if char != '0' {
foundNonZero = true
} else if foundNonZero {
leadingZeros++
}
}
// Large blocks of zeros in the middle are suspicious
if leadingZeros > 8 {
score += leadingZeros / 2
}
// Check for repetitive patterns
if av.hasRepetitivePattern(hexPart) {
score += 10
}
// Overall zero density check for leading-zero patterns (common corruption)
zeroCount := strings.Count(hexPart, "0")
if strings.HasPrefix(hexPart, "0000") && float64(zeroCount)/float64(len(hexPart)) > 0.7 {
score += 30
}
// Leading zero prefix (common corruption pattern from truncated data)
if strings.HasPrefix(hexPart, "000000") {
score += 20
}
return score
}
// hasRepetitivePattern detects repetitive patterns in hex strings that indicate
// corruption or artificial generation. These patterns are rarely seen in legitimate
// Ethereum addresses and often indicate data corruption or malicious generation.
//
// Detected patterns include:
// - Long sequences of the same character (000000000000, ffffffffffff)
// - Addresses where all characters are identical
// - Other suspicious repetitive patterns
//
// Parameters:
// - hexStr: The hex string to analyze (without 0x prefix)
//
// Returns:
// - bool: true if repetitive patterns are detected
func (av *AddressValidator) hasRepetitivePattern(hexStr string) bool {
// Define patterns that indicate corruption or artificial generation
// These patterns are extremely rare in legitimate Ethereum addresses
patterns := []string{"000000000000", "ffffffffffff", "aaaaaaaaaaaa", "bbbbbbbbbbbb",
"1111111111", "2222222222", "3333333333", "4444444444", "5555555555",
"6666666666", "7777777777", "8888888888", "9999999999"}
for _, pattern := range patterns {
if strings.Contains(hexStr, pattern) {
return true
}
}
// Additional check for address with same character repeated throughout
if len(hexStr) >= 10 {
firstChar := hexStr[0]
allSame := true
for i := 1; i < len(hexStr); i++ {
if hexStr[i] != firstChar {
allSame = false
break
}
}
if allSame {
return true
}
}
return false
}
// IsValidPoolAddress checks if an address is valid for pool operations
func (av *AddressValidator) IsValidPoolAddress(address common.Address) bool {
result := av.ValidateAddress(address.Hex())
if !result.IsValid {
return false
}
// Must not be a token, router, or factory for pool operations
switch result.ContractType {
case ContractTypeERC20Token, ContractTypeRouter, ContractTypeFactory:
return false
case ContractTypeUniswapV2Pool, ContractTypeUniswapV3Pool:
return true
case ContractTypeUnknown:
// Allow unknown contracts but warn
return result.CorruptionScore < 20
}
return false
}
// IsValidTokenAddress checks if an address is valid for token operations
func (av *AddressValidator) IsValidTokenAddress(address common.Address) bool {
result := av.ValidateAddress(address.Hex())
if !result.IsValid {
return false
}
// Prefer known tokens, but allow unknown contracts with low corruption scores
switch result.ContractType {
case ContractTypeERC20Token:
return true
case ContractTypeRouter, ContractTypeFactory, ContractTypeUniswapV2Pool, ContractTypeUniswapV3Pool:
return false
case ContractTypeUnknown:
return result.CorruptionScore < 15
}
return false
}
// GetContractType returns the contract type for a given address
func (av *AddressValidator) GetContractType(address common.Address) ContractType {
if contractType, exists := av.knownContracts[address]; exists {
return contractType
}
return ContractTypeUnknown
}
// ValidateContractTypeConsistency validates that addresses don't have conflicting contract types
func (av *AddressValidator) ValidateContractTypeConsistency(tokenAddresses []common.Address, poolAddresses []common.Address) error {
// CRITICAL: Ensure no address appears in both token and pool lists
for _, token := range tokenAddresses {
for _, pool := range poolAddresses {
if token == pool {
return fmt.Errorf("address %s cannot be both a token and a pool", token.Hex())
}
}
}
// CRITICAL: Validate each token address is actually a token
for _, token := range tokenAddresses {
if !av.IsValidTokenAddress(token) {
return fmt.Errorf("address %s is not a valid token address", token.Hex())
}
// Additional check: ensure it's not marked as a pool in known contracts
if contractType := av.GetContractType(token); contractType == ContractTypeUniswapV2Pool || contractType == ContractTypeUniswapV3Pool {
return fmt.Errorf("address %s is marked as a pool but being used as a token", token.Hex())
}
}
// CRITICAL: Validate each pool address is actually a pool
for _, pool := range poolAddresses {
if !av.IsValidPoolAddress(pool) {
return fmt.Errorf("address %s is not a valid pool address", pool.Hex())
}
// Additional check: ensure it's not marked as a token in known contracts
if contractType := av.GetContractType(pool); contractType == ContractTypeERC20Token {
return fmt.Errorf("address %s is marked as a token but being used as a pool", pool.Hex())
}
}
return nil
}
// PreventERC20PoolConfusion is a critical safety function that prevents the costly
// error where ERC-20 token contracts are incorrectly used as pool contracts.
// This was the root cause of the 535K+ log spam incident in production.
//
// The function performs a type safety check to ensure that:
// - ERC-20 tokens are not used in pool operations
// - Pool contracts are not used in token operations
// - Unknown contracts meet safety thresholds
//
// This is a mandatory check for all contract address usage in critical operations.
//
// Parameters:
// - address: The contract address to validate
// - expectedType: The contract type expected by the calling code
//
// Returns:
// - error: nil if the address is safe to use, error describing the issue otherwise
func (av *AddressValidator) PreventERC20PoolConfusion(address common.Address, expectedType ContractType) error {
// Check if we have prior knowledge about this contract
knownType := av.GetContractType(address)
// If we have knowledge about this contract, use it
if knownType != ContractTypeUnknown {
if knownType != expectedType {
return fmt.Errorf("contract type mismatch for %s: expected %s but known as %s",
address.Hex(), contractTypeToString(expectedType), contractTypeToString(knownType))
}
return nil
}
// For unknown contracts, perform basic validation
result := av.ValidateAddress(address.Hex())
if !result.IsValid {
return fmt.Errorf("invalid address %s: %v", address.Hex(), result.ErrorMessages)
}
// High corruption score indicates potential misclassification
if result.CorruptionScore > 25 {
return fmt.Errorf("high corruption score (%d) for address %s, refusing to use as %s",
result.CorruptionScore, address.Hex(), contractTypeToString(expectedType))
}
return nil
}
// contractTypeToString converts ContractType to string representation
func contractTypeToString(ct ContractType) string {
switch ct {
case ContractTypeERC20Token:
return "ERC-20 Token"
case ContractTypeUniswapV2Pool:
return "Uniswap V2 Pool"
case ContractTypeUniswapV3Pool:
return "Uniswap V3 Pool"
case ContractTypeRouter:
return "Router"
case ContractTypeFactory:
return "Factory"
default:
return "Unknown"
}
}
// IsKnownContract checks if we have specific knowledge about a contract
func (av *AddressValidator) IsKnownContract(address common.Address) bool {
_, exists := av.knownContracts[address]
return exists
}
// GetValidationStats returns statistics about validation results
func (av *AddressValidator) GetValidationStats() map[string]interface{} {
return map[string]interface{}{
"known_contracts": len(av.knownContracts),
"corruption_patterns": len(av.corruptedPatterns),
}
}
// SanitizeAddress attempts to clean up a potentially corrupted address
func (av *AddressValidator) SanitizeAddress(addressStr string) (string, error) {
// Remove common prefixes that might be corrupted
cleaned := strings.TrimSpace(addressStr)
// Ensure 0x prefix
if !strings.HasPrefix(cleaned, "0x") && !strings.HasPrefix(cleaned, "0X") {
if len(cleaned) == 40 && av.isValidHexFormat("0x"+cleaned) {
cleaned = "0x" + cleaned
} else {
return "", fmt.Errorf("cannot sanitize address without 0x prefix")
}
}
// Normalize to lowercase
cleaned = strings.ToLower(cleaned)
// Validate the sanitized address
result := av.ValidateAddress(cleaned)
if !result.IsValid {
return "", fmt.Errorf("sanitized address is still invalid: %v", result.ErrorMessages)
}
return cleaned, nil
}
// ValidatePoolOperation validates if a pool-specific operation should be allowed on an address
// This prevents the critical error where ERC-20 tokens are treated as pools
func (av *AddressValidator) ValidatePoolOperation(address common.Address, operation string) error {
// Check if this is a known ERC-20 token
if av.knownContractsRegistry != nil {
if err := av.knownContractsRegistry.ValidatePoolCall(address, operation); err != nil {
return err
}
}
// Additional validation for suspicious addresses
addressStr := address.Hex()
// Check for zero address (common corruption)
if address == (common.Address{}) {
return fmt.Errorf("pool operation '%s' attempted on zero address (likely corruption)", operation)
}
// Check for addresses with excessive zeros (likely corruption)
if strings.Count(addressStr, "0") > 30 {
return fmt.Errorf("pool operation '%s' attempted on suspicious address %s (excessive zeros)", operation, addressStr)
}
// Fallback to general validation
result := av.ValidateAddress(addressStr)
if !result.IsValid {
return fmt.Errorf("pool operation '%s' attempted on invalid address %s: %v", operation, addressStr, result.ErrorMessages)
}
if result.CorruptionScore > 25 {
return fmt.Errorf("pool operation '%s' blocked due to corruption score %d on address %s", operation, result.CorruptionScore, addressStr)
}
return nil
}
// GetDetailedAddressAnalysis provides comprehensive analysis of an address including corruption patterns
func (av *AddressValidator) GetDetailedAddressAnalysis(address common.Address) map[string]interface{} {
analysis := make(map[string]interface{})
addressStr := address.Hex()
// Basic validation
result := av.ValidateAddress(addressStr)
analysis["is_valid"] = result.IsValid
analysis["corruption_score"] = result.CorruptionScore
analysis["contract_type"] = result.ContractType.String()
analysis["error_messages"] = result.ErrorMessages
// Known contract information
if av.knownContractsRegistry != nil {
contractType, name := av.knownContractsRegistry.GetContractType(address)
analysis["known_contract_type"] = contractType.String()
analysis["known_contract_name"] = name
isERC20, tokenName := av.knownContractsRegistry.IsKnownERC20(address)
analysis["is_known_erc20"] = isERC20
if isERC20 {
analysis["token_name"] = tokenName
}
// Corruption pattern analysis
corruptionPattern := av.knownContractsRegistry.GetCorruptionPattern(address)
analysis["corruption_pattern"] = corruptionPattern
}
// Address characteristics
analysis["is_zero_address"] = address == (common.Address{})
analysis["zero_count"] = strings.Count(addressStr, "0")
analysis["hex_length"] = len(addressStr)
return analysis
}

View File

@@ -0,0 +1,351 @@
package validation
import (
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAddressValidator_ValidateAddress(t *testing.T) {
validator := NewAddressValidator()
tests := []struct {
name string
address string
expectedValid bool
expectedScore int
expectedType ContractType
shouldContainErrors []string
}{
{
name: "Valid WETH address",
address: "0x82aF49447D8a07e3bd95BD0d56f35241523fBab1",
expectedValid: true,
expectedScore: 0,
expectedType: ContractTypeUnknown, // Will be unknown without RPC
},
{
name: "Valid USDC address",
address: "0xaf88d065e77c8cC2239327C5EDb3A432268e5831",
expectedValid: true,
expectedScore: 0,
expectedType: ContractTypeUnknown,
},
{
name: "Critical corruption - TOKEN_0x000000 pattern",
address: "0x0000000300000000000000000000000000000000",
expectedValid: false,
expectedScore: 70, // Detected by corruption patterns
shouldContainErrors: []string{"corruption detected"},
},
{
name: "High corruption - mixed zero pattern",
address: "0x0000001200000000000000000000000000000000",
expectedValid: false,
expectedScore: 70,
shouldContainErrors: []string{"corruption detected"},
},
{
name: "Medium corruption - trailing zeros",
address: "0x123456780000000000000000000000000000000",
expectedValid: false,
expectedScore: 50,
shouldContainErrors: []string{"invalid address length"},
},
{
name: "Low corruption - some zeros",
address: "0x1234567800000000000000000000000000000001",
expectedValid: true, // Valid format with moderate corruption
expectedScore: 25,
shouldContainErrors: []string{}, // No errors for valid format
},
{
name: "Invalid length - too short",
address: "0x123456",
expectedValid: false,
expectedScore: 50,
shouldContainErrors: []string{"invalid address length"},
},
{
name: "Invalid length - too long",
address: "0x82aF49447D8a07e3bd95BD0d56f35241523fBab12345",
expectedValid: false,
expectedScore: 50,
shouldContainErrors: []string{"invalid address length"},
},
{
name: "Invalid hex characters",
address: "0x82aF49447D8a07e3bd95BD0d56f35241523fBaZ1",
expectedValid: false,
expectedScore: 50,
shouldContainErrors: []string{"invalid hex format"},
},
{
name: "Missing 0x prefix",
address: "82aF49447D8a07e3bd95BD0d56f35241523fBab1",
expectedValid: false,
expectedScore: 50,
shouldContainErrors: []string{"invalid hex format"},
},
{
name: "All zeros address",
address: "0x0000000000000000000000000000000000000000",
expectedValid: false,
expectedScore: 70, // Detected by corruption patterns
shouldContainErrors: []string{"corruption detected"},
},
{
name: "Invalid checksum",
address: "0x82af49447d8a07e3bd95bd0d56f35241523fbab1", // lowercase
expectedValid: true, // Checksum validation not enforced in current implementation
expectedScore: 0,
shouldContainErrors: []string{}, // No errors for valid format
},
{
name: "Valid checksummed address",
address: common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1").Hex(),
expectedValid: true,
expectedScore: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.ValidateAddress(tt.address)
assert.Equal(t, tt.expectedValid, result.IsValid, "IsValid mismatch")
assert.Equal(t, tt.expectedScore, result.CorruptionScore, "CorruptionScore mismatch")
if tt.expectedType != ContractTypeUnknown {
assert.Equal(t, tt.expectedType, result.ContractType, "ContractType mismatch")
}
for _, expectedError := range tt.shouldContainErrors {
found := false
for _, errMsg := range result.ErrorMessages {
if contains(errMsg, expectedError) {
found = true
break
}
}
assert.True(t, found, "Expected error message containing '%s' not found in %v", expectedError, result.ErrorMessages)
}
t.Logf("Address: %s, Valid: %v, Score: %d, Errors: %v",
tt.address, result.IsValid, result.CorruptionScore, result.ErrorMessages)
})
}
}
func TestAddressValidator_CorruptionPatterns(t *testing.T) {
validator := NewAddressValidator()
corruptionTests := []struct {
name string
address string
expectedScore int
description string
}{
{
name: "TOKEN_0x000000 exact pattern",
address: "0x0000000300000000000000000000000000000000",
expectedScore: 70, // Caught by pattern detection
description: "Exact TOKEN_0x000000 corruption pattern",
},
{
name: "Similar corruption pattern",
address: "0x0000000100000000000000000000000000000000",
expectedScore: 70, // Caught by pattern detection
description: "Similar zero-heavy corruption",
},
{
name: "Partial corruption",
address: "0x1234000000000000000000000000000000000000",
expectedScore: 70, // Caught by pattern detection
description: "Partial zero corruption",
},
{
name: "Trailing corruption",
address: "0x123456789abcdef000000000000000000000000",
expectedScore: 50, // Invalid length
description: "Trailing zero corruption",
},
{
name: "Valid but zero-heavy",
address: "0x000000000000000000000000000000000000dead",
expectedScore: 10, // Valid format, minimal corruption
description: "Valid format but suspicious zeros",
},
}
for _, tt := range corruptionTests {
t.Run(tt.name, func(t *testing.T) {
result := validator.ValidateAddress(tt.address)
assert.GreaterOrEqual(t, result.CorruptionScore, tt.expectedScore-10,
"Corruption score should be at least %d-10 for %s", tt.expectedScore, tt.description)
assert.LessOrEqual(t, result.CorruptionScore, 100,
"Corruption score should not exceed 100")
// Check validity based on expected behavior rather than fixed threshold
if tt.expectedScore >= 50 && tt.address != "0x000000000000000000000000000000000000dead" {
assert.False(t, result.IsValid, "High corruption addresses should be invalid")
}
t.Logf("%s: Score=%d, Valid=%v, Description=%s",
tt.address, result.CorruptionScore, result.IsValid, tt.description)
})
}
}
func TestAddressValidator_EdgeCases(t *testing.T) {
validator := NewAddressValidator()
edgeCases := []struct {
name string
address string
shouldBeValid bool
}{
{"Empty string", "", false},
{"Only 0x", "0x", false},
{"Just prefix", "0x0", false},
{"Uppercase hex", "0x82AF49447D8A07E3BD95BD0D56F35241523FBAB1", true}, // Valid - case doesn't matter
{"Mixed case invalid", "0x82aF49447D8a07e3BD95BD0d56f35241523fBaB1", true}, // Valid - case doesn't matter
{"Unicode characters", "0x82aF49447D8a07e3bd95BD0d56f35241523fBaβ1", false},
{"SQL injection attempt", "0x'; DROP TABLE addresses; --", false},
{"Buffer overflow attempt", "0x" + string(make([]byte, 1000)), false},
}
for _, tt := range edgeCases {
t.Run(tt.name, func(t *testing.T) {
// Should not panic
result := validator.ValidateAddress(tt.address)
// Check validity based on expectation
assert.Equal(t, tt.shouldBeValid, result.IsValid, "Edge case validity mismatch: %s", tt.address)
if !tt.shouldBeValid {
assert.Greater(t, result.CorruptionScore, 0, "Invalid edge case should have corruption score > 0")
assert.NotEmpty(t, result.ErrorMessages, "Invalid edge case should have error messages")
} else {
// Valid addresses can have low corruption scores
assert.Empty(t, result.ErrorMessages, "Valid edge case should not have error messages")
}
t.Logf("Edge case '%s': Valid=%v, Score=%d, Errors=%v",
tt.name, result.IsValid, result.CorruptionScore, result.ErrorMessages)
})
}
}
func TestAddressValidator_Performance(t *testing.T) {
validator := NewAddressValidator()
// Test addresses for performance benchmark
addresses := []string{
"0x82aF49447D8a07e3bd95BD0d56f35241523fBab1", // Valid WETH
"0x0000000300000000000000000000000000000000", // Corrupted
"0xaf88d065e77c8cC2239327C5EDb3A432268e5831", // Valid USDC
"0x0000000000000000000000000000000000000000", // Zero address
"0x123456789abcdef0000000000000000000000000", // Partial corruption
}
// Warm up
for _, addr := range addresses {
validator.ValidateAddress(addr)
}
// Benchmark validation performance
const iterations = 10000
start := time.Now()
for i := 0; i < iterations; i++ {
addr := addresses[i%len(addresses)]
result := validator.ValidateAddress(addr)
require.NotNil(t, result)
}
duration := time.Since(start)
avgTime := duration / iterations
t.Logf("Performance: %d validations in %v (avg: %v per validation)",
iterations, duration, avgTime)
// Should validate at least 1,000 addresses per second
maxTime := time.Millisecond * 2 // 2ms per validation = 500/sec (reasonable for complex validation)
assert.Less(t, avgTime.Nanoseconds(), maxTime.Nanoseconds(),
"Validation should be faster than %v per address (got %v)", maxTime, avgTime)
}
func TestAddressValidator_ConcurrentAccess(t *testing.T) {
validator := NewAddressValidator()
addresses := []string{
"0x82aF49447D8a07e3bd95BD0d56f35241523fBab1",
"0x0000000300000000000000000000000000000000",
"0xaf88d065e77c8cC2239327C5EDb3A432268e5831",
}
const numGoroutines = 100
const validationsPerGoroutine = 100
done := make(chan bool, numGoroutines)
// Launch concurrent validators
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer func() { done <- true }()
for j := 0; j < validationsPerGoroutine; j++ {
addr := addresses[j%len(addresses)]
result := validator.ValidateAddress(addr)
require.NotNil(t, result)
// Verify consistent results
if addr == "0x82aF49447D8a07e3bd95BD0d56f35241523fBab1" {
assert.True(t, result.IsValid)
assert.Equal(t, 0, result.CorruptionScore)
}
if addr == "0x0000000300000000000000000000000000000000" {
assert.False(t, result.IsValid)
assert.Equal(t, 70, result.CorruptionScore) // Updated to match new validation logic
}
}
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < numGoroutines; i++ {
select {
case <-done:
// Success
case <-time.After(10 * time.Second):
t.Fatal("Concurrent validation test timed out")
}
}
t.Logf("Successfully completed %d concurrent validations",
numGoroutines*validationsPerGoroutine)
}
// Helper function to check if a string contains a substring (case-insensitive)
func contains(str, substr string) bool {
return len(str) >= len(substr) &&
(str == substr ||
len(str) > len(substr) &&
(str[:len(substr)] == substr ||
str[len(str)-len(substr):] == substr ||
indexOf(str, substr) >= 0))
}
func indexOf(str, substr string) int {
for i := 0; i <= len(str)-len(substr); i++ {
if str[i:i+len(substr)] == substr {
return i
}
}
return -1
}

View File

@@ -0,0 +1,158 @@
package validation
import (
"fmt"
"strings"
"github.com/ethereum/go-ethereum/common"
)
// KnownContractRegistry maintains a registry of known contract addresses and their types
// This prevents misclassification of well-known contracts (like major ERC-20 tokens)
type KnownContractRegistry struct {
erc20Tokens map[common.Address]string
pools map[common.Address]string
routers map[common.Address]string
}
// NewKnownContractRegistry creates a new registry populated with known Arbitrum contracts
func NewKnownContractRegistry() *KnownContractRegistry {
registry := &KnownContractRegistry{
erc20Tokens: make(map[common.Address]string),
pools: make(map[common.Address]string),
routers: make(map[common.Address]string),
}
// Populate known ERC-20 tokens on Arbitrum (CRITICAL: These should NEVER be treated as pools)
registry.addKnownERC20Tokens()
registry.addKnownPools()
registry.addKnownRouters()
return registry
}
// addKnownERC20Tokens adds all major ERC-20 tokens on Arbitrum
func (r *KnownContractRegistry) addKnownERC20Tokens() {
// Major ERC-20 tokens that were being misclassified as pools
r.erc20Tokens[common.HexToAddress("0x82aF49447D8a07e3bd95BD0d56f35241523fBab1")] = "WETH" // Wrapped Ether
r.erc20Tokens[common.HexToAddress("0xaf88d065e77c8cC2239327C5EDb3A432268e5831")] = "USDC" // USD Coin
r.erc20Tokens[common.HexToAddress("0xFd086bC7CD5C481DCC9C85ebE478A1C0b69FCbb9")] = "USDT" // Tether USD
r.erc20Tokens[common.HexToAddress("0xFF970A61A04b1cA14834A43f5dE4533eBDDB5CC8")] = "USDC.e" // Bridged USDC
r.erc20Tokens[common.HexToAddress("0x912CE59144191C1204E64559FE8253a0e49E6548")] = "ARB" // Arbitrum Token
r.erc20Tokens[common.HexToAddress("0x2f2a2543B76A4166549F7aaB2e75Bef0aefC5B0f")] = "WBTC" // Wrapped Bitcoin
r.erc20Tokens[common.HexToAddress("0xDA10009cBd5D07dd0CeCc66161FC93D7c9000da1")] = "DAI" // Dai Stablecoin
r.erc20Tokens[common.HexToAddress("0x17FC002b466eEc40DaE837Fc4bE5c67993ddBd6F")] = "FRAX" // Frax
r.erc20Tokens[common.HexToAddress("0x11cDb42B0EB46D95f990BeDD4695A6e3fA034978")] = "CRV" // Curve DAO Token
r.erc20Tokens[common.HexToAddress("0x539bdE0d7Dbd336b79148AA742883198BBF60342")] = "MAGIC" // MAGIC
r.erc20Tokens[common.HexToAddress("0xf97f4df75117a78c1A5a0DBb814Af92458539FB4")] = "LINK" // Chainlink Token
r.erc20Tokens[common.HexToAddress("0xfc5A1A6EB076a2C7aD06eD22C90d7E710E35ad0a")] = "GMX" // GMX
r.erc20Tokens[common.HexToAddress("0x6C2C06790b3E3E3c38e12Ee22F8183b37a13EE55")] = "DPX" // Dopex Governance Token
r.erc20Tokens[common.HexToAddress("0x10393c20975cF177a3513071bC110f7962CD67da")] = "JONES" // JonesDAO
r.erc20Tokens[common.HexToAddress("0x4e352cf164e64adcbad318c3a1e222e9eba4ce42")] = "MCB" // MCDEX Token
r.erc20Tokens[common.HexToAddress("0x23A941036Ae778Ac51Ab04CEa08Ed6e2AF33b49")] = "RDNT" // Radiant Capital
r.erc20Tokens[common.HexToAddress("0x6694340fc020c5E6B96567843da2df01b2CE1eb6")] = "STG" // Stargate Finance
r.erc20Tokens[common.HexToAddress("0x3082CC23568eA640225c2467653dB90e9250AaA0")] = "RDNT" // Radiant (alternative)
r.erc20Tokens[common.HexToAddress("0x51fC0f6660482Ea73330E414eFd7808811a57Fa2")] = "PREMIA" // Premia
r.erc20Tokens[common.HexToAddress("0x69Eb4FA4a2fbd498C257C57Ea8b7655a2559A581")] = "DODO" // DODO
}
// addKnownPools adds known liquidity pools
func (r *KnownContractRegistry) addKnownPools() {
// Major Uniswap V3 pools on Arbitrum
r.pools[common.HexToAddress("0xC31E54c7a869B9FcBEcc14363CF510d1c41fa443")] = "USDC/WETH-0.05%"
r.pools[common.HexToAddress("0x17c14D2c404D167802b16C450d3c99F88F2c4F4d")] = "USDC/WETH-0.3%"
r.pools[common.HexToAddress("0x641C00A822e8b671738d32a431a4Fb6074E5c79d")] = "WETH/ARB-0.3%"
r.pools[common.HexToAddress("0xdE64C63e6BaD1Ff18f4F1bdc9d1e7Bbfb5E0B6FD")] = "USDT/USDC-0.01%"
r.pools[common.HexToAddress("0x2f5e87C9312fa29aed5c179E456625D79015299c")] = "WBTC/WETH-0.05%"
}
// addKnownRouters adds known router contracts
func (r *KnownContractRegistry) addKnownRouters() {
// Uniswap V3 SwapRouter
r.routers[common.HexToAddress("0xE592427A0AEce92De3Edee1F18E0157C05861564")] = "UniswapV3Router"
// Uniswap V3 SwapRouter02
r.routers[common.HexToAddress("0x68b3465833fb72A70ecDF485E0e4C7bD8665Fc45")] = "UniswapV3Router02"
// SushiSwap Router
r.routers[common.HexToAddress("0x1b02dA8Cb0d097eB8D57A175b88c7D8b47997506")] = "SushiSwapRouter"
// Camelot Router
r.routers[common.HexToAddress("0xc873fEcbd354f5A56E00E710B90EF4201db2448d")] = "CamelotRouter"
}
// IsKnownERC20 checks if an address is a known ERC-20 token
func (r *KnownContractRegistry) IsKnownERC20(address common.Address) (bool, string) {
name, exists := r.erc20Tokens[address]
return exists, name
}
// IsKnownPool checks if an address is a known liquidity pool
func (r *KnownContractRegistry) IsKnownPool(address common.Address) (bool, string) {
name, exists := r.pools[address]
return exists, name
}
// IsKnownRouter checks if an address is a known router contract
func (r *KnownContractRegistry) IsKnownRouter(address common.Address) (bool, string) {
name, exists := r.routers[address]
return exists, name
}
// GetContractType returns the type of a known contract
func (r *KnownContractRegistry) GetContractType(address common.Address) (ContractType, string) {
if isERC20, name := r.IsKnownERC20(address); isERC20 {
return ContractTypeERC20Token, name
}
if isPool, name := r.IsKnownPool(address); isPool {
return ContractTypeUniswapV3Pool, name
}
if isRouter, name := r.IsKnownRouter(address); isRouter {
return ContractTypeRouter, name
}
return ContractTypeUnknown, ""
}
// ValidatePoolCall validates if a pool-specific operation should be allowed on an address
func (r *KnownContractRegistry) ValidatePoolCall(address common.Address, operation string) error {
if isERC20, name := r.IsKnownERC20(address); isERC20 {
return &ValidationError{
Code: "INVALID_POOL_OPERATION",
Message: fmt.Sprintf("Attempted pool operation '%s' on known ERC-20 token %s (%s)", operation, name, address.Hex()),
Context: map[string]interface{}{
"address": address.Hex(),
"token_name": name,
"operation": operation,
"issue": "ERC-20 tokens do not have pool-specific functions like slot0()",
},
}
}
return nil
}
// GetCorruptionPattern analyzes an address for known corruption patterns
func (r *KnownContractRegistry) GetCorruptionPattern(address common.Address) string {
hexAddr := address.Hex()
// Check for zero address
if address == (common.Address{}) {
return "ZERO_ADDRESS"
}
// Check for mostly zeros with small values
if strings.HasSuffix(hexAddr, "0000000000000000000000000000000000000000") {
return "TRAILING_ZEROS"
}
// Check for leading zeros with small values
if strings.HasPrefix(hexAddr, "0x00000000") {
return "LEADING_ZEROS_CORRUPTION"
}
// Check for embedded WETH/USDC patterns (indicates address extraction issues)
if strings.Contains(strings.ToLower(hexAddr), "82af49447d8a07e3bd95bd0d56f35241523fbab1") {
return "EMBEDDED_WETH_PATTERN"
}
if strings.Contains(strings.ToLower(hexAddr), "af88d065e77c8cc2239327c5edb3a432268e5831") {
return "EMBEDDED_USDC_PATTERN"
}
return "OTHER_CORRUPTION"
}