Files
mev-beta/orig/pkg/security/security_manager.go
Administrator 803de231ba feat: create v2-prep branch with comprehensive planning
Restructured project for V2 refactor:

**Structure Changes:**
- Moved all V1 code to orig/ folder (preserved with git mv)
- Created docs/planning/ directory
- Added orig/README_V1.md explaining V1 preservation

**Planning Documents:**
- 00_V2_MASTER_PLAN.md: Complete architecture overview
  - Executive summary of critical V1 issues
  - High-level component architecture diagrams
  - 5-phase implementation roadmap
  - Success metrics and risk mitigation

- 07_TASK_BREAKDOWN.md: Atomic task breakdown
  - 99+ hours of detailed tasks
  - Every task < 2 hours (atomic)
  - Clear dependencies and success criteria
  - Organized by implementation phase

**V2 Key Improvements:**
- Per-exchange parsers (factory pattern)
- Multi-layer strict validation
- Multi-index pool cache
- Background validation pipeline
- Comprehensive observability

**Critical Issues Addressed:**
- Zero address tokens (strict validation + cache enrichment)
- Parsing accuracy (protocol-specific parsers)
- No audit trail (background validation channel)
- Inefficient lookups (multi-index cache)
- Stats disconnection (event-driven metrics)

Next Steps:
1. Review planning documents
2. Begin Phase 1: Foundation (P1-001 through P1-010)
3. Implement parsers in Phase 2
4. Build cache system in Phase 3
5. Add validation pipeline in Phase 4
6. Migrate and test in Phase 5

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-10 10:14:26 +01:00

624 lines
17 KiB
Go

