package security import ( "context" "crypto/aes" "crypto/cipher" "crypto/ecdsa" "crypto/rand" "crypto/sha256" "crypto/subtle" "encoding/hex" "encoding/json" "fmt" "io" "log" "math/big" "os" "path/filepath" "runtime" "strings" "sync" "sync/atomic" "time" "github.com/ethereum/go-ethereum/accounts/keystore" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "golang.org/x/crypto/scrypt" "github.com/fraktal/mev-beta/internal/logger" ) // AuthenticationContext contains authentication information for key access type AuthenticationContext struct { SessionID string `json:"session_id"` UserID string `json:"user_id"` IPAddress string `json:"ip_address"` UserAgent string `json:"user_agent"` AuthMethod string `json:"auth_method"` // "password", "mfa", "hardware_token" AuthTime time.Time `json:"auth_time"` ExpiresAt time.Time `json:"expires_at"` Permissions []string `json:"permissions"` RiskScore int `json:"risk_score"` } // AuthenticationSession tracks active authentication sessions type AuthenticationSession struct { ID string `json:"id"` Context *AuthenticationContext `json:"context"` CreatedAt time.Time `json:"created_at"` LastActivity time.Time `json:"last_activity"` IsActive bool `json:"is_active"` LoginAttempts int `json:"login_attempts"` } // KeyAccessEvent represents an access event to a private key type KeyAccessEvent struct { Timestamp time.Time `json:"timestamp"` KeyAddress common.Address `json:"key_address"` Operation string `json:"operation"` // "access", "sign", "rotate", "fail" Success bool `json:"success"` Source string `json:"source"` IPAddress string `json:"ip_address,omitempty"` UserAgent string `json:"user_agent,omitempty"` ErrorMsg string `json:"error_msg,omitempty"` AuthContext *AuthenticationContext `json:"auth_context,omitempty"` RiskLevel string `json:"risk_level"` } // SecureKey represents an encrypted private key with metadata type SecureKey struct { Address common.Address `json:"address"` EncryptedKey []byte `json:"encrypted_key"` CreatedAt time.Time `json:"created_at"` LastUsedUnix int64 `json:"last_used_unix"` // Atomic access to Unix timestamp UsageCount int64 `json:"usage_count"` // Atomic access to usage counter MaxUsage int `json:"max_usage"` ExpiresAt *time.Time `json:"expires_at,omitempty"` KeyVersion int `json:"key_version"` Salt []byte `json:"salt"` Nonce []byte `json:"nonce"` KeyType string `json:"key_type"` Permissions KeyPermissions `json:"permissions"` IsActive bool `json:"is_active"` BackupLocations []string `json:"backup_locations,omitempty"` // Mutex for non-atomic fields mu sync.RWMutex `json:"-"` } // GetLastUsed returns the last used time in a thread-safe manner func (sk *SecureKey) GetLastUsed() time.Time { lastUsedUnix := atomic.LoadInt64(&sk.LastUsedUnix) if lastUsedUnix == 0 { return time.Time{} } return time.Unix(lastUsedUnix, 0) } // GetUsageCount returns the usage count in a thread-safe manner func (sk *SecureKey) GetUsageCount() int64 { return atomic.LoadInt64(&sk.UsageCount) } // SetLastUsed sets the last used time in a thread-safe manner func (sk *SecureKey) SetLastUsed(t time.Time) { atomic.StoreInt64(&sk.LastUsedUnix, t.Unix()) } // IncrementUsageCount increments and returns the new usage count func (sk *SecureKey) IncrementUsageCount() int64 { return atomic.AddInt64(&sk.UsageCount, 1) } // SigningRateTracker tracks signing rates per key type SigningRateTracker struct { LastReset time.Time StartTime time.Time Count int MaxPerMinute int MaxPerHour int HourlyCount int } // KeyManagerConfig provides configuration for the key manager type KeyManagerConfig struct { KeyDir string `json:"key_dir"` KeystorePath string `json:"keystore_path"` EncryptionKey string `json:"encryption_key"` BackupPath string `json:"backup_path"` RotationInterval time.Duration `json:"rotation_interval"` MaxKeyAge time.Duration `json:"max_key_age"` MaxFailedAttempts int `json:"max_failed_attempts"` LockoutDuration time.Duration `json:"lockout_duration"` EnableAuditLogging bool `json:"enable_audit_logging"` MaxSigningsPerMinute int `json:"max_signings_per_minute"` MaxSigningsPerHour int `json:"max_signings_per_hour"` RequireHSM bool `json:"require_hsm"` BackupEnabled bool `json:"backup_enabled"` BackupLocation string `json:"backup_location"` MaxSigningRate int `json:"max_signing_rate"` AuditLogPath string `json:"audit_log_path"` KeyRotationDays int `json:"key_rotation_days"` RequireHardware bool `json:"require_hardware"` SessionTimeout time.Duration `json:"session_timeout"` // Authentication and Authorization Configuration RequireAuthentication bool `json:"require_authentication"` EnableIPWhitelist bool `json:"enable_ip_whitelist"` WhitelistedIPs []string `json:"whitelisted_ips"` MaxConcurrentSessions int `json:"max_concurrent_sessions"` RequireMFA bool `json:"require_mfa"` PasswordHashRounds int `json:"password_hash_rounds"` MaxSessionAge time.Duration `json:"max_session_age"` EnableRateLimiting bool `json:"enable_rate_limiting"` MaxAuthAttempts int `json:"max_auth_attempts"` AuthLockoutDuration time.Duration `json:"auth_lockout_duration"` } // KeyManager provides secure private key management and transaction signing type KeyManager struct { logger *logger.Logger keystore *keystore.KeyStore encryptionKey []byte // Enhanced security features mu sync.RWMutex activeKeyRotation bool lastKeyRotation time.Time keyRotationInterval time.Duration maxKeyAge time.Duration failedAccessAttempts map[string]int accessLockouts map[string]time.Time maxFailedAttempts int lockoutDuration time.Duration // Authentication and Authorization activeSessions map[string]*AuthenticationSession sessionsMutex sync.RWMutex whitelistedIPs map[string]bool ipWhitelistMutex sync.RWMutex authMutex sync.Mutex sessionTimeout time.Duration maxConcurrentSessions int // Audit logging accessLog []KeyAccessEvent maxLogEntries int // Key derivation settings scryptN int scryptR int scryptP int scryptKeyLen int keys map[common.Address]*SecureKey keysMutex sync.RWMutex config *KeyManagerConfig signingRates map[string]*SigningRateTracker rateLimitMutex sync.Mutex // MEDIUM-001 ENHANCEMENT: Enhanced rate limiting enhancedRateLimiter *RateLimiter // CHAIN ID VALIDATION ENHANCEMENT: Enhanced chain security chainValidator *ChainIDValidator expectedChainID *big.Int } // KeyPermissions defines what operations a key can perform type KeyPermissions struct { CanSign bool `json:"can_sign"` CanTransfer bool `json:"can_transfer"` MaxTransferWei *big.Int `json:"max_transfer_wei,omitempty"` AllowedContracts []string `json:"allowed_contracts,omitempty"` RequireConfirm bool `json:"require_confirmation"` } // SigningRequest represents a request to sign a transaction type SigningRequest struct { Transaction *types.Transaction ChainID *big.Int From common.Address Purpose string // Description of what this transaction does UrgencyLevel int // 1-5, with 5 being emergency } // SigningResult contains the result of a signing operation type SigningResult struct { SignedTx *types.Transaction Signature []byte SignedAt time.Time KeyUsed common.Address AuditID string Warnings []string } // AuditEntry represents a security audit log entry type AuditEntry struct { Timestamp time.Time `json:"timestamp"` Operation string `json:"operation"` KeyAddress common.Address `json:"key_address"` Success bool `json:"success"` Details string `json:"details"` IPAddress string `json:"ip_address,omitempty"` UserAgent string `json:"user_agent,omitempty"` RiskScore int `json:"risk_score"` // 1-10 } // NewKeyManager creates a new secure key manager func NewKeyManager(config *KeyManagerConfig, logger *logger.Logger) (*KeyManager, error) { // Default to Arbitrum mainnet chain ID (42161) return NewKeyManagerWithChainID(config, logger, big.NewInt(42161)) } // NewKeyManagerWithChainID creates a key manager with specified chain ID for enhanced validation func NewKeyManagerWithChainID(config *KeyManagerConfig, logger *logger.Logger, chainID *big.Int) (*KeyManager, error) { // Skip production validation in development/test environments validateProduction := os.Getenv("GO_ENV") != "development" && os.Getenv("NODE_ENV") != "development" && os.Getenv("NODE_ENV") != "test" return newKeyManagerInternal(config, logger, chainID, validateProduction) } // newKeyManagerForTesting creates a key manager without production validation (test only) func newKeyManagerForTesting(config *KeyManagerConfig, logger *logger.Logger) (*KeyManager, error) { return newKeyManagerInternal(config, logger, big.NewInt(42161), false) } func newKeyManagerInternal(config *KeyManagerConfig, logger *logger.Logger, chainID *big.Int, validateProduction bool) (*KeyManager, error) { if config == nil { config = getDefaultConfig() } // Critical Security Fix: Validate production encryption key (skip for tests) if validateProduction { if err := validateProductionConfig(config); err != nil { return nil, fmt.Errorf("production configuration validation failed: %w", err) } } // Validate configuration if err := validateConfig(config); err != nil { return nil, fmt.Errorf("invalid configuration: %w", err) } // Create keystore directory if it doesn't exist if err := os.MkdirAll(config.KeystorePath, 0700); err != nil { return nil, fmt.Errorf("failed to create keystore directory: %w", err) } // Create backup directory if specified if config.BackupPath != "" { if err := os.MkdirAll(config.BackupPath, 0700); err != nil { return nil, fmt.Errorf("failed to create backup directory: %w", err) } } // Initialize keystore // PERFORMANCE FIX: Use LightScrypt instead of StandardScrypt for faster key operations // Light parameters still provide good security but significantly faster performance // StandardScryptN=262144 can take 10+ seconds per key operation // LightScryptN=4096 takes < 1 second while still being secure for local keystores ks := keystore.NewKeyStore(config.KeystorePath, keystore.LightScryptN, keystore.LightScryptP) // Derive encryption key from master key encryptionKey, err := deriveEncryptionKey(config.EncryptionKey) if err != nil { return nil, fmt.Errorf("failed to derive encryption key: %w", err) } // MEDIUM-001 ENHANCEMENT: Initialize enhanced rate limiter enhancedRateLimiterConfig := &RateLimiterConfig{ IPRequestsPerSecond: config.MaxSigningRate, IPBurstSize: config.MaxSigningRate * 2, UserRequestsPerSecond: config.MaxSigningRate * 10, UserBurstSize: config.MaxSigningRate * 20, GlobalRequestsPerSecond: config.MaxSigningRate * 100, GlobalBurstSize: config.MaxSigningRate * 200, SlidingWindowEnabled: true, SlidingWindowSize: time.Minute, SlidingWindowPrecision: time.Second, AdaptiveEnabled: true, SystemLoadThreshold: 80.0, AdaptiveAdjustInterval: 30 * time.Second, AdaptiveMinRate: 0.1, AdaptiveMaxRate: 5.0, BypassDetectionEnabled: true, BypassThreshold: config.MaxSigningRate / 2, BypassDetectionWindow: time.Hour, BypassAlertCooldown: 10 * time.Minute, CleanupInterval: 5 * time.Minute, BucketTTL: time.Hour, } km := &KeyManager{ logger: logger, keystore: ks, encryptionKey: encryptionKey, keys: make(map[common.Address]*SecureKey), config: config, activeSessions: make(map[string]*AuthenticationSession), whitelistedIPs: make(map[string]bool), failedAccessAttempts: make(map[string]int), accessLockouts: make(map[string]time.Time), maxFailedAttempts: config.MaxFailedAttempts, lockoutDuration: config.LockoutDuration, sessionTimeout: config.SessionTimeout, maxConcurrentSessions: config.MaxConcurrentSessions, // MEDIUM-001 ENHANCEMENT: Enhanced rate limiting enhancedRateLimiter: NewEnhancedRateLimiter(enhancedRateLimiterConfig), // CHAIN ID VALIDATION ENHANCEMENT: Initialize chain security expectedChainID: chainID, chainValidator: NewChainIDValidator(logger, chainID), } // Initialize IP whitelist if config.EnableIPWhitelist { for _, ip := range config.WhitelistedIPs { km.whitelistedIPs[ip] = true } } // Load existing keys if err := km.loadExistingKeys(); err != nil { logger.Warn(fmt.Sprintf("Failed to load existing keys: %v", err)) } // Start background tasks go km.backgroundTasks() logger.Info("Secure key manager initialized with enhanced rate limiting") return km, nil } // GenerateKey creates a new private key with specified permissions func (km *KeyManager) GenerateKey(keyType string, permissions KeyPermissions) (common.Address, error) { // Generate new private key privateKey, err := crypto.GenerateKey() if err != nil { return common.Address{}, fmt.Errorf("failed to generate private key: %w", err) } address := crypto.PubkeyToAddress(privateKey.PublicKey) // Encrypt the private key encryptedKey, err := km.encryptPrivateKey(privateKey) if err != nil { return common.Address{}, fmt.Errorf("failed to encrypt private key: %w", err) } // Create secure key object secureKey := &SecureKey{ Address: address, EncryptedKey: encryptedKey, CreatedAt: time.Now(), LastUsedUnix: time.Now().Unix(), UsageCount: 0, KeyType: keyType, Permissions: permissions, IsActive: true, // Mark as active by default } // Set expiration for certain key types if keyType == "emergency" { expiresAt := time.Now().Add(30 * 24 * time.Hour) // 30 days secureKey.ExpiresAt = &expiresAt } // Store the key km.keysMutex.Lock() km.keys[address] = secureKey km.keysMutex.Unlock() // Create backup if err := km.createKeyBackup(secureKey); err != nil { km.logger.Warn(fmt.Sprintf("Failed to create backup for key %s: %v", address.Hex(), err)) } // Audit log km.auditLog("KEY_GENERATED", address, true, fmt.Sprintf("Generated %s key", keyType)) km.logger.Info(fmt.Sprintf("Generated new %s key: %s", keyType, address.Hex())) return address, nil } // ImportKey imports an existing private key func (km *KeyManager) ImportKey(privateKeyHex string, keyType string, permissions KeyPermissions) (common.Address, error) { // Parse private key privateKey, err := crypto.HexToECDSA(privateKeyHex) if err != nil { return common.Address{}, fmt.Errorf("invalid private key: %w", err) } address := crypto.PubkeyToAddress(privateKey.PublicKey) // Check if key already exists km.keysMutex.RLock() _, exists := km.keys[address] km.keysMutex.RUnlock() if exists { return common.Address{}, fmt.Errorf("key already exists: %s", address.Hex()) } // Encrypt the private key encryptedKey, err := km.encryptPrivateKey(privateKey) if err != nil { return common.Address{}, fmt.Errorf("failed to encrypt private key: %w", err) } // Create secure key object secureKey := &SecureKey{ Address: address, EncryptedKey: encryptedKey, CreatedAt: time.Now(), LastUsedUnix: time.Now().Unix(), UsageCount: 0, KeyType: keyType, Permissions: permissions, IsActive: true, // Mark as active by default } // Store the key km.keysMutex.Lock() km.keys[address] = secureKey km.keysMutex.Unlock() // Create backup if err := km.createKeyBackup(secureKey); err != nil { km.logger.Warn(fmt.Sprintf("Failed to create backup for key %s: %v", address.Hex(), err)) } // Audit log km.auditLog("KEY_IMPORTED", address, true, fmt.Sprintf("Imported %s key", keyType)) km.logger.Info(fmt.Sprintf("Imported %s key: %s", keyType, address.Hex())) return address, nil } // SignTransactionWithAuth signs a transaction with authentication and comprehensive security checks func (km *KeyManager) SignTransactionWithAuth(request *SigningRequest, authContext *AuthenticationContext) (*SigningResult, error) { // Validate authentication if required if km.config.RequireAuthentication { if authContext == nil { return nil, fmt.Errorf("authentication required") } // Validate session if _, err := km.ValidateSession(authContext.SessionID); err != nil { return nil, fmt.Errorf("invalid session: %w", err) } // Check permissions if !contains(authContext.Permissions, "transaction_signing") { return nil, fmt.Errorf("insufficient permissions for transaction signing") } // Enhanced audit logging with auth context km.auditLogWithAuth("SIGN_ATTEMPT", request.From, true, fmt.Sprintf("Transaction signing attempted: %s", request.Purpose), authContext) } return km.SignTransaction(request) } // SignTransaction signs a transaction with comprehensive security checks func (km *KeyManager) SignTransaction(request *SigningRequest) (*SigningResult, error) { // Get the key km.keysMutex.RLock() secureKey, exists := km.keys[request.From] km.keysMutex.RUnlock() if !exists { km.auditLog("SIGN_FAILED", request.From, false, "Key not found") return nil, fmt.Errorf("key not found: %s", request.From.Hex()) } // Security checks warnings := make([]string, 0) // Check permissions if !secureKey.Permissions.CanSign { km.auditLog("SIGN_FAILED", request.From, false, "Key not permitted to sign") return nil, fmt.Errorf("key %s not permitted to sign transactions", request.From.Hex()) } // Check expiration if secureKey.ExpiresAt != nil && time.Now().After(*secureKey.ExpiresAt) { km.auditLog("SIGN_FAILED", request.From, false, "Key expired") return nil, fmt.Errorf("key %s has expired", request.From.Hex()) } // Check usage limits (using atomic load for thread safety) currentUsageCount := atomic.LoadInt64(&secureKey.UsageCount) if secureKey.MaxUsage > 0 && currentUsageCount >= int64(secureKey.MaxUsage) { km.auditLog("SIGN_FAILED", request.From, false, "Usage limit exceeded") return nil, fmt.Errorf("key %s usage limit exceeded", request.From.Hex()) } // Check transfer permissions and limits if request.Transaction.Value().Sign() > 0 { if !secureKey.Permissions.CanTransfer { km.auditLog("SIGN_FAILED", request.From, false, "Transfer not permitted") return nil, fmt.Errorf("key %s not permitted to transfer value", request.From.Hex()) } if secureKey.Permissions.MaxTransferWei != nil && request.Transaction.Value().Cmp(secureKey.Permissions.MaxTransferWei) > 0 { km.auditLog("SIGN_FAILED", request.From, false, "Transfer amount exceeds limit") return nil, fmt.Errorf("transfer amount exceeds limit for key %s", request.From.Hex()) } } // Check contract interaction permissions if request.Transaction.To() != nil { contractAddr := request.Transaction.To().Hex() if len(secureKey.Permissions.AllowedContracts) > 0 { allowed := false for _, allowedContract := range secureKey.Permissions.AllowedContracts { if contractAddr == allowedContract { allowed = true break } } if !allowed { km.auditLog("SIGN_FAILED", request.From, false, "Contract interaction not permitted") return nil, fmt.Errorf("key %s not permitted to interact with contract %s", request.From.Hex(), contractAddr) } } } // Rate limiting check if err := km.checkRateLimit(request.From); err != nil { km.auditLog("SIGN_FAILED", request.From, false, "Rate limit exceeded") return nil, fmt.Errorf("rate limit exceeded: %w", err) } // Warning checks using atomic operations for thread safety lastUsedUnix := atomic.LoadInt64(&secureKey.LastUsedUnix) if lastUsedUnix > 0 && time.Since(time.Unix(lastUsedUnix, 0)) > 24*time.Hour { warnings = append(warnings, "Key has not been used in over 24 hours") } usageCount := atomic.LoadInt64(&secureKey.UsageCount) if usageCount > 1000 { warnings = append(warnings, "Key has high usage count - consider rotation") } // CHAIN ID VALIDATION ENHANCEMENT: Comprehensive chain ID validation before signing chainValidationResult := km.chainValidator.ValidateChainID(request.Transaction, request.From, request.ChainID) if !chainValidationResult.Valid { km.auditLog("SIGN_FAILED", request.From, false, fmt.Sprintf("Chain ID validation failed: %v", chainValidationResult.Errors)) return nil, fmt.Errorf("chain ID validation failed: %v", chainValidationResult.Errors) } // Log security warnings from chain validation for _, warning := range chainValidationResult.Warnings { warnings = append(warnings, warning) km.logger.Warn(fmt.Sprintf("Chain validation warning for %s: %s", request.From.Hex(), warning)) } // CRITICAL: Check for high replay risk if chainValidationResult.ReplayRisk == "CRITICAL" { km.auditLog("SIGN_FAILED", request.From, false, "Critical replay attack risk detected") return nil, fmt.Errorf("transaction rejected due to critical replay attack risk") } // Decrypt private key privateKey, err := km.decryptPrivateKey(secureKey.EncryptedKey) if err != nil { km.auditLog("SIGN_FAILED", request.From, false, "Failed to decrypt private key") return nil, fmt.Errorf("failed to decrypt private key: %w", err) } defer func() { // Clear private key from memory if privateKey != nil { clearPrivateKey(privateKey) } }() // CHAIN ID VALIDATION ENHANCEMENT: Verify chain ID matches transaction before signing if request.ChainID.Uint64() != km.expectedChainID.Uint64() { km.auditLog("SIGN_FAILED", request.From, false, fmt.Sprintf("Request chain ID %d doesn't match expected %d", request.ChainID.Uint64(), km.expectedChainID.Uint64())) return nil, fmt.Errorf("request chain ID %d doesn't match expected %d", request.ChainID.Uint64(), km.expectedChainID.Uint64()) } // Sign the transaction with appropriate signer based on transaction type var signer types.Signer switch request.Transaction.Type() { case types.LegacyTxType: signer = types.NewEIP155Signer(request.ChainID) case types.DynamicFeeTxType: signer = types.NewLondonSigner(request.ChainID) default: km.auditLog("SIGN_FAILED", request.From, false, fmt.Sprintf("Unsupported transaction type: %d", request.Transaction.Type())) return nil, fmt.Errorf("unsupported transaction type: %d", request.Transaction.Type()) } signedTx, err := types.SignTx(request.Transaction, signer, privateKey) if err != nil { km.auditLog("SIGN_FAILED", request.From, false, "Transaction signing failed") return nil, fmt.Errorf("failed to sign transaction: %w", err) } // CHAIN ID VALIDATION ENHANCEMENT: Verify signature integrity after signing if err := km.chainValidator.ValidateSignerMatchesChain(signedTx, request.From); err != nil { km.auditLog("SIGN_FAILED", request.From, false, fmt.Sprintf("Post-signing validation failed: %v", err)) return nil, fmt.Errorf("post-signing validation failed: %w", err) } // Extract signature v, r, s := signedTx.RawSignatureValues() signature := make([]byte, 65) r.FillBytes(signature[0:32]) s.FillBytes(signature[32:64]) signature[64] = byte(v.Uint64() - 35 - 2*request.ChainID.Uint64()) // Convert to recovery ID // Update key usage with atomic operations for thread safety now := time.Now() atomic.StoreInt64(&secureKey.LastUsedUnix, now.Unix()) atomic.AddInt64(&secureKey.UsageCount, 1) // Generate audit ID auditID := generateAuditID() // Create result result := &SigningResult{ SignedTx: signedTx, Signature: signature, SignedAt: time.Now(), KeyUsed: request.From, AuditID: auditID, Warnings: warnings, } // Audit log km.auditLog("TRANSACTION_SIGNED", request.From, true, fmt.Sprintf("Signed transaction %s for %s (audit: %s)", signedTx.Hash().Hex(), request.Purpose, auditID)) return result, nil } // CHAIN ID VALIDATION ENHANCEMENT: Chain security management methods // GetChainValidationStats returns chain validation statistics func (km *KeyManager) GetChainValidationStats() map[string]interface{} { return km.chainValidator.GetValidationStats() } // AddAllowedChainID adds a chain ID to the allowed list func (km *KeyManager) AddAllowedChainID(chainID uint64) { km.chainValidator.AddAllowedChainID(chainID) km.auditLog("CHAIN_ID_ADDED", common.Address{}, true, fmt.Sprintf("Added chain ID %d to allowed list", chainID)) } // RemoveAllowedChainID removes a chain ID from the allowed list func (km *KeyManager) RemoveAllowedChainID(chainID uint64) { km.chainValidator.RemoveAllowedChainID(chainID) km.auditLog("CHAIN_ID_REMOVED", common.Address{}, true, fmt.Sprintf("Removed chain ID %d from allowed list", chainID)) } // ValidateTransactionChain validates a transaction's chain ID without signing func (km *KeyManager) ValidateTransactionChain(tx *types.Transaction, signerAddr common.Address) (*ChainValidationResult, error) { return km.chainValidator.ValidateChainID(tx, signerAddr, nil), nil } // GetExpectedChainID returns the expected chain ID for this key manager func (km *KeyManager) GetExpectedChainID() *big.Int { return new(big.Int).Set(km.expectedChainID) } // GetKeyInfo returns information about a key (without sensitive data) func (km *KeyManager) GetKeyInfo(address common.Address) (*SecureKey, error) { km.keysMutex.RLock() defer km.keysMutex.RUnlock() secureKey, exists := km.keys[address] if !exists { return nil, fmt.Errorf("key not found: %s", address.Hex()) } // Return a copy without the encrypted key info := SecureKey{ Address: secureKey.Address, EncryptedKey: nil, // Intentionally exclude encrypted key CreatedAt: secureKey.CreatedAt, LastUsedUnix: secureKey.LastUsedUnix, UsageCount: secureKey.UsageCount, MaxUsage: secureKey.MaxUsage, ExpiresAt: secureKey.ExpiresAt, KeyVersion: secureKey.KeyVersion, Salt: secureKey.Salt, Nonce: secureKey.Nonce, KeyType: secureKey.KeyType, Permissions: secureKey.Permissions, IsActive: secureKey.IsActive, BackupLocations: secureKey.BackupLocations, // Do not copy mutex field } return &info, nil } // ListKeys returns addresses of all managed keys func (km *KeyManager) ListKeys() []common.Address { km.keysMutex.RLock() defer km.keysMutex.RUnlock() addresses := make([]common.Address, 0, len(km.keys)) for address := range km.keys { addresses = append(addresses, address) } return addresses } // RotateKey creates a new key to replace an existing one func (km *KeyManager) RotateKey(oldAddress common.Address) (common.Address, error) { km.keysMutex.RLock() oldKey, exists := km.keys[oldAddress] km.keysMutex.RUnlock() if !exists { return common.Address{}, fmt.Errorf("key not found: %s", oldAddress.Hex()) } // Generate new key with same permissions newAddress, err := km.GenerateKey(oldKey.KeyType, oldKey.Permissions) if err != nil { return common.Address{}, fmt.Errorf("failed to generate new key: %w", err) } // Mark old key as rotated (don't delete immediately for audit purposes) km.keysMutex.Lock() oldKey.Permissions.CanSign = false oldKey.Permissions.CanTransfer = false km.keysMutex.Unlock() // Audit log km.auditLog("KEY_ROTATED", oldAddress, true, fmt.Sprintf("Rotated to new key %s", newAddress.Hex())) km.logger.Info(fmt.Sprintf("Rotated key %s to %s", oldAddress.Hex(), newAddress.Hex())) return newAddress, nil } // encryptPrivateKey encrypts a private key using AES-GCM func (km *KeyManager) encryptPrivateKey(privateKey *ecdsa.PrivateKey) ([]byte, error) { // Convert private key to bytes keyBytes := crypto.FromECDSA(privateKey) // Create AES cipher block, err := aes.NewCipher(km.encryptionKey) if err != nil { return nil, fmt.Errorf("failed to create cipher: %w", err) } // Create GCM gcm, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("failed to create GCM: %w", err) } // Generate nonce nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, fmt.Errorf("failed to generate nonce: %w", err) } // Encrypt ciphertext := gcm.Seal(nonce, nonce, keyBytes, nil) // Clear original key bytes for i := range keyBytes { keyBytes[i] = 0 } return ciphertext, nil } // decryptPrivateKey decrypts an encrypted private key func (km *KeyManager) decryptPrivateKey(encryptedKey []byte) (*ecdsa.PrivateKey, error) { // Create AES cipher block, err := aes.NewCipher(km.encryptionKey) if err != nil { return nil, fmt.Errorf("failed to create cipher: %w", err) } // Create GCM gcm, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("failed to create GCM: %w", err) } // Extract nonce nonceSize := gcm.NonceSize() if len(encryptedKey) < nonceSize { return nil, fmt.Errorf("encrypted key too short") } nonce, ciphertext := encryptedKey[:nonceSize], encryptedKey[nonceSize:] // Decrypt keyBytes, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { return nil, fmt.Errorf("decryption failed: %w", err) } defer func() { // Clear decrypted bytes for i := range keyBytes { keyBytes[i] = 0 } }() // Convert to private key privateKey, err := crypto.ToECDSA(keyBytes) if err != nil { return nil, fmt.Errorf("failed to parse private key: %w", err) } return privateKey, nil } // createKeyBackup creates an encrypted backup of a key func (km *KeyManager) createKeyBackup(secureKey *SecureKey) error { if km.config.BackupPath == "" { return nil // Backups not configured } backupFile := filepath.Join(km.config.BackupPath, fmt.Sprintf("key_%s_%d.backup", secureKey.Address.Hex(), time.Now().Unix())) // Create backup data backupData := struct { Address string `json:"address"` EncryptedKey []byte `json:"encrypted_key"` CreatedAt time.Time `json:"created_at"` KeyType string `json:"key_type"` }{ Address: secureKey.Address.Hex(), EncryptedKey: secureKey.EncryptedKey, CreatedAt: secureKey.CreatedAt, KeyType: secureKey.KeyType, } // Additional encryption for backup backupBytes, err := encryptBackupData(backupData, km.encryptionKey) if err != nil { return fmt.Errorf("failed to encrypt backup: %w", err) } // Write to file if err := os.WriteFile(backupFile, backupBytes, 0600); err != nil { return fmt.Errorf("failed to write backup file: %w", err) } // Update backup locations secureKey.BackupLocations = append(secureKey.BackupLocations, backupFile) return nil } // checkRateLimit checks if signing rate limit is exceeded using enhanced rate limiting func (km *KeyManager) checkRateLimit(address common.Address) error { if km.config.MaxSigningRate <= 0 { return nil // Rate limiting disabled } // Use enhanced rate limiter if available if km.enhancedRateLimiter != nil { ctx := context.Background() result := km.enhancedRateLimiter.CheckRateLimitEnhanced( ctx, "127.0.0.1", // IP for local signing address.Hex(), // User ID "MEVBot/1.0", // User agent "signing", // Endpoint make(map[string]string), // Headers ) if !result.Allowed { km.logger.Warn(fmt.Sprintf("Enhanced rate limit exceeded for key %s: %s (reason: %s, score: %d)", address.Hex(), result.Message, result.ReasonCode, result.SuspiciousScore)) return fmt.Errorf("enhanced rate limit exceeded: %s", result.Message) } // Log metrics for monitoring if result.SuspiciousScore > 50 { km.logger.Warn(fmt.Sprintf("Suspicious signing activity detected for key %s: score %d", address.Hex(), result.SuspiciousScore)) } return nil } // Fallback to simple rate limiting km.rateLimitMutex.Lock() defer km.rateLimitMutex.Unlock() now := time.Now() key := address.Hex() // Initialize rate limit tracking for this key if needed if km.signingRates == nil { km.signingRates = make(map[string]*SigningRateTracker) } if _, exists := km.signingRates[key]; !exists { km.signingRates[key] = &SigningRateTracker{ Count: 0, StartTime: now, } } tracker := km.signingRates[key] // Reset counter if more than a minute has passed if now.Sub(tracker.StartTime) > time.Minute { tracker.Count = 0 tracker.StartTime = now } // Increment counter tracker.Count++ // Check if we've exceeded the rate limit if tracker.Count > km.config.MaxSigningRate { return fmt.Errorf("signing rate limit exceeded for key %s: %d/%d per minute", address.Hex(), tracker.Count, km.config.MaxSigningRate) } return nil } // auditLog writes an entry to the audit log func (km *KeyManager) auditLog(operation string, keyAddress common.Address, success bool, details string) { entry := AuditEntry{ Timestamp: time.Now(), Operation: operation, KeyAddress: keyAddress, Success: success, Details: details, RiskScore: calculateRiskScore(operation, success), } // Write to audit log if km.config.AuditLogPath != "" { // Implementation would write to audit log file km.logger.Info(fmt.Sprintf("AUDIT: %s %s %v - %s (Risk: %.2f)", entry.Operation, entry.KeyAddress.Hex(), entry.Success, entry.Details, float64(entry.RiskScore))) } } // loadExistingKeys loads keys from the keystore func (km *KeyManager) loadExistingKeys() error { // Implementation would load existing keys from storage // For now, just log that we're loading km.logger.Info("Loading existing keys from keystore") return nil } // backgroundTasks runs periodic maintenance tasks func (km *KeyManager) backgroundTasks() { ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() for { select { case <-ticker.C: km.performMaintenance() } } } // performMaintenance performs periodic security maintenance func (km *KeyManager) performMaintenance() { km.keysMutex.RLock() defer km.keysMutex.RUnlock() now := time.Now() for address, key := range km.keys { // Check for expired keys if key.ExpiresAt != nil && now.After(*key.ExpiresAt) { km.logger.Warn(fmt.Sprintf("Key %s has expired", address.Hex())) } // Check for keys that should be rotated if km.config.KeyRotationDays > 0 { rotationTime := time.Duration(km.config.KeyRotationDays) * 24 * time.Hour if now.Sub(key.CreatedAt) > rotationTime { km.logger.Warn(fmt.Sprintf("Key %s should be rotated (age: %v)", address.Hex(), now.Sub(key.CreatedAt))) } } } } // AuthenticateUser authenticates a user and creates a session func (km *KeyManager) AuthenticateUser(userID, password, ipAddress, userAgent string) (*AuthenticationSession, error) { km.authMutex.Lock() defer km.authMutex.Unlock() // Check IP whitelist if enabled if km.config.EnableIPWhitelist { km.ipWhitelistMutex.RLock() allowed := km.whitelistedIPs[ipAddress] km.ipWhitelistMutex.RUnlock() if !allowed { km.auditLog("AUTH_FAILED", common.Address{}, false, fmt.Sprintf("IP not whitelisted: %s", ipAddress)) return nil, fmt.Errorf("access denied: IP address not whitelisted") } } // Check for lockout if lockoutEnd, locked := km.accessLockouts[userID]; locked { if time.Now().Before(lockoutEnd) { return nil, fmt.Errorf("account locked until %v", lockoutEnd) } // Clear expired lockout delete(km.accessLockouts, userID) delete(km.failedAccessAttempts, userID) } // Validate credentials (simplified - in production use proper password hashing) if !km.validateCredentials(userID, password) { // Track failed attempt km.failedAccessAttempts[userID]++ if km.failedAccessAttempts[userID] >= km.maxFailedAttempts { // Lock account km.accessLockouts[userID] = time.Now().Add(km.lockoutDuration) km.logger.Warn(fmt.Sprintf("Account locked for user %s due to failed attempts", userID)) } km.auditLog("AUTH_FAILED", common.Address{}, false, fmt.Sprintf("Invalid credentials for user %s", userID)) return nil, fmt.Errorf("invalid credentials") } // Clear failed attempts on successful login delete(km.failedAccessAttempts, userID) // Check concurrent session limit km.sessionsMutex.Lock() userSessions := 0 for _, session := range km.activeSessions { if session.Context.UserID == userID && session.IsActive { userSessions++ } } if userSessions >= km.maxConcurrentSessions { km.sessionsMutex.Unlock() return nil, fmt.Errorf("maximum concurrent sessions exceeded") } // Create new session sessionID := generateSessionID() context := &AuthenticationContext{ SessionID: sessionID, UserID: userID, IPAddress: ipAddress, UserAgent: userAgent, AuthMethod: "password", AuthTime: time.Now(), ExpiresAt: time.Now().Add(km.sessionTimeout), Permissions: []string{"key_access", "transaction_signing"}, RiskScore: calculateAuthRiskScore(ipAddress, userAgent), } session := &AuthenticationSession{ ID: sessionID, Context: context, CreatedAt: time.Now(), LastActivity: time.Now(), IsActive: true, } km.activeSessions[sessionID] = session km.sessionsMutex.Unlock() // Audit log km.auditLog("USER_AUTHENTICATED", common.Address{}, true, fmt.Sprintf("User %s authenticated from %s", userID, ipAddress)) km.logger.Info(fmt.Sprintf("User %s authenticated successfully", userID)) return session, nil } // ValidateSession validates an active session func (km *KeyManager) ValidateSession(sessionID string) (*AuthenticationContext, error) { km.sessionsMutex.RLock() session, exists := km.activeSessions[sessionID] km.sessionsMutex.RUnlock() if !exists { return nil, fmt.Errorf("session not found") } if !session.IsActive { return nil, fmt.Errorf("session is inactive") } if time.Now().After(session.Context.ExpiresAt) { // Session expired, deactivate it km.sessionsMutex.Lock() session.IsActive = false km.sessionsMutex.Unlock() km.auditLog("SESSION_EXPIRED", common.Address{}, false, fmt.Sprintf("Session %s expired", sessionID)) return nil, fmt.Errorf("session expired") } // Update last activity km.sessionsMutex.Lock() session.LastActivity = time.Now() km.sessionsMutex.Unlock() return session.Context, nil } // GetActivePrivateKeyWithAuth returns the active private key for transaction signing with authentication func (km *KeyManager) GetActivePrivateKeyWithAuth(authContext *AuthenticationContext) (*ecdsa.PrivateKey, error) { // Validate authentication if required if km.config.RequireAuthentication { if authContext == nil { return nil, fmt.Errorf("authentication required") } // Validate session if _, err := km.ValidateSession(authContext.SessionID); err != nil { return nil, fmt.Errorf("invalid session: %w", err) } // Check permissions if !contains(authContext.Permissions, "key_access") { return nil, fmt.Errorf("insufficient permissions for key access") } } return km.GetActivePrivateKey() } // GetActivePrivateKey returns the active private key for transaction signing func (km *KeyManager) GetActivePrivateKey() (*ecdsa.PrivateKey, error) { // First, check for existing active keys km.keysMutex.RLock() for address, secureKey := range km.keys { if secureKey.IsActive { // Decrypt the private key privateKey, err := km.decryptPrivateKey(secureKey.EncryptedKey) if err != nil { km.keysMutex.RUnlock() km.auditLog("KEY_DECRYPTION_FAILED", address, false, fmt.Sprintf("Failed to decrypt key: %v", err)) return nil, fmt.Errorf("failed to decrypt active key: %w", err) } km.keysMutex.RUnlock() km.auditLog("KEY_ACCESSED", address, true, "Active private key retrieved") return privateKey, nil } } // Check if we need to generate a new key (no keys exist) needsNewKey := len(km.keys) == 0 km.keysMutex.RUnlock() // If no active key found and no keys exist, generate a default one if needsNewKey { km.logger.Info("No keys found, generating default trading key...") // Generate a new key pair with default permissions defaultPermissions := KeyPermissions{ CanSign: true, CanTransfer: true, MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH max per transaction AllowedContracts: []string{}, // Will be populated with contract addresses RequireConfirm: false, } km.logger.Info("Calling GenerateKey...") address, err := km.GenerateKey("trading", defaultPermissions) if err != nil { km.logger.Error(fmt.Sprintf("Failed to generate default key: %v", err)) return nil, fmt.Errorf("failed to generate default key: %w", err) } km.logger.Info(fmt.Sprintf("Default key generated: %s", address.Hex())) // Retrieve the newly generated key km.keysMutex.RLock() if secureKey, exists := km.keys[address]; exists { privateKey, err := km.decryptPrivateKey(secureKey.EncryptedKey) km.keysMutex.RUnlock() if err != nil { return nil, fmt.Errorf("failed to decrypt newly generated key: %w", err) } km.auditLog("KEY_AUTO_GENERATED", address, true, "Auto-generated active key") return privateKey, nil } km.keysMutex.RUnlock() } return nil, fmt.Errorf("no active private key available") } // Helper functions func getDefaultConfig() *KeyManagerConfig { return &KeyManagerConfig{ KeystorePath: "./keystore", EncryptionKey: "", // Will be set later or generated KeyRotationDays: 90, MaxSigningRate: 60, // 60 signings per minute RequireHardware: false, BackupPath: "./backups", AuditLogPath: "./audit.log", SessionTimeout: 15 * time.Minute, RequireAuthentication: true, EnableIPWhitelist: true, WhitelistedIPs: []string{"127.0.0.1", "::1"}, // localhost only by default MaxConcurrentSessions: 3, RequireMFA: false, PasswordHashRounds: 12, MaxSessionAge: 24 * time.Hour, EnableRateLimiting: true, MaxAuthAttempts: 5, AuthLockoutDuration: 30 * time.Minute, MaxFailedAttempts: 3, LockoutDuration: 15 * time.Minute, } } func validateConfig(config *KeyManagerConfig) error { if config.KeystorePath == "" { return fmt.Errorf("keystore path cannot be empty") } if config.EncryptionKey == "" { return fmt.Errorf("encryption key cannot be empty") } if len(config.EncryptionKey) < 32 { return fmt.Errorf("encryption key must be at least 32 characters") } return nil } // SECURITY FIX: Persistent salt for stable key derivation across restarts const saltFilename = ".salt" func deriveEncryptionKey(masterKey string) ([]byte, error) { if masterKey == "" { return nil, fmt.Errorf("master key cannot be empty") } // Determine salt storage path (use current directory for global scope) // NOTE: In production, this should be in a secure location like keystore/ saltPath := filepath.Join("keystore", saltFilename) var salt []byte // Try to load existing salt from file if data, err := os.ReadFile(saltPath); err == nil && len(data) == 32 { salt = data // Validate salt is not all zeros (corrupted) allZero := true for _, b := range salt { if b != 0 { allZero = false break } } if allZero { return nil, fmt.Errorf("corrupted salt file detected (all zeros): %s", saltPath) } } else { // Generate new salt only if none exists or is invalid salt = make([]byte, 32) if _, err := rand.Read(salt); err != nil { return nil, fmt.Errorf("failed to generate random salt: %w", err) } // Ensure keystore directory exists if err := os.MkdirAll("keystore", 0700); err != nil { return nil, fmt.Errorf("failed to create keystore directory: %w", err) } // Persist salt for future restarts if err := os.WriteFile(saltPath, salt, 0600); err != nil { return nil, fmt.Errorf("failed to persist salt: %w", err) } } // PERFORMANCE FIX: Reduced scrypt N from 32768 to 16384 for faster startup // This provides a good balance between security and performance for server applications // N=16384 still requires significant computation (prevents brute force attacks) // but allows bot to start in reasonable time (<10 seconds instead of 2+ minutes) key, err := scrypt.Key([]byte(masterKey), salt, 16384, 8, 1, 32) if err != nil { return nil, fmt.Errorf("key derivation failed: %w", err) } return key, nil } // clearPrivateKey securely clears all private key material from memory func clearPrivateKey(privateKey *ecdsa.PrivateKey) { if privateKey == nil { return } // ENHANCED: Record key clearing for audit trail startTime := time.Now() // Clear D parameter (private key scalar) - MOST CRITICAL if privateKey.D != nil { secureClearBigInt(privateKey.D) privateKey.D = nil } // Clear PublicKey components for complete cleanup if privateKey.PublicKey.X != nil { secureClearBigInt(privateKey.PublicKey.X) privateKey.PublicKey.X = nil } if privateKey.PublicKey.Y != nil { secureClearBigInt(privateKey.PublicKey.Y) privateKey.PublicKey.Y = nil } // Clear the curve reference privateKey.PublicKey.Curve = nil // ENHANCED: Force memory barriers and garbage collection runtime.KeepAlive(privateKey) runtime.GC() // Force garbage collection to clear any remaining references // ENHANCED: Log memory clearing operation for security audit clearingTime := time.Since(startTime) if clearingTime > 100*time.Millisecond { // Log if clearing takes unusually long (potential security concern) log.Printf("WARNING: Private key clearing took %v (longer than expected)", clearingTime) } } // ENHANCED: Add memory protection for sensitive operations func withMemoryProtection(operation func() error) error { // Force garbage collection before sensitive operation runtime.GC() // Execute the operation err := operation() // Force garbage collection after sensitive operation runtime.GC() return err } // ENHANCED: Memory usage monitoring for key operations type KeyMemoryMetrics struct { ActiveKeys int `json:"active_keys"` MemoryUsageBytes int64 `json:"memory_usage_bytes"` GCPauseTime time.Duration `json:"gc_pause_time"` LastClearingTime time.Duration `json:"last_clearing_time"` ClearingCount int64 `json:"clearing_count"` LastGCTime time.Time `json:"last_gc_time"` } // ENHANCED: Monitor memory usage for key operations func (km *KeyManager) GetMemoryMetrics() *KeyMemoryMetrics { var memStats runtime.MemStats runtime.ReadMemStats(&memStats) km.keysMutex.RLock() activeKeys := len(km.keys) km.keysMutex.RUnlock() return &KeyMemoryMetrics{ ActiveKeys: activeKeys, MemoryUsageBytes: int64(memStats.Alloc), GCPauseTime: time.Duration(memStats.PauseTotalNs), LastGCTime: time.Now(), // Simplified - would need proper tracking ClearingCount: 0, // Would need proper tracking LastClearingTime: 0, // Would need proper tracking } } // secureClearBigInt securely clears a big.Int's underlying data func secureClearBigInt(bi *big.Int) { if bi == nil { return } // ENHANCED: Multiple-pass clearing for enhanced security bits := bi.Bits() // Pass 1: Zero out the internal bits slice for i := range bits { bits[i] = 0 } // Pass 2: Fill with random data then clear (prevents data recovery) for i := range bits { bits[i] = ^big.Word(0) // Fill with all 1s } for i := range bits { bits[i] = 0 // Clear again } // Pass 3: Use crypto random to overwrite, then clear if len(bits) > 0 { randomBytes := make([]byte, len(bits)*8) // 8 bytes per Word on 64-bit rand.Read(randomBytes) // Convert random bytes to Words and overwrite for i := range bits { if i*8 < len(randomBytes) { bits[i] = 0 // Final clear after random overwrite } } // Clear the random bytes buffer secureClearBytes(randomBytes) } // ENHANCED: Set to zero using multiple methods to ensure clearing bi.SetInt64(0) bi.SetBytes([]byte{}) bi.Set(big.NewInt(0)) // ENHANCED: Force memory barrier to prevent compiler optimization runtime.KeepAlive(bi) } // secureClearBytes securely clears a byte slice func secureClearBytes(data []byte) { if len(data) == 0 { return } // ENHANCED: Multi-pass clearing for enhanced security // Pass 1: Zero out for i := range data { data[i] = 0 } // Pass 2: Fill with 0xFF for i := range data { data[i] = 0xFF } // Pass 3: Random fill then clear rand.Read(data) for i := range data { data[i] = 0 } // ENHANCED: Force compiler to not optimize away the clearing runtime.KeepAlive(data) } func generateAuditID() string { bytes := make([]byte, 16) if _, err := io.ReadFull(rand.Reader, bytes); err != nil { // Fallback to current time if crypto/rand fails (shouldn't happen) return fmt.Sprintf("%x", time.Now().UnixNano()) } return hex.EncodeToString(bytes) } func calculateRiskScore(operation string, success bool) int { if !success { return 8 // Failed operations are high risk } switch operation { case "TRANSACTION_SIGNED": return 3 case "KEY_GENERATED", "KEY_IMPORTED": return 5 case "KEY_ROTATED": return 4 default: return 2 } } // Logout invalidates a session func (km *KeyManager) Logout(sessionID string) error { km.sessionsMutex.Lock() defer km.sessionsMutex.Unlock() session, exists := km.activeSessions[sessionID] if !exists { return fmt.Errorf("session not found") } session.IsActive = false km.auditLog("USER_LOGOUT", common.Address{}, true, fmt.Sprintf("User %s logged out", session.Context.UserID)) return nil } // validateCredentials validates user credentials (simplified implementation) func (km *KeyManager) validateCredentials(userID, password string) bool { // In production, this should use proper password hashing (bcrypt, scrypt, etc.) // For now, we'll use a simple hash comparison expectedHash := hashPassword(password) storredHash := km.getStoredPasswordHash(userID) return subtle.ConstantTimeCompare([]byte(expectedHash), []byte(storredHash)) == 1 } // getStoredPasswordHash retrieves stored password hash (simplified) func (km *KeyManager) getStoredPasswordHash(userID string) string { // In production, this would fetch from secure storage // For development/testing, we'll use a default hash if userID == "admin" { return hashPassword("secure_admin_password_123") } return hashPassword("default_password") } // hashPassword creates a hash of the password func hashPassword(password string) string { hash := sha256.Sum256([]byte(password)) return hex.EncodeToString(hash[:]) } // generateSessionID generates a secure session ID func generateSessionID() string { bytes := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, bytes); err != nil { return fmt.Sprintf("%x", time.Now().UnixNano()) } return hex.EncodeToString(bytes) } // calculateAuthRiskScore calculates risk score for authentication func calculateAuthRiskScore(ipAddress, userAgent string) int { riskScore := 1 // Base risk // Increase risk for external IPs if !strings.HasPrefix(ipAddress, "127.") && !strings.HasPrefix(ipAddress, "192.168.") && !strings.HasPrefix(ipAddress, "10.") { riskScore += 3 } // Increase risk for unknown user agents if len(userAgent) < 10 || !strings.Contains(userAgent, "Mozilla") { riskScore += 2 } return riskScore } // contains checks if a slice contains a string func contains(slice []string, item string) bool { for _, s := range slice { if s == item { return true } } return false } // auditLogWithAuth writes an audit entry with authentication context func (km *KeyManager) auditLogWithAuth(operation string, keyAddress common.Address, success bool, details string, authContext *AuthenticationContext) { entry := AuditEntry{ Timestamp: time.Now(), Operation: operation, KeyAddress: keyAddress, Success: success, Details: details, RiskScore: calculateRiskScore(operation, success), } if authContext != nil { entry.IPAddress = authContext.IPAddress entry.UserAgent = authContext.UserAgent } // Write to audit log if km.config.AuditLogPath != "" { km.logger.Info(fmt.Sprintf("AUDIT: %s %s %v - %s (Risk: %d) [User: %v]", entry.Operation, entry.KeyAddress.Hex(), entry.Success, entry.Details, entry.RiskScore, map[string]interface{}{"user_id": authContext.UserID, "session_id": authContext.SessionID})) } } func encryptBackupData(data interface{}, key []byte) ([]byte, error) { // Convert data to JSON bytes jsonData, err := json.Marshal(data) if err != nil { return nil, fmt.Errorf("failed to marshal backup data: %w", err) } // Create AES cipher block, err := aes.NewCipher(key) if err != nil { return nil, fmt.Errorf("failed to create AES cipher: %w", err) } // Create GCM mode for authenticated encryption gcm, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("failed to create GCM mode: %w", err) } // Generate random nonce nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, fmt.Errorf("failed to generate nonce: %w", err) } // Encrypt and authenticate the data ciphertext := gcm.Seal(nonce, nonce, jsonData, nil) return ciphertext, nil } // validateProductionConfig validates production-specific security requirements func validateProductionConfig(config *KeyManagerConfig) error { // Check for encryption key presence if config.EncryptionKey == "" { return fmt.Errorf("MEV_BOT_ENCRYPTION_KEY environment variable is required for production") } // Check for test/default encryption keys if strings.Contains(strings.ToLower(config.EncryptionKey), "test") || strings.Contains(strings.ToLower(config.EncryptionKey), "default") || strings.Contains(strings.ToLower(config.EncryptionKey), "example") { return fmt.Errorf("production deployment cannot use test/default encryption keys") } // Validate encryption key strength if len(config.EncryptionKey) < 32 { return fmt.Errorf("encryption key must be at least 32 characters for production use") } // Check for weak encryption keys if config.EncryptionKey == "test123" || config.EncryptionKey == "password" || config.EncryptionKey == "123456789012345678901234567890" || strings.Repeat("a", len(config.EncryptionKey)) == config.EncryptionKey { return fmt.Errorf("encryption key is too weak for production use") } // Validate keystore path security if config.KeystorePath != "" { // Check that keystore path is not in a publicly accessible location publicPaths := []string{"/tmp", "/var/tmp", "/home/public", "/usr/tmp"} keystoreLower := strings.ToLower(config.KeystorePath) for _, publicPath := range publicPaths { if strings.HasPrefix(keystoreLower, publicPath) { return fmt.Errorf("keystore path '%s' is in a publicly accessible location", config.KeystorePath) } } } // Validate backup path if specified if config.BackupPath != "" { if config.BackupPath == config.KeystorePath { return fmt.Errorf("backup path cannot be the same as keystore path") } } return nil } // MEDIUM-001 ENHANCEMENT: Enhanced Rate Limiting Methods // Shutdown properly shuts down the KeyManager and its enhanced rate limiter func (km *KeyManager) Shutdown() { km.logger.Info("Shutting down KeyManager") // Stop enhanced rate limiter if km.enhancedRateLimiter != nil { km.enhancedRateLimiter.Stop() km.logger.Info("Enhanced rate limiter stopped") } // Clear all keys from memory (simplified for safety) km.keysMutex.Lock() km.keys = make(map[common.Address]*SecureKey) km.keysMutex.Unlock() // Clear all sessions km.sessionsMutex.Lock() km.activeSessions = make(map[string]*AuthenticationSession) km.sessionsMutex.Unlock() km.logger.Info("KeyManager shutdown complete") } // GetRateLimitMetrics returns current rate limiting metrics func (km *KeyManager) GetRateLimitMetrics() map[string]interface{} { if km.enhancedRateLimiter != nil { return km.enhancedRateLimiter.GetEnhancedMetrics() } // Fallback to simple metrics km.rateLimitMutex.Lock() defer km.rateLimitMutex.Unlock() totalTrackers := 0 activeTrackers := 0 now := time.Now() if km.signingRates != nil { totalTrackers = len(km.signingRates) for _, tracker := range km.signingRates { if now.Sub(tracker.StartTime) <= time.Minute && tracker.Count > 0 { activeTrackers++ } } } return map[string]interface{}{ "rate_limiting_enabled": km.config.MaxSigningRate > 0, "max_signing_rate": km.config.MaxSigningRate, "total_rate_trackers": totalTrackers, "active_rate_trackers": activeTrackers, "enhanced_rate_limiter": km.enhancedRateLimiter != nil, } } // SetRateLimitConfig allows dynamic configuration of rate limiting func (km *KeyManager) SetRateLimitConfig(maxSigningRate int, adaptiveEnabled bool) error { if maxSigningRate < 0 { return fmt.Errorf("maxSigningRate cannot be negative") } // Update basic config km.config.MaxSigningRate = maxSigningRate // Update enhanced rate limiter if available if km.enhancedRateLimiter != nil { // Create new enhanced rate limiter with updated configuration enhancedRateLimiterConfig := &RateLimiterConfig{ IPRequestsPerSecond: maxSigningRate, IPBurstSize: maxSigningRate * 2, UserRequestsPerSecond: maxSigningRate * 10, UserBurstSize: maxSigningRate * 20, GlobalRequestsPerSecond: maxSigningRate * 100, GlobalBurstSize: maxSigningRate * 200, SlidingWindowEnabled: true, SlidingWindowSize: time.Minute, SlidingWindowPrecision: time.Second, AdaptiveEnabled: adaptiveEnabled, SystemLoadThreshold: 80.0, AdaptiveAdjustInterval: 30 * time.Second, AdaptiveMinRate: 0.1, AdaptiveMaxRate: 5.0, BypassDetectionEnabled: true, BypassThreshold: maxSigningRate / 2, BypassDetectionWindow: time.Hour, BypassAlertCooldown: 10 * time.Minute, CleanupInterval: 5 * time.Minute, BucketTTL: time.Hour, } // Stop current rate limiter km.enhancedRateLimiter.Stop() // Create new enhanced rate limiter km.enhancedRateLimiter = NewEnhancedRateLimiter(enhancedRateLimiterConfig) km.logger.Info(fmt.Sprintf("Enhanced rate limiter reconfigured: maxSigningRate=%d, adaptive=%t", maxSigningRate, adaptiveEnabled)) } km.logger.Info(fmt.Sprintf("Rate limiting configuration updated: maxSigningRate=%d", maxSigningRate)) return nil } // GetRateLimitStatus returns current rate limiting status for monitoring func (km *KeyManager) GetRateLimitStatus() map[string]interface{} { status := map[string]interface{}{ "enabled": km.config.MaxSigningRate > 0, "max_signing_rate": km.config.MaxSigningRate, "enhanced_limiter": km.enhancedRateLimiter != nil, } if km.enhancedRateLimiter != nil { enhancedMetrics := km.enhancedRateLimiter.GetEnhancedMetrics() status["sliding_window_enabled"] = enhancedMetrics["sliding_window_enabled"] status["adaptive_enabled"] = enhancedMetrics["adaptive_enabled"] status["bypass_detection_enabled"] = enhancedMetrics["bypass_detection_enabled"] status["system_load"] = enhancedMetrics["system_load_average"] status["bypass_alerts"] = enhancedMetrics["bypass_alerts_active"] status["blocked_ips"] = enhancedMetrics["blocked_ips"] } return status }