package security import ( "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "math/rand" "net/http" "os" "reflect" "sync" "time" "golang.org/x/time/rate" "github.com/fraktal/mev-beta/internal/logger" ) // SecurityManager provides centralized security management for the MEV bot type SecurityManager struct { keyManager *KeyManager inputValidator *InputValidator rateLimiter *RateLimiter monitor *SecurityMonitor config *SecurityConfig logger *logger.Logger // Circuit breakers for different components rpcCircuitBreaker *CircuitBreaker arbitrageCircuitBreaker *CircuitBreaker // TLS configuration tlsConfig *tls.Config // Rate limiters for different operations transactionLimiter *rate.Limiter rpcLimiter *rate.Limiter // Security state emergencyMode bool securityAlerts []SecurityAlert alertsMutex sync.RWMutex // Metrics managerMetrics *ManagerMetrics rpcHTTPClient *http.Client } // SecurityConfig contains all security-related configuration type SecurityConfig struct { // Key management KeyStoreDir string `yaml:"keystore_dir"` EncryptionEnabled bool `yaml:"encryption_enabled"` // Rate limiting TransactionRPS int `yaml:"transaction_rps"` RPCRPS int `yaml:"rpc_rps"` MaxBurstSize int `yaml:"max_burst_size"` // Circuit breaker settings FailureThreshold int `yaml:"failure_threshold"` RecoveryTimeout time.Duration `yaml:"recovery_timeout"` // TLS settings TLSMinVersion uint16 `yaml:"tls_min_version"` TLSCipherSuites []uint16 `yaml:"tls_cipher_suites"` // Emergency settings EmergencyStopFile string `yaml:"emergency_stop_file"` MaxGasPrice string `yaml:"max_gas_price"` // Monitoring AlertWebhookURL string `yaml:"alert_webhook_url"` LogLevel string `yaml:"log_level"` // RPC endpoint used by secure RPC call delegation RPCURL string `yaml:"rpc_url"` } // Additional security metrics for SecurityManager type ManagerMetrics struct { AuthenticationAttempts int64 `json:"authentication_attempts"` FailedAuthentications int64 `json:"failed_authentications"` CircuitBreakerTrips int64 `json:"circuit_breaker_trips"` EmergencyStops int64 `json:"emergency_stops"` TLSHandshakeFailures int64 `json:"tls_handshake_failures"` } // CircuitBreaker implements the circuit breaker pattern for fault tolerance type CircuitBreaker struct { name string failureCount int lastFailureTime time.Time state CircuitBreakerState config CircuitBreakerConfig mutex sync.RWMutex } type CircuitBreakerState int const ( CircuitBreakerClosed CircuitBreakerState = iota CircuitBreakerOpen CircuitBreakerHalfOpen ) type CircuitBreakerConfig struct { FailureThreshold int RecoveryTimeout time.Duration MaxRetries int } // NewSecurityManager creates a new security manager with comprehensive protection func NewSecurityManager(config *SecurityConfig) (*SecurityManager, error) { if config == nil { return nil, fmt.Errorf("security config cannot be nil") } // Initialize key manager - get encryption key from environment or fail encryptionKey := os.Getenv("MEV_BOT_ENCRYPTION_KEY") if encryptionKey == "" { return nil, fmt.Errorf("MEV_BOT_ENCRYPTION_KEY environment variable is required") } keyManagerConfig := &KeyManagerConfig{ KeyDir: config.KeyStoreDir, KeystorePath: config.KeyStoreDir, EncryptionKey: encryptionKey, BackupEnabled: true, MaxFailedAttempts: 3, LockoutDuration: 5 * time.Minute, } keyManager, err := NewKeyManager(keyManagerConfig, logger.New("info", "json", "logs/keymanager.log")) if err != nil { return nil, fmt.Errorf("failed to initialize key manager: %w", err) } // Initialize input validator inputValidator := NewInputValidator(1) // Default chain ID // Initialize rate limiter rateLimiterConfig := &RateLimiterConfig{ IPRequestsPerSecond: 100, IPBurstSize: config.MaxBurstSize, IPBlockDuration: 5 * time.Minute, UserRequestsPerSecond: config.TransactionRPS, UserBurstSize: config.MaxBurstSize, UserBlockDuration: 5 * time.Minute, CleanupInterval: 5 * time.Minute, } rateLimiter := NewRateLimiter(rateLimiterConfig) // Initialize security monitor monitorConfig := &MonitorConfig{ EnableAlerts: true, AlertBuffer: 1000, AlertRetention: 24 * time.Hour, MaxEvents: 10000, EventRetention: 7 * 24 * time.Hour, MetricsInterval: time.Minute, CleanupInterval: time.Hour, } monitor := NewSecurityMonitor(monitorConfig) // Create TLS configuration with security best practices tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, // Enforce TLS 1.2 minimum CipherSuites: config.TLSCipherSuites, InsecureSkipVerify: false, PreferServerCipherSuites: true, } // Initialize circuit breakers rpcCircuitBreaker := &CircuitBreaker{ name: "rpc", config: CircuitBreakerConfig{ FailureThreshold: config.FailureThreshold, RecoveryTimeout: config.RecoveryTimeout, MaxRetries: 3, }, state: CircuitBreakerClosed, } arbitrageCircuitBreaker := &CircuitBreaker{ name: "arbitrage", config: CircuitBreakerConfig{ FailureThreshold: config.FailureThreshold, RecoveryTimeout: config.RecoveryTimeout, MaxRetries: 3, }, state: CircuitBreakerClosed, } // Initialize rate limiters transactionLimiter := rate.NewLimiter(rate.Limit(config.TransactionRPS), config.MaxBurstSize) rpcLimiter := rate.NewLimiter(rate.Limit(config.RPCRPS), config.MaxBurstSize) // Create logger instance securityLogger := logger.New("info", "json", "logs/security.log") httpTransport := &http.Transport{ TLSClientConfig: tlsConfig, } sm := &SecurityManager{ keyManager: keyManager, inputValidator: inputValidator, rateLimiter: rateLimiter, monitor: monitor, config: config, logger: securityLogger, rpcCircuitBreaker: rpcCircuitBreaker, arbitrageCircuitBreaker: arbitrageCircuitBreaker, tlsConfig: tlsConfig, transactionLimiter: transactionLimiter, rpcLimiter: rpcLimiter, emergencyMode: false, securityAlerts: make([]SecurityAlert, 0), managerMetrics: &ManagerMetrics{}, rpcHTTPClient: &http.Client{ Timeout: 30 * time.Second, Transport: httpTransport, }, } // Start security monitoring go sm.startSecurityMonitoring() sm.logger.Info("Security manager initialized successfully") return sm, nil } // ValidateTransaction performs comprehensive transaction validation func (sm *SecurityManager) ValidateTransaction(ctx context.Context, txParams *TransactionParams) error { // Check rate limiting if !sm.transactionLimiter.Allow() { if sm.monitor != nil { sm.monitor.RecordEvent(EventTypeError, "security_manager", "Transaction rate limit exceeded", SeverityMedium, map[string]interface{}{ "limit_type": "transaction", }) } return fmt.Errorf("transaction rate limit exceeded") } // Check emergency mode if sm.emergencyMode { return fmt.Errorf("system in emergency mode - transactions disabled") } // Validate input parameters (simplified validation) if txParams.To == nil { return fmt.Errorf("transaction validation failed: missing recipient") } if txParams.Value == nil { return fmt.Errorf("transaction validation failed: missing value") } // Check circuit breaker state if sm.arbitrageCircuitBreaker.state == CircuitBreakerOpen { return fmt.Errorf("arbitrage circuit breaker is open") } return nil } // SecureRPCCall performs RPC calls with security controls func (sm *SecurityManager) SecureRPCCall(ctx context.Context, method string, params interface{}) (interface{}, error) { // Check rate limiting if !sm.rpcLimiter.Allow() { if sm.monitor != nil { sm.monitor.RecordEvent(EventTypeError, "security_manager", "RPC rate limit exceeded", SeverityMedium, map[string]interface{}{ "limit_type": "rpc", "method": method, }) } return nil, fmt.Errorf("RPC rate limit exceeded") } // Check circuit breaker if sm.rpcCircuitBreaker.state == CircuitBreakerOpen { return nil, fmt.Errorf("RPC circuit breaker is open") } if sm.config.RPCURL == "" { err := errors.New("RPC endpoint not configured in security manager") sm.RecordFailure("rpc", err) return nil, err } paramList, err := normalizeRPCParams(params) if err != nil { sm.RecordFailure("rpc", err) return nil, err } requestPayload := jsonRPCRequest{ JSONRPC: "2.0", Method: method, Params: paramList, ID: fmt.Sprintf("sm-%d", rand.Int63()), } body, err := json.Marshal(requestPayload) if err != nil { sm.RecordFailure("rpc", err) return nil, fmt.Errorf("failed to marshal RPC request: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, sm.config.RPCURL, bytes.NewReader(body)) if err != nil { sm.RecordFailure("rpc", err) return nil, fmt.Errorf("failed to create RPC request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := sm.rpcHTTPClient.Do(req) if err != nil { sm.RecordFailure("rpc", err) return nil, fmt.Errorf("RPC call failed: %w", err) } defer resp.Body.Close() if resp.StatusCode >= 400 { sm.RecordFailure("rpc", fmt.Errorf("rpc endpoint returned status %d", resp.StatusCode)) return nil, fmt.Errorf("rpc endpoint returned status %d", resp.StatusCode) } var rpcResp jsonRPCResponse if err := json.NewDecoder(resp.Body).Decode(&rpcResp); err != nil { sm.RecordFailure("rpc", err) return nil, fmt.Errorf("failed to decode RPC response: %w", err) } if rpcResp.Error != nil { err := fmt.Errorf("rpc error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message) sm.RecordFailure("rpc", err) return nil, err } var result interface{} if len(rpcResp.Result) > 0 { if err := json.Unmarshal(rpcResp.Result, &result); err != nil { // If we cannot unmarshal into interface{}, return raw JSON result = string(rpcResp.Result) } } sm.RecordSuccess("rpc") return result, nil } // TriggerEmergencyStop activates emergency mode func (sm *SecurityManager) TriggerEmergencyStop(reason string) error { sm.emergencyMode = true sm.managerMetrics.EmergencyStops++ alert := SecurityAlert{ ID: fmt.Sprintf("emergency-%d", time.Now().Unix()), Timestamp: time.Now(), Level: AlertLevelCritical, Type: AlertTypeConfiguration, Title: "Emergency Stop Activated", Description: fmt.Sprintf("Emergency stop triggered: %s", reason), Source: "security_manager", Data: map[string]interface{}{ "reason": reason, }, Actions: []string{"investigate_cause", "review_logs", "manual_restart_required"}, } sm.addSecurityAlert(alert) sm.logger.Error("Emergency stop triggered: " + reason) return nil } // RecordFailure records a failure for circuit breaker logic func (sm *SecurityManager) RecordFailure(component string, err error) { var cb *CircuitBreaker switch component { case "rpc": cb = sm.rpcCircuitBreaker case "arbitrage": cb = sm.arbitrageCircuitBreaker default: return } cb.mutex.Lock() defer cb.mutex.Unlock() cb.failureCount++ cb.lastFailureTime = time.Now() if cb.failureCount >= cb.config.FailureThreshold && cb.state == CircuitBreakerClosed { cb.state = CircuitBreakerOpen sm.managerMetrics.CircuitBreakerTrips++ alert := SecurityAlert{ ID: fmt.Sprintf("circuit-breaker-%s-%d", component, time.Now().Unix()), Timestamp: time.Now(), Level: AlertLevelError, Type: AlertTypePerformance, Title: "Circuit Breaker Opened", Description: fmt.Sprintf("Circuit breaker opened for component: %s", component), Source: "security_manager", Data: map[string]interface{}{ "component": component, "failure_count": cb.failureCount, "error": err.Error(), }, Actions: []string{"investigate_failures", "check_component_health", "manual_intervention_required"}, } sm.addSecurityAlert(alert) sm.logger.Warn(fmt.Sprintf("Circuit breaker opened for component: %s, failure count: %d", component, cb.failureCount)) } } // RecordSuccess records a success for circuit breaker logic func (sm *SecurityManager) RecordSuccess(component string) { var cb *CircuitBreaker switch component { case "rpc": cb = sm.rpcCircuitBreaker case "arbitrage": cb = sm.arbitrageCircuitBreaker default: return } cb.mutex.Lock() defer cb.mutex.Unlock() if cb.state == CircuitBreakerHalfOpen { cb.state = CircuitBreakerClosed cb.failureCount = 0 sm.logger.Info(fmt.Sprintf("Circuit breaker closed for component: %s", component)) } } // addSecurityAlert adds a security alert to the system func (sm *SecurityManager) addSecurityAlert(alert SecurityAlert) { sm.alertsMutex.Lock() defer sm.alertsMutex.Unlock() sm.securityAlerts = append(sm.securityAlerts, alert) // Send alert to monitor if available if sm.monitor != nil { sm.monitor.TriggerAlert(alert.Level, alert.Type, alert.Title, alert.Description, alert.Source, alert.Data, alert.Actions) } // Keep only last 1000 alerts if len(sm.securityAlerts) > 1000 { sm.securityAlerts = sm.securityAlerts[len(sm.securityAlerts)-1000:] } } // startSecurityMonitoring starts background security monitoring func (sm *SecurityManager) startSecurityMonitoring() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: sm.performSecurityChecks() } } } // performSecurityChecks performs periodic security checks func (sm *SecurityManager) performSecurityChecks() { // Check circuit breakers for recovery sm.checkCircuitBreakerRecovery(sm.rpcCircuitBreaker) sm.checkCircuitBreakerRecovery(sm.arbitrageCircuitBreaker) // Check for emergency stop file if sm.config.EmergencyStopFile != "" { if _, err := os.Stat(sm.config.EmergencyStopFile); err == nil { sm.TriggerEmergencyStop("emergency stop file detected") } } } // checkCircuitBreakerRecovery checks if circuit breakers can transition to half-open func (sm *SecurityManager) checkCircuitBreakerRecovery(cb *CircuitBreaker) { cb.mutex.Lock() defer cb.mutex.Unlock() if cb.state == CircuitBreakerOpen && time.Since(cb.lastFailureTime) > cb.config.RecoveryTimeout { cb.state = CircuitBreakerHalfOpen sm.logger.Info(fmt.Sprintf("Circuit breaker transitioned to half-open for component: %s", cb.name)) } } // GetManagerMetrics returns current manager metrics func (sm *SecurityManager) GetManagerMetrics() *ManagerMetrics { return sm.managerMetrics } // GetSecurityMetrics returns current security metrics from monitor func (sm *SecurityManager) GetSecurityMetrics() *SecurityMetrics { if sm.monitor != nil { return sm.monitor.GetMetrics() } return &SecurityMetrics{} } // GetSecurityAlerts returns recent security alerts func (sm *SecurityManager) GetSecurityAlerts(limit int) []SecurityAlert { sm.alertsMutex.RLock() defer sm.alertsMutex.RUnlock() if limit <= 0 || limit > len(sm.securityAlerts) { limit = len(sm.securityAlerts) } start := len(sm.securityAlerts) - limit if start < 0 { start = 0 } alerts := make([]SecurityAlert, limit) copy(alerts, sm.securityAlerts[start:]) return alerts } // Shutdown gracefully shuts down the security manager func (sm *SecurityManager) Shutdown(ctx context.Context) error { sm.logger.Info("Shutting down security manager") // Shutdown components if sm.keyManager != nil { // Key manager shutdown - simplified (no shutdown method needed) sm.logger.Info("Key manager stopped") } if sm.rateLimiter != nil { // Rate limiter shutdown - simplified sm.logger.Info("Rate limiter stopped") } if sm.monitor != nil { // Monitor shutdown - simplified sm.logger.Info("Security monitor stopped") } if sm.rpcHTTPClient != nil { sm.rpcHTTPClient.CloseIdleConnections() } return nil } type jsonRPCRequest struct { JSONRPC string `json:"jsonrpc"` Method string `json:"method"` Params []interface{} `json:"params"` ID string `json:"id"` } type jsonRPCError struct { Code int `json:"code"` Message string `json:"message"` } type jsonRPCResponse struct { JSONRPC string `json:"jsonrpc"` Result json.RawMessage `json:"result"` Error *jsonRPCError `json:"error,omitempty"` ID string `json:"id"` } func normalizeRPCParams(params interface{}) ([]interface{}, error) { if params == nil { return []interface{}{}, nil } switch v := params.(type) { case []interface{}: return v, nil case []string: result := make([]interface{}, len(v)) for i := range v { result[i] = v[i] } return result, nil case []int: result := make([]interface{}, len(v)) for i := range v { result[i] = v[i] } return result, nil } val := reflect.ValueOf(params) if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { length := val.Len() result := make([]interface{}, length) for i := 0; i < length; i++ { result[i] = val.Index(i).Interface() } return result, nil } return []interface{}{params}, nil }