package security
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"math/rand"
"net/http"
"os"
"reflect"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/fraktal/mev-beta/internal/logger"
)
// SecurityManager provides centralized security management for the MEV bot
type SecurityManager struct {
keyManager *KeyManager
inputValidator *InputValidator
rateLimiter *RateLimiter
monitor *SecurityMonitor
config *SecurityConfig
logger *logger.Logger
// Circuit breakers for different components
rpcCircuitBreaker *CircuitBreaker
arbitrageCircuitBreaker *CircuitBreaker
// TLS configuration
tlsConfig *tls.Config
// Rate limiters for different operations
transactionLimiter *rate.Limiter
rpcLimiter *rate.Limiter
// Security state
emergencyMode bool
securityAlerts []SecurityAlert
alertsMutex sync.RWMutex
// Metrics
managerMetrics *ManagerMetrics
rpcHTTPClient *http.Client
}
// SecurityConfig contains all security-related configuration
type SecurityConfig struct {
// Key management
KeyStoreDir string `yaml:"keystore_dir"`
EncryptionEnabled bool `yaml:"encryption_enabled"`
// Rate limiting
TransactionRPS int `yaml:"transaction_rps"`
RPCRPS int `yaml:"rpc_rps"`
MaxBurstSize int `yaml:"max_burst_size"`
// Circuit breaker settings
FailureThreshold int `yaml:"failure_threshold"`
RecoveryTimeout time.Duration `yaml:"recovery_timeout"`
// TLS settings
TLSMinVersion uint16 `yaml:"tls_min_version"`
TLSCipherSuites []uint16 `yaml:"tls_cipher_suites"`
// Emergency settings
EmergencyStopFile string `yaml:"emergency_stop_file"`
MaxGasPrice string `yaml:"max_gas_price"`
// Monitoring
AlertWebhookURL string `yaml:"alert_webhook_url"`
LogLevel string `yaml:"log_level"`
// RPC endpoint used by secure RPC call delegation
RPCURL string `yaml:"rpc_url"`
}
// Additional security metrics for SecurityManager
type ManagerMetrics struct {
AuthenticationAttempts int64 `json:"authentication_attempts"`
FailedAuthentications int64 `json:"failed_authentications"`
CircuitBreakerTrips int64 `json:"circuit_breaker_trips"`
EmergencyStops int64 `json:"emergency_stops"`
TLSHandshakeFailures int64 `json:"tls_handshake_failures"`
}
// CircuitBreaker implements the circuit breaker pattern for fault tolerance
type CircuitBreaker struct {
name string
failureCount int
lastFailureTime time.Time
state CircuitBreakerState
config CircuitBreakerConfig
mutex sync.RWMutex
}
type CircuitBreakerState int
const (
CircuitBreakerClosed CircuitBreakerState = iota
CircuitBreakerOpen
CircuitBreakerHalfOpen
)
type CircuitBreakerConfig struct {
FailureThreshold int
RecoveryTimeout time.Duration
MaxRetries int
}
// NewSecurityManager creates a new security manager with comprehensive protection
func NewSecurityManager(config *SecurityConfig) (*SecurityManager, error) {
if config == nil {
return nil, fmt.Errorf("security config cannot be nil")
}
// Initialize key manager - get encryption key from environment or fail
encryptionKey := os.Getenv("MEV_BOT_ENCRYPTION_KEY")
if encryptionKey == "" {
return nil, fmt.Errorf("MEV_BOT_ENCRYPTION_KEY environment variable is required")
}
keyManagerConfig := &KeyManagerConfig{
KeyDir: config.KeyStoreDir,
KeystorePath: config.KeyStoreDir,
EncryptionKey: encryptionKey,
BackupEnabled: true,
MaxFailedAttempts: 3,
LockoutDuration: 5 * time.Minute,
}
keyManager, err := NewKeyManager(keyManagerConfig, logger.New("info", "json", "logs/keymanager.log"))
if err != nil {
return nil, fmt.Errorf("failed to initialize key manager: %w", err)
}
// Initialize input validator
inputValidator := NewInputValidator(1) // Default chain ID
// Initialize rate limiter
rateLimiterConfig := &RateLimiterConfig{
IPRequestsPerSecond: 100,
IPBurstSize: config.MaxBurstSize,
IPBlockDuration: 5 * time.Minute,
UserRequestsPerSecond: config.TransactionRPS,
UserBurstSize: config.MaxBurstSize,
UserBlockDuration: 5 * time.Minute,
CleanupInterval: 5 * time.Minute,
}
rateLimiter := NewRateLimiter(rateLimiterConfig)
// Initialize security monitor
monitorConfig := &MonitorConfig{
EnableAlerts: true,
AlertBuffer: 1000,
AlertRetention: 24 * time.Hour,
MaxEvents: 10000,
EventRetention: 7 * 24 * time.Hour,
MetricsInterval: time.Minute,
CleanupInterval: time.Hour,
}
monitor := NewSecurityMonitor(monitorConfig)
// Create TLS configuration with security best practices
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12, // Enforce TLS 1.2 minimum
CipherSuites: config.TLSCipherSuites,
InsecureSkipVerify: false,
PreferServerCipherSuites: true,
}
// Initialize circuit breakers
rpcCircuitBreaker := &CircuitBreaker{
name: "rpc",
config: CircuitBreakerConfig{
FailureThreshold: config.FailureThreshold,
RecoveryTimeout: config.RecoveryTimeout,
MaxRetries: 3,
},
state: CircuitBreakerClosed,
}
arbitrageCircuitBreaker := &CircuitBreaker{
name: "arbitrage",
config: CircuitBreakerConfig{
FailureThreshold: config.FailureThreshold,
RecoveryTimeout: config.RecoveryTimeout,
MaxRetries: 3,
},
state: CircuitBreakerClosed,
}
// Initialize rate limiters
transactionLimiter := rate.NewLimiter(rate.Limit(config.TransactionRPS), config.MaxBurstSize)
rpcLimiter := rate.NewLimiter(rate.Limit(config.RPCRPS), config.MaxBurstSize)
// Create logger instance
securityLogger := logger.New("info", "json", "logs/security.log")
httpTransport := &http.Transport{
TLSClientConfig: tlsConfig,
}
sm := &SecurityManager{
keyManager: keyManager,
inputValidator: inputValidator,
rateLimiter: rateLimiter,
monitor: monitor,
config: config,
logger: securityLogger,
rpcCircuitBreaker: rpcCircuitBreaker,
arbitrageCircuitBreaker: arbitrageCircuitBreaker,
tlsConfig: tlsConfig,
transactionLimiter: transactionLimiter,
rpcLimiter: rpcLimiter,
emergencyMode: false,
securityAlerts: make([]SecurityAlert, 0),
managerMetrics: &ManagerMetrics{},
rpcHTTPClient: &http.Client{
Timeout: 30 * time.Second,
Transport: httpTransport,
},
}
// Start security monitoring
go sm.startSecurityMonitoring()
sm.logger.Info("Security manager initialized successfully")
return sm, nil
}
// GetKeyManager returns the KeyManager instance managed by SecurityManager
// SECURITY FIX: Provides single source of truth for KeyManager to prevent multiple instances
// with different encryption keys (which would cause key derivation mismatches)
func (sm *SecurityManager) GetKeyManager() *KeyManager {
return sm.keyManager
}
// ValidateTransaction performs comprehensive transaction validation
func (sm *SecurityManager) ValidateTransaction(ctx context.Context, txParams *TransactionParams) error {
// Check rate limiting
if !sm.transactionLimiter.Allow() {
if sm.monitor != nil {
sm.monitor.RecordEvent(EventTypeError, "security_manager", "Transaction rate limit exceeded", SeverityMedium, map[string]interface{}{
"limit_type": "transaction",
})
}
return fmt.Errorf("transaction rate limit exceeded")
}
// Check emergency mode
if sm.emergencyMode {
return fmt.Errorf("system in emergency mode - transactions disabled")
}
// Validate input parameters (simplified validation)
if txParams.To == nil {
return fmt.Errorf("transaction validation failed: missing recipient")
}
if txParams.Value == nil {
return fmt.Errorf("transaction validation failed: missing value")
}
// Check circuit breaker state
if sm.arbitrageCircuitBreaker.state == CircuitBreakerOpen {
return fmt.Errorf("arbitrage circuit breaker is open")
}
return nil
}
// SecureRPCCall performs RPC calls with security controls
func (sm *SecurityManager) SecureRPCCall(ctx context.Context, method string, params interface{}) (interface{}, error) {
// Check rate limiting
if !sm.rpcLimiter.Allow() {
if sm.monitor != nil {
sm.monitor.RecordEvent(EventTypeError, "security_manager", "RPC rate limit exceeded", SeverityMedium, map[string]interface{}{
"limit_type": "rpc",
"method": method,
})
}
return nil, fmt.Errorf("RPC rate limit exceeded")
}
// Check circuit breaker
if sm.rpcCircuitBreaker.state == CircuitBreakerOpen {
return nil, fmt.Errorf("RPC circuit breaker is open")
}
if sm.config.RPCURL == "" {
err := errors.New("RPC endpoint not configured in security manager")
sm.RecordFailure("rpc", err)
return nil, err
}
paramList, err := normalizeRPCParams(params)
if err != nil {
sm.RecordFailure("rpc", err)
return nil, err
}
requestPayload := jsonRPCRequest{
JSONRPC: "2.0",
Method: method,
Params: paramList,
ID: fmt.Sprintf("sm-%d", rand.Int63()),
}
body, err := json.Marshal(requestPayload)
if err != nil {
sm.RecordFailure("rpc", err)
return nil, fmt.Errorf("failed to marshal RPC request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, sm.config.RPCURL, bytes.NewReader(body))
if err != nil {
sm.RecordFailure("rpc", err)
return nil, fmt.Errorf("failed to create RPC request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := sm.rpcHTTPClient.Do(req)
if err != nil {
sm.RecordFailure("rpc", err)
return nil, fmt.Errorf("RPC call failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
sm.RecordFailure("rpc", fmt.Errorf("rpc endpoint returned status %d", resp.StatusCode))
return nil, fmt.Errorf("rpc endpoint returned status %d", resp.StatusCode)
}
var rpcResp jsonRPCResponse
if err := json.NewDecoder(resp.Body).Decode(&rpcResp); err != nil {
sm.RecordFailure("rpc", err)
return nil, fmt.Errorf("failed to decode RPC response: %w", err)
}
if rpcResp.Error != nil {
err := fmt.Errorf("rpc error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message)
sm.RecordFailure("rpc", err)
return nil, err
}
var result interface{}
if len(rpcResp.Result) > 0 {
if err := json.Unmarshal(rpcResp.Result, &result); err != nil {
// If we cannot unmarshal into interface{}, return raw JSON
result = string(rpcResp.Result)
}
}
sm.RecordSuccess("rpc")
return result, nil
}
// TriggerEmergencyStop activates emergency mode
func (sm *SecurityManager) TriggerEmergencyStop(reason string) error {
sm.emergencyMode = true
sm.managerMetrics.EmergencyStops++
alert := SecurityAlert{
ID: fmt.Sprintf("emergency-%d", time.Now().Unix()),
Timestamp: time.Now(),
Level: AlertLevelCritical,
Type: AlertTypeConfiguration,
Title: "Emergency Stop Activated",
Description: fmt.Sprintf("Emergency stop triggered: %s", reason),
Source: "security_manager",
Data: map[string]interface{}{
"reason": reason,
},
Actions: []string{"investigate_cause", "review_logs", "manual_restart_required"},
}
sm.addSecurityAlert(alert)
sm.logger.Error("Emergency stop triggered: " + reason)
return nil
}
// RecordFailure records a failure for circuit breaker logic
func (sm *SecurityManager) RecordFailure(component string, err error) {
var cb *CircuitBreaker
switch component {
case "rpc":
cb = sm.rpcCircuitBreaker
case "arbitrage":
cb = sm.arbitrageCircuitBreaker
default:
return
}
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.failureCount++
cb.lastFailureTime = time.Now()
if cb.failureCount >= cb.config.FailureThreshold && cb.state == CircuitBreakerClosed {
cb.state = CircuitBreakerOpen
sm.managerMetrics.CircuitBreakerTrips++
alert := SecurityAlert{
ID: fmt.Sprintf("circuit-breaker-%s-%d", component, time.Now().Unix()),
Timestamp: time.Now(),
Level: AlertLevelError,
Type: AlertTypePerformance,
Title: "Circuit Breaker Opened",
Description: fmt.Sprintf("Circuit breaker opened for component: %s", component),
Source: "security_manager",
Data: map[string]interface{}{
"component": component,
"failure_count": cb.failureCount,
"error": err.Error(),
},
Actions: []string{"investigate_failures", "check_component_health", "manual_intervention_required"},
}
sm.addSecurityAlert(alert)
sm.logger.Warn(fmt.Sprintf("Circuit breaker opened for component: %s, failure count: %d", component, cb.failureCount))
}
}
// RecordSuccess records a success for circuit breaker logic
func (sm *SecurityManager) RecordSuccess(component string) {
var cb *CircuitBreaker
switch component {
case "rpc":
cb = sm.rpcCircuitBreaker
case "arbitrage":
cb = sm.arbitrageCircuitBreaker
default:
return
}
cb.mutex.Lock()
defer cb.mutex.Unlock()
if cb.state == CircuitBreakerHalfOpen {
cb.state = CircuitBreakerClosed
cb.failureCount = 0
sm.logger.Info(fmt.Sprintf("Circuit breaker closed for component: %s", component))
}
}
// addSecurityAlert adds a security alert to the system
func (sm *SecurityManager) addSecurityAlert(alert SecurityAlert) {
sm.alertsMutex.Lock()
defer sm.alertsMutex.Unlock()
sm.securityAlerts = append(sm.securityAlerts, alert)
// Send alert to monitor if available
if sm.monitor != nil {
sm.monitor.TriggerAlert(alert.Level, alert.Type, alert.Title, alert.Description, alert.Source, alert.Data, alert.Actions)
}
// Keep only last 1000 alerts
if len(sm.securityAlerts) > 1000 {
sm.securityAlerts = sm.securityAlerts[len(sm.securityAlerts)-1000:]
}
}
// startSecurityMonitoring starts background security monitoring
func (sm *SecurityManager) startSecurityMonitoring() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
sm.performSecurityChecks()
}
}
}
// performSecurityChecks performs periodic security checks
func (sm *SecurityManager) performSecurityChecks() {
// Check circuit breakers for recovery
sm.checkCircuitBreakerRecovery(sm.rpcCircuitBreaker)
sm.checkCircuitBreakerRecovery(sm.arbitrageCircuitBreaker)
// Check for emergency stop file
if sm.config.EmergencyStopFile != "" {
if _, err := os.Stat(sm.config.EmergencyStopFile); err == nil {
sm.TriggerEmergencyStop("emergency stop file detected")
}
}
}
// checkCircuitBreakerRecovery checks if circuit breakers can transition to half-open
func (sm *SecurityManager) checkCircuitBreakerRecovery(cb *CircuitBreaker) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
if cb.state == CircuitBreakerOpen &&
time.Since(cb.lastFailureTime) > cb.config.RecoveryTimeout {
cb.state = CircuitBreakerHalfOpen
sm.logger.Info(fmt.Sprintf("Circuit breaker transitioned to half-open for component: %s", cb.name))
}
}
// GetManagerMetrics returns current manager metrics
func (sm *SecurityManager) GetManagerMetrics() *ManagerMetrics {
return sm.managerMetrics
}
// GetSecurityMetrics returns current security metrics from monitor
func (sm *SecurityManager) GetSecurityMetrics() *SecurityMetrics {
if sm.monitor != nil {
return sm.monitor.GetMetrics()
}
return &SecurityMetrics{}
}
// GetSecurityAlerts returns recent security alerts
func (sm *SecurityManager) GetSecurityAlerts(limit int) []SecurityAlert {
sm.alertsMutex.RLock()
defer sm.alertsMutex.RUnlock()
if limit <= 0 || limit > len(sm.securityAlerts) {
limit = len(sm.securityAlerts)
}
start := len(sm.securityAlerts) - limit
if start < 0 {
start = 0
}
alerts := make([]SecurityAlert, limit)
copy(alerts, sm.securityAlerts[start:])
return alerts
}
// Shutdown gracefully shuts down the security manager
func (sm *SecurityManager) Shutdown(ctx context.Context) error {
sm.logger.Info("Shutting down security manager")
// Shutdown components
if sm.keyManager != nil {
// Key manager shutdown - simplified (no shutdown method needed)
sm.logger.Info("Key manager stopped")
}
if sm.rateLimiter != nil {
// Rate limiter shutdown - simplified
sm.logger.Info("Rate limiter stopped")
}
if sm.monitor != nil {
// Monitor shutdown - simplified
sm.logger.Info("Security monitor stopped")
}
if sm.rpcHTTPClient != nil {
sm.rpcHTTPClient.CloseIdleConnections()
}
return nil
}
type jsonRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params []interface{} `json:"params"`
ID string `json:"id"`
}
type jsonRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type jsonRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
Result json.RawMessage `json:"result"`
Error *jsonRPCError `json:"error,omitempty"`
ID string `json:"id"`
}
func normalizeRPCParams(params interface{}) ([]interface{}, error) {
if params == nil {
return []interface{}{}, nil
}
switch v := params.(type) {
case []interface{}:
return v, nil
case []string:
result := make([]interface{}, len(v))
for i := range v {
result[i] = v[i]
}
return result, nil
case []int:
result := make([]interface{}, len(v))
for i := range v {
result[i] = v[i]
}
return result, nil
}
val := reflect.ValueOf(params)
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
length := val.Len()
result := make([]interface{}, length)
for i := 0; i < length; i++ {
result[i] = val.Index(i).Interface()
}
return result, nil
}
return []interface{}{params}, nil
}