package security import ( "crypto/rand" "crypto/sha256" "encoding/hex" "fmt" "os" "path/filepath" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts/keystore" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/fraktal/mev-beta/internal/logger" ) // KeyManager handles secure key management for the MEV bot type KeyManager struct { keystore *keystore.KeyStore logger *logger.Logger keyDir string } // NewKeyManager creates a new secure key manager func NewKeyManager(keyDir string, logger *logger.Logger) *KeyManager { // Ensure key directory exists and has proper permissions if err := os.MkdirAll(keyDir, 0700); err != nil { logger.Error(fmt.Sprintf("Failed to create key directory: %v", err)) return nil } // Create keystore with scrypt parameters for security ks := keystore.NewKeyStore(keyDir, keystore.StandardScryptN, keystore.StandardScryptP) return &KeyManager{ keystore: ks, logger: logger, keyDir: keyDir, } } // CreateAccount creates a new account with a secure random key func (km *KeyManager) CreateAccount(password string) (accounts.Account, error) { if len(password) < 12 { return accounts.Account{}, fmt.Errorf("password must be at least 12 characters") } // Generate account account, err := km.keystore.NewAccount(password) if err != nil { km.logger.Error(fmt.Sprintf("Failed to create account: %v", err)) return accounts.Account{}, err } km.logger.Info(fmt.Sprintf("Created new account: %s", account.Address.Hex())) return account, nil } // UnlockAccount unlocks an account for signing transactions func (km *KeyManager) UnlockAccount(address common.Address, password string) error { account := accounts.Account{Address: address} err := km.keystore.Unlock(account, password) if err != nil { km.logger.Error(fmt.Sprintf("Failed to unlock account %s: %v", address.Hex(), err)) return err } km.logger.Info(fmt.Sprintf("Unlocked account: %s", address.Hex())) return nil } // GetSignerFunction returns a signing function for the given address func (km *KeyManager) GetSignerFunction(address common.Address) (func([]byte) ([]byte, error), error) { account := accounts.Account{Address: address} // Find the account in keystore if !km.keystore.HasAddress(address) { return nil, fmt.Errorf("account %s not found in keystore", address.Hex()) } return func(hash []byte) ([]byte, error) { signature, err := km.keystore.SignHash(account, hash) if err != nil { km.logger.Error(fmt.Sprintf("Failed to sign hash: %v", err)) return nil, err } return signature, nil }, nil } // SecureConfig handles secure configuration management type SecureConfig struct { logger *logger.Logger configPath string encryptionKey [32]byte } // NewSecureConfig creates a new secure configuration manager func NewSecureConfig(configPath string, logger *logger.Logger) (*SecureConfig, error) { // Generate or load encryption key keyPath := filepath.Join(filepath.Dir(configPath), ".encryption.key") key, err := loadOrGenerateKey(keyPath) if err != nil { return nil, fmt.Errorf("failed to setup encryption key: %v", err) } return &SecureConfig{ logger: logger, configPath: configPath, encryptionKey: key, }, nil } // loadOrGenerateKey loads existing encryption key or generates a new one func loadOrGenerateKey(keyPath string) ([32]byte, error) { var key [32]byte // Try to load existing key if keyData, err := os.ReadFile(keyPath); err == nil { if len(keyData) == 64 { // Hex encoded key decoded, err := hex.DecodeString(string(keyData)) if err == nil && len(decoded) == 32 { copy(key[:], decoded) return key, nil } } } // Generate new key _, err := rand.Read(key[:]) if err != nil { return key, err } // Save key securely keyHex := hex.EncodeToString(key[:]) err = os.WriteFile(keyPath, []byte(keyHex), 0600) if err != nil { return key, err } return key, nil } // ValidatePrivateKey validates that a private key is secure func (km *KeyManager) ValidatePrivateKey(privateKeyHex string) error { if len(privateKeyHex) < 64 { return fmt.Errorf("private key too short") } // Remove 0x prefix if present if len(privateKeyHex) >= 2 && privateKeyHex[:2] == "0x" { privateKeyHex = privateKeyHex[2:] } // Validate hex encoding privateKeyBytes, err := hex.DecodeString(privateKeyHex) if err != nil { return fmt.Errorf("invalid hex encoding: %v", err) } if len(privateKeyBytes) != 32 { return fmt.Errorf("private key must be 32 bytes") } // Validate that it's not a weak key privateKey, err := crypto.ToECDSA(privateKeyBytes) if err != nil { return fmt.Errorf("invalid private key: %v", err) } // Check if key is not zero if privateKey.D.Sign() == 0 { return fmt.Errorf("private key cannot be zero") } return nil } // SecureEndpoint represents a secure RPC endpoint configuration type SecureEndpoint struct { URL string APIKey string TLSConfig *TLSConfig } // TLSConfig represents TLS configuration for secure connections type TLSConfig struct { InsecureSkipVerify bool CertFile string KeyFile string CAFile string } // ConnectionManager manages secure connections to RPC endpoints type ConnectionManager struct { endpoints map[string]*SecureEndpoint logger *logger.Logger } // NewConnectionManager creates a new secure connection manager func NewConnectionManager(logger *logger.Logger) *ConnectionManager { return &ConnectionManager{ endpoints: make(map[string]*SecureEndpoint), logger: logger, } } // AddEndpoint adds a secure endpoint configuration func (cm *ConnectionManager) AddEndpoint(name string, endpoint *SecureEndpoint) { // Validate endpoint URL if !isSecureURL(endpoint.URL) { cm.logger.Warn(fmt.Sprintf("Endpoint %s is not using HTTPS/WSS", name)) } cm.endpoints[name] = endpoint cm.logger.Info(fmt.Sprintf("Added secure endpoint: %s", name)) } // isSecureURL checks if URL uses secure protocol func isSecureURL(url string) bool { return len(url) >= 5 && (url[:5] == "https" || url[:3] == "wss") } // ValidateAPIKey validates API key format and strength func ValidateAPIKey(apiKey string) error { if len(apiKey) < 32 { return fmt.Errorf("API key too short, minimum 32 characters required") } // Check for obvious patterns if isWeakAPIKey(apiKey) { return fmt.Errorf("API key appears to be weak or default") } return nil } // isWeakAPIKey checks for common weak API key patterns func isWeakAPIKey(apiKey string) bool { weakPatterns := []string{ "test", "demo", "sample", "your_api_key", "replace_me", "changeme", } apiKeyLower := apiKey for _, pattern := range weakPatterns { if apiKeyLower == pattern { return true } } return false } // SecureHasher provides secure hashing functionality type SecureHasher struct{} // Hash creates a secure hash of the input data func (sh *SecureHasher) Hash(data []byte) [32]byte { return sha256.Sum256(data) } // HashString creates a secure hash of a string func (sh *SecureHasher) HashString(data string) string { hash := sh.Hash([]byte(data)) return hex.EncodeToString(hash[:]) } // AccessControl manages access control for the MEV bot type AccessControl struct { allowedAddresses map[common.Address]bool logger *logger.Logger } // NewAccessControl creates a new access control manager func NewAccessControl(logger *logger.Logger) *AccessControl { return &AccessControl{ allowedAddresses: make(map[common.Address]bool), logger: logger, } } // AddAllowedAddress adds an address to the allowed list func (ac *AccessControl) AddAllowedAddress(address common.Address) { ac.allowedAddresses[address] = true ac.logger.Info(fmt.Sprintf("Added allowed address: %s", address.Hex())) } // IsAllowed checks if an address is allowed func (ac *AccessControl) IsAllowed(address common.Address) bool { return ac.allowedAddresses[address] } // RemoveAllowedAddress removes an address from the allowed list func (ac *AccessControl) RemoveAllowedAddress(address common.Address) { delete(ac.allowedAddresses, address) ac.logger.Info(fmt.Sprintf("Removed allowed address: %s", address.Hex())) }