Files
mev-beta/pkg/security/keymanager.go
2025-09-14 06:21:10 -05:00

305 lines
8.1 KiB
Go

package security
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"path/filepath"
"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/accounts/keystore"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/fraktal/mev-beta/internal/logger"
)
// KeyManager handles secure key management for the MEV bot
type KeyManager struct {
keystore *keystore.KeyStore
logger *logger.Logger
keyDir string
}
// NewKeyManager creates a new secure key manager
func NewKeyManager(keyDir string, logger *logger.Logger) *KeyManager {
// Ensure key directory exists and has proper permissions
if err := os.MkdirAll(keyDir, 0700); err != nil {
logger.Error(fmt.Sprintf("Failed to create key directory: %v", err))
return nil
}
// Create keystore with scrypt parameters for security
ks := keystore.NewKeyStore(keyDir, keystore.StandardScryptN, keystore.StandardScryptP)
return &KeyManager{
keystore: ks,
logger: logger,
keyDir: keyDir,
}
}
// CreateAccount creates a new account with a secure random key
func (km *KeyManager) CreateAccount(password string) (accounts.Account, error) {
if len(password) < 12 {
return accounts.Account{}, fmt.Errorf("password must be at least 12 characters")
}
// Generate account
account, err := km.keystore.NewAccount(password)
if err != nil {
km.logger.Error(fmt.Sprintf("Failed to create account: %v", err))
return accounts.Account{}, err
}
km.logger.Info(fmt.Sprintf("Created new account: %s", account.Address.Hex()))
return account, nil
}
// UnlockAccount unlocks an account for signing transactions
func (km *KeyManager) UnlockAccount(address common.Address, password string) error {
account := accounts.Account{Address: address}
err := km.keystore.Unlock(account, password)
if err != nil {
km.logger.Error(fmt.Sprintf("Failed to unlock account %s: %v", address.Hex(), err))
return err
}
km.logger.Info(fmt.Sprintf("Unlocked account: %s", address.Hex()))
return nil
}
// GetSignerFunction returns a signing function for the given address
func (km *KeyManager) GetSignerFunction(address common.Address) (func([]byte) ([]byte, error), error) {
account := accounts.Account{Address: address}
// Find the account in keystore
if !km.keystore.HasAddress(address) {
return nil, fmt.Errorf("account %s not found in keystore", address.Hex())
}
return func(hash []byte) ([]byte, error) {
signature, err := km.keystore.SignHash(account, hash)
if err != nil {
km.logger.Error(fmt.Sprintf("Failed to sign hash: %v", err))
return nil, err
}
return signature, nil
}, nil
}
// SecureConfig handles secure configuration management
type SecureConfig struct {
logger *logger.Logger
configPath string
encryptionKey [32]byte
}
// NewSecureConfig creates a new secure configuration manager
func NewSecureConfig(configPath string, logger *logger.Logger) (*SecureConfig, error) {
// Generate or load encryption key
keyPath := filepath.Join(filepath.Dir(configPath), ".encryption.key")
key, err := loadOrGenerateKey(keyPath)
if err != nil {
return nil, fmt.Errorf("failed to setup encryption key: %v", err)
}
return &SecureConfig{
logger: logger,
configPath: configPath,
encryptionKey: key,
}, nil
}
// loadOrGenerateKey loads existing encryption key or generates a new one
func loadOrGenerateKey(keyPath string) ([32]byte, error) {
var key [32]byte
// Try to load existing key
if keyData, err := os.ReadFile(keyPath); err == nil {
if len(keyData) == 64 { // Hex encoded key
decoded, err := hex.DecodeString(string(keyData))
if err == nil && len(decoded) == 32 {
copy(key[:], decoded)
return key, nil
}
}
}
// Generate new key
_, err := rand.Read(key[:])
if err != nil {
return key, err
}
// Save key securely
keyHex := hex.EncodeToString(key[:])
err = os.WriteFile(keyPath, []byte(keyHex), 0600)
if err != nil {
return key, err
}
return key, nil
}
// ValidatePrivateKey validates that a private key is secure
func (km *KeyManager) ValidatePrivateKey(privateKeyHex string) error {
if len(privateKeyHex) < 64 {
return fmt.Errorf("private key too short")
}
// Remove 0x prefix if present
if len(privateKeyHex) >= 2 && privateKeyHex[:2] == "0x" {
privateKeyHex = privateKeyHex[2:]
}
// Validate hex encoding
privateKeyBytes, err := hex.DecodeString(privateKeyHex)
if err != nil {
return fmt.Errorf("invalid hex encoding: %v", err)
}
if len(privateKeyBytes) != 32 {
return fmt.Errorf("private key must be 32 bytes")
}
// Validate that it's not a weak key
privateKey, err := crypto.ToECDSA(privateKeyBytes)
if err != nil {
return fmt.Errorf("invalid private key: %v", err)
}
// Check if key is not zero
if privateKey.D.Sign() == 0 {
return fmt.Errorf("private key cannot be zero")
}
return nil
}
// SecureEndpoint represents a secure RPC endpoint configuration
type SecureEndpoint struct {
URL string
APIKey string
TLSConfig *TLSConfig
}
// TLSConfig represents TLS configuration for secure connections
type TLSConfig struct {
InsecureSkipVerify bool
CertFile string
KeyFile string
CAFile string
}
// ConnectionManager manages secure connections to RPC endpoints
type ConnectionManager struct {
endpoints map[string]*SecureEndpoint
logger *logger.Logger
}
// NewConnectionManager creates a new secure connection manager
func NewConnectionManager(logger *logger.Logger) *ConnectionManager {
return &ConnectionManager{
endpoints: make(map[string]*SecureEndpoint),
logger: logger,
}
}
// AddEndpoint adds a secure endpoint configuration
func (cm *ConnectionManager) AddEndpoint(name string, endpoint *SecureEndpoint) {
// Validate endpoint URL
if !isSecureURL(endpoint.URL) {
cm.logger.Warn(fmt.Sprintf("Endpoint %s is not using HTTPS/WSS", name))
}
cm.endpoints[name] = endpoint
cm.logger.Info(fmt.Sprintf("Added secure endpoint: %s", name))
}
// isSecureURL checks if URL uses secure protocol
func isSecureURL(url string) bool {
return len(url) >= 5 && (url[:5] == "https" || url[:3] == "wss")
}
// ValidateAPIKey validates API key format and strength
func ValidateAPIKey(apiKey string) error {
if len(apiKey) < 32 {
return fmt.Errorf("API key too short, minimum 32 characters required")
}
// Check for obvious patterns
if isWeakAPIKey(apiKey) {
return fmt.Errorf("API key appears to be weak or default")
}
return nil
}
// isWeakAPIKey checks for common weak API key patterns
func isWeakAPIKey(apiKey string) bool {
weakPatterns := []string{
"test",
"demo",
"sample",
"your_api_key",
"replace_me",
"changeme",
}
apiKeyLower := apiKey
for _, pattern := range weakPatterns {
if apiKeyLower == pattern {
return true
}
}
return false
}
// SecureHasher provides secure hashing functionality
type SecureHasher struct{}
// Hash creates a secure hash of the input data
func (sh *SecureHasher) Hash(data []byte) [32]byte {
return sha256.Sum256(data)
}
// HashString creates a secure hash of a string
func (sh *SecureHasher) HashString(data string) string {
hash := sh.Hash([]byte(data))
return hex.EncodeToString(hash[:])
}
// AccessControl manages access control for the MEV bot
type AccessControl struct {
allowedAddresses map[common.Address]bool
logger *logger.Logger
}
// NewAccessControl creates a new access control manager
func NewAccessControl(logger *logger.Logger) *AccessControl {
return &AccessControl{
allowedAddresses: make(map[common.Address]bool),
logger: logger,
}
}
// AddAllowedAddress adds an address to the allowed list
func (ac *AccessControl) AddAllowedAddress(address common.Address) {
ac.allowedAddresses[address] = true
ac.logger.Info(fmt.Sprintf("Added allowed address: %s", address.Hex()))
}
// IsAllowed checks if an address is allowed
func (ac *AccessControl) IsAllowed(address common.Address) bool {
return ac.allowedAddresses[address]
}
// RemoveAllowedAddress removes an address from the allowed list
func (ac *AccessControl) RemoveAllowedAddress(address common.Address) {
delete(ac.allowedAddresses, address)
ac.logger.Info(fmt.Sprintf("Removed allowed address: %s", address.Hex()))
}