306 lines
8.1 KiB
Go
306 lines
8.1 KiB
Go
package security
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/ethereum/go-ethereum/accounts"
|
|
"github.com/ethereum/go-ethereum/accounts/keystore"
|
|
"github.com/ethereum/go-ethereum/common"
|
|
"github.com/ethereum/go-ethereum/crypto"
|
|
"github.com/fraktal/mev-beta/internal/logger"
|
|
)
|
|
|
|
// KeyManager handles secure key management for the MEV bot
|
|
type KeyManager struct {
|
|
keystore *keystore.KeyStore
|
|
logger *logger.Logger
|
|
keyDir string
|
|
}
|
|
|
|
// NewKeyManager creates a new secure key manager
|
|
func NewKeyManager(keyDir string, logger *logger.Logger) *KeyManager {
|
|
// Ensure key directory exists and has proper permissions
|
|
if err := os.MkdirAll(keyDir, 0700); err != nil {
|
|
logger.Error(fmt.Sprintf("Failed to create key directory: %v", err))
|
|
return nil
|
|
}
|
|
|
|
// Create keystore with scrypt parameters for security
|
|
ks := keystore.NewKeyStore(keyDir, keystore.StandardScryptN, keystore.StandardScryptP)
|
|
|
|
return &KeyManager{
|
|
keystore: ks,
|
|
logger: logger,
|
|
keyDir: keyDir,
|
|
}
|
|
}
|
|
|
|
// CreateAccount creates a new account with a secure random key
|
|
func (km *KeyManager) CreateAccount(password string) (accounts.Account, error) {
|
|
if len(password) < 12 {
|
|
return accounts.Account{}, fmt.Errorf("password must be at least 12 characters")
|
|
}
|
|
|
|
// Generate account
|
|
account, err := km.keystore.NewAccount(password)
|
|
if err != nil {
|
|
km.logger.Error(fmt.Sprintf("Failed to create account: %v", err))
|
|
return accounts.Account{}, err
|
|
}
|
|
|
|
km.logger.Info(fmt.Sprintf("Created new account: %s", account.Address.Hex()))
|
|
return account, nil
|
|
}
|
|
|
|
// UnlockAccount unlocks an account for signing transactions
|
|
func (km *KeyManager) UnlockAccount(address common.Address, password string) error {
|
|
account := accounts.Account{Address: address}
|
|
|
|
err := km.keystore.Unlock(account, password)
|
|
if err != nil {
|
|
km.logger.Error(fmt.Sprintf("Failed to unlock account %s: %v", address.Hex(), err))
|
|
return err
|
|
}
|
|
|
|
km.logger.Info(fmt.Sprintf("Unlocked account: %s", address.Hex()))
|
|
return nil
|
|
}
|
|
|
|
// GetSignerFunction returns a signing function for the given address
|
|
func (km *KeyManager) GetSignerFunction(address common.Address) (func([]byte) ([]byte, error), error) {
|
|
account := accounts.Account{Address: address}
|
|
|
|
// Find the account in keystore
|
|
if !km.keystore.HasAddress(address) {
|
|
return nil, fmt.Errorf("account %s not found in keystore", address.Hex())
|
|
}
|
|
|
|
return func(hash []byte) ([]byte, error) {
|
|
signature, err := km.keystore.SignHash(account, hash)
|
|
if err != nil {
|
|
km.logger.Error(fmt.Sprintf("Failed to sign hash: %v", err))
|
|
return nil, err
|
|
}
|
|
return signature, nil
|
|
}, nil
|
|
}
|
|
|
|
// SecureConfig handles secure configuration management
|
|
type SecureConfig struct {
|
|
logger *logger.Logger
|
|
configPath string
|
|
encryptionKey [32]byte
|
|
}
|
|
|
|
// NewSecureConfig creates a new secure configuration manager
|
|
func NewSecureConfig(configPath string, logger *logger.Logger) (*SecureConfig, error) {
|
|
// Generate or load encryption key
|
|
keyPath := filepath.Join(filepath.Dir(configPath), ".encryption.key")
|
|
key, err := loadOrGenerateKey(keyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to setup encryption key: %v", err)
|
|
}
|
|
|
|
return &SecureConfig{
|
|
logger: logger,
|
|
configPath: configPath,
|
|
encryptionKey: key,
|
|
}, nil
|
|
}
|
|
|
|
// loadOrGenerateKey loads existing encryption key or generates a new one
|
|
func loadOrGenerateKey(keyPath string) ([32]byte, error) {
|
|
var key [32]byte
|
|
|
|
// Try to load existing key
|
|
if keyData, err := os.ReadFile(keyPath); err == nil {
|
|
if len(keyData) == 64 { // Hex encoded key
|
|
decoded, err := hex.DecodeString(string(keyData))
|
|
if err == nil && len(decoded) == 32 {
|
|
copy(key[:], decoded)
|
|
return key, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Generate new key
|
|
_, err := rand.Read(key[:])
|
|
if err != nil {
|
|
return key, err
|
|
}
|
|
|
|
// Save key securely
|
|
keyHex := hex.EncodeToString(key[:])
|
|
err = os.WriteFile(keyPath, []byte(keyHex), 0600)
|
|
if err != nil {
|
|
return key, err
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
// ValidatePrivateKey validates that a private key is secure
|
|
func (km *KeyManager) ValidatePrivateKey(privateKeyHex string) error {
|
|
if len(privateKeyHex) < 64 {
|
|
return fmt.Errorf("private key too short")
|
|
}
|
|
|
|
// Remove 0x prefix if present
|
|
if len(privateKeyHex) >= 2 && privateKeyHex[:2] == "0x" {
|
|
privateKeyHex = privateKeyHex[2:]
|
|
}
|
|
|
|
// Validate hex encoding
|
|
privateKeyBytes, err := hex.DecodeString(privateKeyHex)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid hex encoding: %v", err)
|
|
}
|
|
|
|
if len(privateKeyBytes) != 32 {
|
|
return fmt.Errorf("private key must be 32 bytes")
|
|
}
|
|
|
|
// Validate that it's not a weak key
|
|
privateKey, err := crypto.ToECDSA(privateKeyBytes)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid private key: %v", err)
|
|
}
|
|
|
|
// Check if key is not zero
|
|
if privateKey.D.Sign() == 0 {
|
|
return fmt.Errorf("private key cannot be zero")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SecureEndpoint represents a secure RPC endpoint configuration
|
|
type SecureEndpoint struct {
|
|
URL string
|
|
APIKey string
|
|
TLSConfig *TLSConfig
|
|
}
|
|
|
|
// TLSConfig represents TLS configuration for secure connections
|
|
type TLSConfig struct {
|
|
InsecureSkipVerify bool
|
|
CertFile string
|
|
KeyFile string
|
|
CAFile string
|
|
}
|
|
|
|
// ConnectionManager manages secure connections to RPC endpoints
|
|
type ConnectionManager struct {
|
|
endpoints map[string]*SecureEndpoint
|
|
logger *logger.Logger
|
|
}
|
|
|
|
// NewConnectionManager creates a new secure connection manager
|
|
func NewConnectionManager(logger *logger.Logger) *ConnectionManager {
|
|
return &ConnectionManager{
|
|
endpoints: make(map[string]*SecureEndpoint),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// AddEndpoint adds a secure endpoint configuration
|
|
func (cm *ConnectionManager) AddEndpoint(name string, endpoint *SecureEndpoint) {
|
|
// Validate endpoint URL
|
|
if !isSecureURL(endpoint.URL) {
|
|
cm.logger.Warn(fmt.Sprintf("Endpoint %s is not using HTTPS/WSS", name))
|
|
}
|
|
|
|
cm.endpoints[name] = endpoint
|
|
cm.logger.Info(fmt.Sprintf("Added secure endpoint: %s", name))
|
|
}
|
|
|
|
// isSecureURL checks if URL uses secure protocol
|
|
func isSecureURL(url string) bool {
|
|
return len(url) >= 5 && (url[:5] == "https" || url[:3] == "wss")
|
|
}
|
|
|
|
// ValidateAPIKey validates API key format and strength
|
|
func ValidateAPIKey(apiKey string) error {
|
|
if len(apiKey) < 32 {
|
|
return fmt.Errorf("API key too short, minimum 32 characters required")
|
|
}
|
|
|
|
// Check for obvious patterns
|
|
if isWeakAPIKey(apiKey) {
|
|
return fmt.Errorf("API key appears to be weak or default")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// isWeakAPIKey checks for common weak API key patterns
|
|
func isWeakAPIKey(apiKey string) bool {
|
|
weakPatterns := []string{
|
|
"test",
|
|
"demo",
|
|
"sample",
|
|
"your_api_key",
|
|
"replace_me",
|
|
"changeme",
|
|
}
|
|
|
|
apiKeyLower := apiKey
|
|
for _, pattern := range weakPatterns {
|
|
if apiKeyLower == pattern {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// SecureHasher provides secure hashing functionality
|
|
type SecureHasher struct{}
|
|
|
|
// Hash creates a secure hash of the input data
|
|
func (sh *SecureHasher) Hash(data []byte) [32]byte {
|
|
return sha256.Sum256(data)
|
|
}
|
|
|
|
// HashString creates a secure hash of a string
|
|
func (sh *SecureHasher) HashString(data string) string {
|
|
hash := sh.Hash([]byte(data))
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
// AccessControl manages access control for the MEV bot
|
|
type AccessControl struct {
|
|
allowedAddresses map[common.Address]bool
|
|
logger *logger.Logger
|
|
}
|
|
|
|
// NewAccessControl creates a new access control manager
|
|
func NewAccessControl(logger *logger.Logger) *AccessControl {
|
|
return &AccessControl{
|
|
allowedAddresses: make(map[common.Address]bool),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// AddAllowedAddress adds an address to the allowed list
|
|
func (ac *AccessControl) AddAllowedAddress(address common.Address) {
|
|
ac.allowedAddresses[address] = true
|
|
ac.logger.Info(fmt.Sprintf("Added allowed address: %s", address.Hex()))
|
|
}
|
|
|
|
// IsAllowed checks if an address is allowed
|
|
func (ac *AccessControl) IsAllowed(address common.Address) bool {
|
|
return ac.allowedAddresses[address]
|
|
}
|
|
|
|
// RemoveAllowedAddress removes an address from the allowed list
|
|
func (ac *AccessControl) RemoveAllowedAddress(address common.Address) {
|
|
delete(ac.allowedAddresses, address)
|
|
ac.logger.Info(fmt.Sprintf("Removed allowed address: %s", address.Hex()))
|
|
}
|