Files
mev-beta/pkg/security/keymanager.go
Krypto Kajun 850223a953 fix(multicall): resolve critical multicall parsing corruption issues
- Added comprehensive bounds checking to prevent buffer overruns in multicall parsing
- Implemented graduated validation system (Strict/Moderate/Permissive) to reduce false positives
- Added LRU caching system for address validation with 10-minute TTL
- Enhanced ABI decoder with missing Universal Router and Arbitrum-specific DEX signatures
- Fixed duplicate function declarations and import conflicts across multiple files
- Added error recovery mechanisms with multiple fallback strategies
- Updated tests to handle new validation behavior for suspicious addresses
- Fixed parser test expectations for improved validation system
- Applied gofmt formatting fixes to ensure code style compliance
- Fixed mutex copying issues in monitoring package by introducing MetricsSnapshot
- Resolved critical security vulnerabilities in heuristic address extraction
- Progress: Updated TODO audit from 10% to 35% complete

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 00:12:55 -05:00

1422 lines
44 KiB
Go

package security
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math/big"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/accounts/keystore"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"golang.org/x/crypto/scrypt"
"github.com/fraktal/mev-beta/internal/logger"
)
// AuthenticationContext contains authentication information for key access
type AuthenticationContext struct {
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
AuthMethod string `json:"auth_method"` // "password", "mfa", "hardware_token"
AuthTime time.Time `json:"auth_time"`
ExpiresAt time.Time `json:"expires_at"`
Permissions []string `json:"permissions"`
RiskScore int `json:"risk_score"`
}
// AuthenticationSession tracks active authentication sessions
type AuthenticationSession struct {
ID string `json:"id"`
Context *AuthenticationContext `json:"context"`
CreatedAt time.Time `json:"created_at"`
LastActivity time.Time `json:"last_activity"`
IsActive bool `json:"is_active"`
LoginAttempts int `json:"login_attempts"`
}
// KeyAccessEvent represents an access event to a private key
type KeyAccessEvent struct {
Timestamp time.Time `json:"timestamp"`
KeyAddress common.Address `json:"key_address"`
Operation string `json:"operation"` // "access", "sign", "rotate", "fail"
Success bool `json:"success"`
Source string `json:"source"`
IPAddress string `json:"ip_address,omitempty"`
UserAgent string `json:"user_agent,omitempty"`
ErrorMsg string `json:"error_msg,omitempty"`
AuthContext *AuthenticationContext `json:"auth_context,omitempty"`
RiskLevel string `json:"risk_level"`
}
// SecureKey represents an encrypted private key with metadata
type SecureKey struct {
Address common.Address `json:"address"`
EncryptedKey []byte `json:"encrypted_key"`
CreatedAt time.Time `json:"created_at"`
LastUsedUnix int64 `json:"last_used_unix"` // Atomic access to Unix timestamp
UsageCount int64 `json:"usage_count"` // Atomic access to usage counter
MaxUsage int `json:"max_usage"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
KeyVersion int `json:"key_version"`
Salt []byte `json:"salt"`
Nonce []byte `json:"nonce"`
KeyType string `json:"key_type"`
Permissions KeyPermissions `json:"permissions"`
IsActive bool `json:"is_active"`
BackupLocations []string `json:"backup_locations,omitempty"`
// Mutex for non-atomic fields
mu sync.RWMutex `json:"-"`
}
// GetLastUsed returns the last used time in a thread-safe manner
func (sk *SecureKey) GetLastUsed() time.Time {
lastUsedUnix := atomic.LoadInt64(&sk.LastUsedUnix)
if lastUsedUnix == 0 {
return time.Time{}
}
return time.Unix(lastUsedUnix, 0)
}
// GetUsageCount returns the usage count in a thread-safe manner
func (sk *SecureKey) GetUsageCount() int64 {
return atomic.LoadInt64(&sk.UsageCount)
}
// SetLastUsed sets the last used time in a thread-safe manner
func (sk *SecureKey) SetLastUsed(t time.Time) {
atomic.StoreInt64(&sk.LastUsedUnix, t.Unix())
}
// IncrementUsageCount increments and returns the new usage count
func (sk *SecureKey) IncrementUsageCount() int64 {
return atomic.AddInt64(&sk.UsageCount, 1)
}
// SigningRateTracker tracks signing rates per key
type SigningRateTracker struct {
LastReset time.Time
StartTime time.Time
Count int
MaxPerMinute int
MaxPerHour int
HourlyCount int
}
// KeyManagerConfig provides configuration for the key manager
type KeyManagerConfig struct {
KeyDir string `json:"key_dir"`
KeystorePath string `json:"keystore_path"`
EncryptionKey string `json:"encryption_key"`
BackupPath string `json:"backup_path"`
RotationInterval time.Duration `json:"rotation_interval"`
MaxKeyAge time.Duration `json:"max_key_age"`
MaxFailedAttempts int `json:"max_failed_attempts"`
LockoutDuration time.Duration `json:"lockout_duration"`
EnableAuditLogging bool `json:"enable_audit_logging"`
MaxSigningsPerMinute int `json:"max_signings_per_minute"`
MaxSigningsPerHour int `json:"max_signings_per_hour"`
RequireHSM bool `json:"require_hsm"`
BackupEnabled bool `json:"backup_enabled"`
BackupLocation string `json:"backup_location"`
MaxSigningRate int `json:"max_signing_rate"`
AuditLogPath string `json:"audit_log_path"`
KeyRotationDays int `json:"key_rotation_days"`
RequireHardware bool `json:"require_hardware"`
SessionTimeout time.Duration `json:"session_timeout"`
// Authentication and Authorization Configuration
RequireAuthentication bool `json:"require_authentication"`
EnableIPWhitelist bool `json:"enable_ip_whitelist"`
WhitelistedIPs []string `json:"whitelisted_ips"`
MaxConcurrentSessions int `json:"max_concurrent_sessions"`
RequireMFA bool `json:"require_mfa"`
PasswordHashRounds int `json:"password_hash_rounds"`
MaxSessionAge time.Duration `json:"max_session_age"`
EnableRateLimiting bool `json:"enable_rate_limiting"`
MaxAuthAttempts int `json:"max_auth_attempts"`
AuthLockoutDuration time.Duration `json:"auth_lockout_duration"`
}
// KeyManager provides secure private key management and transaction signing
type KeyManager struct {
logger *logger.Logger
keystore *keystore.KeyStore
encryptionKey []byte
// Enhanced security features
mu sync.RWMutex
activeKeyRotation bool
lastKeyRotation time.Time
keyRotationInterval time.Duration
maxKeyAge time.Duration
failedAccessAttempts map[string]int
accessLockouts map[string]time.Time
maxFailedAttempts int
lockoutDuration time.Duration
// Authentication and Authorization
activeSessions map[string]*AuthenticationSession
sessionsMutex sync.RWMutex
whitelistedIPs map[string]bool
ipWhitelistMutex sync.RWMutex
authMutex sync.Mutex
sessionTimeout time.Duration
maxConcurrentSessions int
// Audit logging
accessLog []KeyAccessEvent
maxLogEntries int
// Key derivation settings
scryptN int
scryptR int
scryptP int
scryptKeyLen int
keys map[common.Address]*SecureKey
keysMutex sync.RWMutex
config *KeyManagerConfig
signingRates map[string]*SigningRateTracker
rateLimitMutex sync.Mutex
}
// KeyPermissions defines what operations a key can perform
type KeyPermissions struct {
CanSign bool `json:"can_sign"`
CanTransfer bool `json:"can_transfer"`
MaxTransferWei *big.Int `json:"max_transfer_wei,omitempty"`
AllowedContracts []string `json:"allowed_contracts,omitempty"`
RequireConfirm bool `json:"require_confirmation"`
}
// SigningRequest represents a request to sign a transaction
type SigningRequest struct {
Transaction *types.Transaction
ChainID *big.Int
From common.Address
Purpose string // Description of what this transaction does
UrgencyLevel int // 1-5, with 5 being emergency
}
// SigningResult contains the result of a signing operation
type SigningResult struct {
SignedTx *types.Transaction
Signature []byte
SignedAt time.Time
KeyUsed common.Address
AuditID string
Warnings []string
}
// AuditEntry represents a security audit log entry
type AuditEntry struct {
Timestamp time.Time `json:"timestamp"`
Operation string `json:"operation"`
KeyAddress common.Address `json:"key_address"`
Success bool `json:"success"`
Details string `json:"details"`
IPAddress string `json:"ip_address,omitempty"`
UserAgent string `json:"user_agent,omitempty"`
RiskScore int `json:"risk_score"` // 1-10
}
// NewKeyManager creates a new secure key manager
func NewKeyManager(config *KeyManagerConfig, logger *logger.Logger) (*KeyManager, error) {
return newKeyManagerInternal(config, logger, true)
}
// newKeyManagerForTesting creates a key manager without production validation (test only)
func newKeyManagerForTesting(config *KeyManagerConfig, logger *logger.Logger) (*KeyManager, error) {
return newKeyManagerInternal(config, logger, false)
}
func newKeyManagerInternal(config *KeyManagerConfig, logger *logger.Logger, validateProduction bool) (*KeyManager, error) {
if config == nil {
config = getDefaultConfig()
}
// Critical Security Fix: Validate production encryption key (skip for tests)
if validateProduction {
if err := validateProductionConfig(config); err != nil {
return nil, fmt.Errorf("production configuration validation failed: %w", err)
}
}
// Validate configuration
if err := validateConfig(config); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
// Create keystore directory if it doesn't exist
if err := os.MkdirAll(config.KeystorePath, 0700); err != nil {
return nil, fmt.Errorf("failed to create keystore directory: %w", err)
}
// Create backup directory if specified
if config.BackupPath != "" {
if err := os.MkdirAll(config.BackupPath, 0700); err != nil {
return nil, fmt.Errorf("failed to create backup directory: %w", err)
}
}
// Initialize keystore
ks := keystore.NewKeyStore(config.KeystorePath, keystore.StandardScryptN, keystore.StandardScryptP)
// Derive encryption key from master key
encryptionKey, err := deriveEncryptionKey(config.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("failed to derive encryption key: %w", err)
}
km := &KeyManager{
logger: logger,
keystore: ks,
encryptionKey: encryptionKey,
keys: make(map[common.Address]*SecureKey),
config: config,
activeSessions: make(map[string]*AuthenticationSession),
whitelistedIPs: make(map[string]bool),
failedAccessAttempts: make(map[string]int),
accessLockouts: make(map[string]time.Time),
maxFailedAttempts: config.MaxFailedAttempts,
lockoutDuration: config.LockoutDuration,
sessionTimeout: config.SessionTimeout,
maxConcurrentSessions: config.MaxConcurrentSessions,
}
// Initialize IP whitelist
if config.EnableIPWhitelist {
for _, ip := range config.WhitelistedIPs {
km.whitelistedIPs[ip] = true
}
}
// Load existing keys
if err := km.loadExistingKeys(); err != nil {
logger.Warn(fmt.Sprintf("Failed to load existing keys: %v", err))
}
// Start background tasks
go km.backgroundTasks()
logger.Info("Secure key manager initialized")
return km, nil
}
// GenerateKey creates a new private key with specified permissions
func (km *KeyManager) GenerateKey(keyType string, permissions KeyPermissions) (common.Address, error) {
// Generate new private key
privateKey, err := crypto.GenerateKey()
if err != nil {
return common.Address{}, fmt.Errorf("failed to generate private key: %w", err)
}
address := crypto.PubkeyToAddress(privateKey.PublicKey)
// Encrypt the private key
encryptedKey, err := km.encryptPrivateKey(privateKey)
if err != nil {
return common.Address{}, fmt.Errorf("failed to encrypt private key: %w", err)
}
// Create secure key object
secureKey := &SecureKey{
Address: address,
EncryptedKey: encryptedKey,
CreatedAt: time.Now(),
LastUsedUnix: time.Now().Unix(),
UsageCount: 0,
KeyType: keyType,
Permissions: permissions,
IsActive: true, // Mark as active by default
}
// Set expiration for certain key types
if keyType == "emergency" {
expiresAt := time.Now().Add(30 * 24 * time.Hour) // 30 days
secureKey.ExpiresAt = &expiresAt
}
// Store the key
km.keysMutex.Lock()
km.keys[address] = secureKey
km.keysMutex.Unlock()
// Create backup
if err := km.createKeyBackup(secureKey); err != nil {
km.logger.Warn(fmt.Sprintf("Failed to create backup for key %s: %v", address.Hex(), err))
}
// Audit log
km.auditLog("KEY_GENERATED", address, true, fmt.Sprintf("Generated %s key", keyType))
km.logger.Info(fmt.Sprintf("Generated new %s key: %s", keyType, address.Hex()))
return address, nil
}
// ImportKey imports an existing private key
func (km *KeyManager) ImportKey(privateKeyHex string, keyType string, permissions KeyPermissions) (common.Address, error) {
// Parse private key
privateKey, err := crypto.HexToECDSA(privateKeyHex)
if err != nil {
return common.Address{}, fmt.Errorf("invalid private key: %w", err)
}
address := crypto.PubkeyToAddress(privateKey.PublicKey)
// Check if key already exists
km.keysMutex.RLock()
_, exists := km.keys[address]
km.keysMutex.RUnlock()
if exists {
return common.Address{}, fmt.Errorf("key already exists: %s", address.Hex())
}
// Encrypt the private key
encryptedKey, err := km.encryptPrivateKey(privateKey)
if err != nil {
return common.Address{}, fmt.Errorf("failed to encrypt private key: %w", err)
}
// Create secure key object
secureKey := &SecureKey{
Address: address,
EncryptedKey: encryptedKey,
CreatedAt: time.Now(),
LastUsedUnix: time.Now().Unix(),
UsageCount: 0,
KeyType: keyType,
Permissions: permissions,
IsActive: true, // Mark as active by default
}
// Store the key
km.keysMutex.Lock()
km.keys[address] = secureKey
km.keysMutex.Unlock()
// Create backup
if err := km.createKeyBackup(secureKey); err != nil {
km.logger.Warn(fmt.Sprintf("Failed to create backup for key %s: %v", address.Hex(), err))
}
// Audit log
km.auditLog("KEY_IMPORTED", address, true, fmt.Sprintf("Imported %s key", keyType))
km.logger.Info(fmt.Sprintf("Imported %s key: %s", keyType, address.Hex()))
return address, nil
}
// SignTransactionWithAuth signs a transaction with authentication and comprehensive security checks
func (km *KeyManager) SignTransactionWithAuth(request *SigningRequest, authContext *AuthenticationContext) (*SigningResult, error) {
// Validate authentication if required
if km.config.RequireAuthentication {
if authContext == nil {
return nil, fmt.Errorf("authentication required")
}
// Validate session
if _, err := km.ValidateSession(authContext.SessionID); err != nil {
return nil, fmt.Errorf("invalid session: %w", err)
}
// Check permissions
if !contains(authContext.Permissions, "transaction_signing") {
return nil, fmt.Errorf("insufficient permissions for transaction signing")
}
// Enhanced audit logging with auth context
km.auditLogWithAuth("SIGN_ATTEMPT", request.From, true,
fmt.Sprintf("Transaction signing attempted: %s", request.Purpose), authContext)
}
return km.SignTransaction(request)
}
// SignTransaction signs a transaction with comprehensive security checks
func (km *KeyManager) SignTransaction(request *SigningRequest) (*SigningResult, error) {
// Get the key
km.keysMutex.RLock()
secureKey, exists := km.keys[request.From]
km.keysMutex.RUnlock()
if !exists {
km.auditLog("SIGN_FAILED", request.From, false, "Key not found")
return nil, fmt.Errorf("key not found: %s", request.From.Hex())
}
// Security checks
warnings := make([]string, 0)
// Check permissions
if !secureKey.Permissions.CanSign {
km.auditLog("SIGN_FAILED", request.From, false, "Key not permitted to sign")
return nil, fmt.Errorf("key %s not permitted to sign transactions", request.From.Hex())
}
// Check expiration
if secureKey.ExpiresAt != nil && time.Now().After(*secureKey.ExpiresAt) {
km.auditLog("SIGN_FAILED", request.From, false, "Key expired")
return nil, fmt.Errorf("key %s has expired", request.From.Hex())
}
// Check usage limits (using atomic load for thread safety)
currentUsageCount := atomic.LoadInt64(&secureKey.UsageCount)
if secureKey.MaxUsage > 0 && currentUsageCount >= int64(secureKey.MaxUsage) {
km.auditLog("SIGN_FAILED", request.From, false, "Usage limit exceeded")
return nil, fmt.Errorf("key %s usage limit exceeded", request.From.Hex())
}
// Check transfer permissions and limits
if request.Transaction.Value().Sign() > 0 {
if !secureKey.Permissions.CanTransfer {
km.auditLog("SIGN_FAILED", request.From, false, "Transfer not permitted")
return nil, fmt.Errorf("key %s not permitted to transfer value", request.From.Hex())
}
if secureKey.Permissions.MaxTransferWei != nil &&
request.Transaction.Value().Cmp(secureKey.Permissions.MaxTransferWei) > 0 {
km.auditLog("SIGN_FAILED", request.From, false, "Transfer amount exceeds limit")
return nil, fmt.Errorf("transfer amount exceeds limit for key %s", request.From.Hex())
}
}
// Check contract interaction permissions
if request.Transaction.To() != nil {
contractAddr := request.Transaction.To().Hex()
if len(secureKey.Permissions.AllowedContracts) > 0 {
allowed := false
for _, allowedContract := range secureKey.Permissions.AllowedContracts {
if contractAddr == allowedContract {
allowed = true
break
}
}
if !allowed {
km.auditLog("SIGN_FAILED", request.From, false, "Contract interaction not permitted")
return nil, fmt.Errorf("key %s not permitted to interact with contract %s", request.From.Hex(), contractAddr)
}
}
}
// Rate limiting check
if err := km.checkRateLimit(request.From); err != nil {
km.auditLog("SIGN_FAILED", request.From, false, "Rate limit exceeded")
return nil, fmt.Errorf("rate limit exceeded: %w", err)
}
// Warning checks using atomic operations for thread safety
lastUsedUnix := atomic.LoadInt64(&secureKey.LastUsedUnix)
if lastUsedUnix > 0 && time.Since(time.Unix(lastUsedUnix, 0)) > 24*time.Hour {
warnings = append(warnings, "Key has not been used in over 24 hours")
}
usageCount := atomic.LoadInt64(&secureKey.UsageCount)
if usageCount > 1000 {
warnings = append(warnings, "Key has high usage count - consider rotation")
}
// Decrypt private key
privateKey, err := km.decryptPrivateKey(secureKey.EncryptedKey)
if err != nil {
km.auditLog("SIGN_FAILED", request.From, false, "Failed to decrypt private key")
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
}
defer func() {
// Clear private key from memory
if privateKey != nil {
clearPrivateKey(privateKey)
}
}()
// Sign the transaction
signer := types.NewEIP155Signer(request.ChainID)
signedTx, err := types.SignTx(request.Transaction, signer, privateKey)
if err != nil {
km.auditLog("SIGN_FAILED", request.From, false, "Transaction signing failed")
return nil, fmt.Errorf("failed to sign transaction: %w", err)
}
// Extract signature
v, r, s := signedTx.RawSignatureValues()
signature := make([]byte, 65)
r.FillBytes(signature[0:32])
s.FillBytes(signature[32:64])
signature[64] = byte(v.Uint64() - 35 - 2*request.ChainID.Uint64()) // Convert to recovery ID
// Update key usage with atomic operations for thread safety
now := time.Now()
atomic.StoreInt64(&secureKey.LastUsedUnix, now.Unix())
atomic.AddInt64(&secureKey.UsageCount, 1)
// Generate audit ID
auditID := generateAuditID()
// Create result
result := &SigningResult{
SignedTx: signedTx,
Signature: signature,
SignedAt: time.Now(),
KeyUsed: request.From,
AuditID: auditID,
Warnings: warnings,
}
// Audit log
km.auditLog("TRANSACTION_SIGNED", request.From, true,
fmt.Sprintf("Signed transaction %s for %s (audit: %s)",
signedTx.Hash().Hex(), request.Purpose, auditID))
return result, nil
}
// GetKeyInfo returns information about a key (without sensitive data)
func (km *KeyManager) GetKeyInfo(address common.Address) (*SecureKey, error) {
km.keysMutex.RLock()
defer km.keysMutex.RUnlock()
secureKey, exists := km.keys[address]
if !exists {
return nil, fmt.Errorf("key not found: %s", address.Hex())
}
// Return a copy without the encrypted key
info := SecureKey{
Address: secureKey.Address,
EncryptedKey: nil, // Intentionally exclude encrypted key
CreatedAt: secureKey.CreatedAt,
LastUsedUnix: secureKey.LastUsedUnix,
UsageCount: secureKey.UsageCount,
MaxUsage: secureKey.MaxUsage,
ExpiresAt: secureKey.ExpiresAt,
KeyVersion: secureKey.KeyVersion,
Salt: secureKey.Salt,
Nonce: secureKey.Nonce,
KeyType: secureKey.KeyType,
Permissions: secureKey.Permissions,
IsActive: secureKey.IsActive,
BackupLocations: secureKey.BackupLocations,
// Do not copy mutex field
}
return &info, nil
}
// ListKeys returns addresses of all managed keys
func (km *KeyManager) ListKeys() []common.Address {
km.keysMutex.RLock()
defer km.keysMutex.RUnlock()
addresses := make([]common.Address, 0, len(km.keys))
for address := range km.keys {
addresses = append(addresses, address)
}
return addresses
}
// RotateKey creates a new key to replace an existing one
func (km *KeyManager) RotateKey(oldAddress common.Address) (common.Address, error) {
km.keysMutex.RLock()
oldKey, exists := km.keys[oldAddress]
km.keysMutex.RUnlock()
if !exists {
return common.Address{}, fmt.Errorf("key not found: %s", oldAddress.Hex())
}
// Generate new key with same permissions
newAddress, err := km.GenerateKey(oldKey.KeyType, oldKey.Permissions)
if err != nil {
return common.Address{}, fmt.Errorf("failed to generate new key: %w", err)
}
// Mark old key as rotated (don't delete immediately for audit purposes)
km.keysMutex.Lock()
oldKey.Permissions.CanSign = false
oldKey.Permissions.CanTransfer = false
km.keysMutex.Unlock()
// Audit log
km.auditLog("KEY_ROTATED", oldAddress, true,
fmt.Sprintf("Rotated to new key %s", newAddress.Hex()))
km.logger.Info(fmt.Sprintf("Rotated key %s to %s", oldAddress.Hex(), newAddress.Hex()))
return newAddress, nil
}
// encryptPrivateKey encrypts a private key using AES-GCM
func (km *KeyManager) encryptPrivateKey(privateKey *ecdsa.PrivateKey) ([]byte, error) {
// Convert private key to bytes
keyBytes := crypto.FromECDSA(privateKey)
// Create AES cipher
block, err := aes.NewCipher(km.encryptionKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Generate nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt
ciphertext := gcm.Seal(nonce, nonce, keyBytes, nil)
// Clear original key bytes
for i := range keyBytes {
keyBytes[i] = 0
}
return ciphertext, nil
}
// decryptPrivateKey decrypts an encrypted private key
func (km *KeyManager) decryptPrivateKey(encryptedKey []byte) (*ecdsa.PrivateKey, error) {
// Create AES cipher
block, err := aes.NewCipher(km.encryptionKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Extract nonce
nonceSize := gcm.NonceSize()
if len(encryptedKey) < nonceSize {
return nil, fmt.Errorf("encrypted key too short")
}
nonce, ciphertext := encryptedKey[:nonceSize], encryptedKey[nonceSize:]
// Decrypt
keyBytes, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("decryption failed: %w", err)
}
defer func() {
// Clear decrypted bytes
for i := range keyBytes {
keyBytes[i] = 0
}
}()
// Convert to private key
privateKey, err := crypto.ToECDSA(keyBytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
return privateKey, nil
}
// createKeyBackup creates an encrypted backup of a key
func (km *KeyManager) createKeyBackup(secureKey *SecureKey) error {
if km.config.BackupPath == "" {
return nil // Backups not configured
}
backupFile := filepath.Join(km.config.BackupPath,
fmt.Sprintf("key_%s_%d.backup", secureKey.Address.Hex(), time.Now().Unix()))
// Create backup data
backupData := struct {
Address string `json:"address"`
EncryptedKey []byte `json:"encrypted_key"`
CreatedAt time.Time `json:"created_at"`
KeyType string `json:"key_type"`
}{
Address: secureKey.Address.Hex(),
EncryptedKey: secureKey.EncryptedKey,
CreatedAt: secureKey.CreatedAt,
KeyType: secureKey.KeyType,
}
// Additional encryption for backup
backupBytes, err := encryptBackupData(backupData, km.encryptionKey)
if err != nil {
return fmt.Errorf("failed to encrypt backup: %w", err)
}
// Write to file
if err := os.WriteFile(backupFile, backupBytes, 0600); err != nil {
return fmt.Errorf("failed to write backup file: %w", err)
}
// Update backup locations
secureKey.BackupLocations = append(secureKey.BackupLocations, backupFile)
return nil
}
// checkRateLimit checks if signing rate limit is exceeded
func (km *KeyManager) checkRateLimit(address common.Address) error {
if km.config.MaxSigningRate <= 0 {
return nil // Rate limiting disabled
}
// Track signing rates per key using a simple in-memory map
km.rateLimitMutex.Lock()
defer km.rateLimitMutex.Unlock()
now := time.Now()
key := address.Hex()
// Initialize rate limit tracking for this key if needed
if km.signingRates == nil {
km.signingRates = make(map[string]*SigningRateTracker)
}
if _, exists := km.signingRates[key]; !exists {
km.signingRates[key] = &SigningRateTracker{
Count: 0,
StartTime: now,
}
}
tracker := km.signingRates[key]
// Reset counter if more than a minute has passed
if now.Sub(tracker.StartTime) > time.Minute {
tracker.Count = 0
tracker.StartTime = now
}
// Increment counter
tracker.Count++
// Check if we've exceeded the rate limit
if tracker.Count > km.config.MaxSigningRate {
return fmt.Errorf("signing rate limit exceeded for key %s: %d/%d per minute",
address.Hex(), tracker.Count, km.config.MaxSigningRate)
}
return nil
}
// auditLog writes an entry to the audit log
func (km *KeyManager) auditLog(operation string, keyAddress common.Address, success bool, details string) {
entry := AuditEntry{
Timestamp: time.Now(),
Operation: operation,
KeyAddress: keyAddress,
Success: success,
Details: details,
RiskScore: calculateRiskScore(operation, success),
}
// Write to audit log
if km.config.AuditLogPath != "" {
// Implementation would write to audit log file
km.logger.Info(fmt.Sprintf("AUDIT: %s %s %v - %s (Risk: %.2f)",
entry.Operation, entry.KeyAddress.Hex(), entry.Success, entry.Details, float64(entry.RiskScore)))
}
}
// loadExistingKeys loads keys from the keystore
func (km *KeyManager) loadExistingKeys() error {
// Implementation would load existing keys from storage
// For now, just log that we're loading
km.logger.Info("Loading existing keys from keystore")
return nil
}
// backgroundTasks runs periodic maintenance tasks
func (km *KeyManager) backgroundTasks() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
km.performMaintenance()
}
}
}
// performMaintenance performs periodic security maintenance
func (km *KeyManager) performMaintenance() {
km.keysMutex.RLock()
defer km.keysMutex.RUnlock()
now := time.Now()
for address, key := range km.keys {
// Check for expired keys
if key.ExpiresAt != nil && now.After(*key.ExpiresAt) {
km.logger.Warn(fmt.Sprintf("Key %s has expired", address.Hex()))
}
// Check for keys that should be rotated
if km.config.KeyRotationDays > 0 {
rotationTime := time.Duration(km.config.KeyRotationDays) * 24 * time.Hour
if now.Sub(key.CreatedAt) > rotationTime {
km.logger.Warn(fmt.Sprintf("Key %s should be rotated (age: %v)",
address.Hex(), now.Sub(key.CreatedAt)))
}
}
}
}
// AuthenticateUser authenticates a user and creates a session
func (km *KeyManager) AuthenticateUser(userID, password, ipAddress, userAgent string) (*AuthenticationSession, error) {
km.authMutex.Lock()
defer km.authMutex.Unlock()
// Check IP whitelist if enabled
if km.config.EnableIPWhitelist {
km.ipWhitelistMutex.RLock()
allowed := km.whitelistedIPs[ipAddress]
km.ipWhitelistMutex.RUnlock()
if !allowed {
km.auditLog("AUTH_FAILED", common.Address{}, false,
fmt.Sprintf("IP not whitelisted: %s", ipAddress))
return nil, fmt.Errorf("access denied: IP address not whitelisted")
}
}
// Check for lockout
if lockoutEnd, locked := km.accessLockouts[userID]; locked {
if time.Now().Before(lockoutEnd) {
return nil, fmt.Errorf("account locked until %v", lockoutEnd)
}
// Clear expired lockout
delete(km.accessLockouts, userID)
delete(km.failedAccessAttempts, userID)
}
// Validate credentials (simplified - in production use proper password hashing)
if !km.validateCredentials(userID, password) {
// Track failed attempt
km.failedAccessAttempts[userID]++
if km.failedAccessAttempts[userID] >= km.maxFailedAttempts {
// Lock account
km.accessLockouts[userID] = time.Now().Add(km.lockoutDuration)
km.logger.Warn(fmt.Sprintf("Account locked for user %s due to failed attempts", userID))
}
km.auditLog("AUTH_FAILED", common.Address{}, false,
fmt.Sprintf("Invalid credentials for user %s", userID))
return nil, fmt.Errorf("invalid credentials")
}
// Clear failed attempts on successful login
delete(km.failedAccessAttempts, userID)
// Check concurrent session limit
km.sessionsMutex.Lock()
userSessions := 0
for _, session := range km.activeSessions {
if session.Context.UserID == userID && session.IsActive {
userSessions++
}
}
if userSessions >= km.maxConcurrentSessions {
km.sessionsMutex.Unlock()
return nil, fmt.Errorf("maximum concurrent sessions exceeded")
}
// Create new session
sessionID := generateSessionID()
context := &AuthenticationContext{
SessionID: sessionID,
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
AuthMethod: "password",
AuthTime: time.Now(),
ExpiresAt: time.Now().Add(km.sessionTimeout),
Permissions: []string{"key_access", "transaction_signing"},
RiskScore: calculateAuthRiskScore(ipAddress, userAgent),
}
session := &AuthenticationSession{
ID: sessionID,
Context: context,
CreatedAt: time.Now(),
LastActivity: time.Now(),
IsActive: true,
}
km.activeSessions[sessionID] = session
km.sessionsMutex.Unlock()
// Audit log
km.auditLog("USER_AUTHENTICATED", common.Address{}, true,
fmt.Sprintf("User %s authenticated from %s", userID, ipAddress))
km.logger.Info(fmt.Sprintf("User %s authenticated successfully", userID))
return session, nil
}
// ValidateSession validates an active session
func (km *KeyManager) ValidateSession(sessionID string) (*AuthenticationContext, error) {
km.sessionsMutex.RLock()
session, exists := km.activeSessions[sessionID]
km.sessionsMutex.RUnlock()
if !exists {
return nil, fmt.Errorf("session not found")
}
if !session.IsActive {
return nil, fmt.Errorf("session is inactive")
}
if time.Now().After(session.Context.ExpiresAt) {
// Session expired, deactivate it
km.sessionsMutex.Lock()
session.IsActive = false
km.sessionsMutex.Unlock()
km.auditLog("SESSION_EXPIRED", common.Address{}, false,
fmt.Sprintf("Session %s expired", sessionID))
return nil, fmt.Errorf("session expired")
}
// Update last activity
km.sessionsMutex.Lock()
session.LastActivity = time.Now()
km.sessionsMutex.Unlock()
return session.Context, nil
}
// GetActivePrivateKeyWithAuth returns the active private key for transaction signing with authentication
func (km *KeyManager) GetActivePrivateKeyWithAuth(authContext *AuthenticationContext) (*ecdsa.PrivateKey, error) {
// Validate authentication if required
if km.config.RequireAuthentication {
if authContext == nil {
return nil, fmt.Errorf("authentication required")
}
// Validate session
if _, err := km.ValidateSession(authContext.SessionID); err != nil {
return nil, fmt.Errorf("invalid session: %w", err)
}
// Check permissions
if !contains(authContext.Permissions, "key_access") {
return nil, fmt.Errorf("insufficient permissions for key access")
}
}
return km.GetActivePrivateKey()
}
// GetActivePrivateKey returns the active private key for transaction signing
func (km *KeyManager) GetActivePrivateKey() (*ecdsa.PrivateKey, error) {
// First, check for existing active keys
km.keysMutex.RLock()
for address, secureKey := range km.keys {
if secureKey.IsActive {
// Decrypt the private key
privateKey, err := km.decryptPrivateKey(secureKey.EncryptedKey)
if err != nil {
km.keysMutex.RUnlock()
km.auditLog("KEY_DECRYPTION_FAILED", address, false,
fmt.Sprintf("Failed to decrypt key: %v", err))
return nil, fmt.Errorf("failed to decrypt active key: %w", err)
}
km.keysMutex.RUnlock()
km.auditLog("KEY_ACCESSED", address, true, "Active private key retrieved")
return privateKey, nil
}
}
// Check if we need to generate a new key (no keys exist)
needsNewKey := len(km.keys) == 0
km.keysMutex.RUnlock()
// If no active key found and no keys exist, generate a default one
if needsNewKey {
km.logger.Info("No keys found, generating default trading key...")
// Generate a new key pair with default permissions
defaultPermissions := KeyPermissions{
CanSign: true,
CanTransfer: true,
MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH max per transaction
AllowedContracts: []string{}, // Will be populated with contract addresses
RequireConfirm: false,
}
km.logger.Info("Calling GenerateKey...")
address, err := km.GenerateKey("trading", defaultPermissions)
if err != nil {
km.logger.Error(fmt.Sprintf("Failed to generate default key: %v", err))
return nil, fmt.Errorf("failed to generate default key: %w", err)
}
km.logger.Info(fmt.Sprintf("Default key generated: %s", address.Hex()))
// Retrieve the newly generated key
km.keysMutex.RLock()
if secureKey, exists := km.keys[address]; exists {
privateKey, err := km.decryptPrivateKey(secureKey.EncryptedKey)
km.keysMutex.RUnlock()
if err != nil {
return nil, fmt.Errorf("failed to decrypt newly generated key: %w", err)
}
km.auditLog("KEY_AUTO_GENERATED", address, true, "Auto-generated active key")
return privateKey, nil
}
km.keysMutex.RUnlock()
}
return nil, fmt.Errorf("no active private key available")
}
// Helper functions
func getDefaultConfig() *KeyManagerConfig {
return &KeyManagerConfig{
KeystorePath: "./keystore",
EncryptionKey: "", // Will be set later or generated
KeyRotationDays: 90,
MaxSigningRate: 60, // 60 signings per minute
RequireHardware: false,
BackupPath: "./backups",
AuditLogPath: "./audit.log",
SessionTimeout: 15 * time.Minute,
RequireAuthentication: true,
EnableIPWhitelist: true,
WhitelistedIPs: []string{"127.0.0.1", "::1"}, // localhost only by default
MaxConcurrentSessions: 3,
RequireMFA: false,
PasswordHashRounds: 12,
MaxSessionAge: 24 * time.Hour,
EnableRateLimiting: true,
MaxAuthAttempts: 5,
AuthLockoutDuration: 30 * time.Minute,
MaxFailedAttempts: 3,
LockoutDuration: 15 * time.Minute,
}
}
func validateConfig(config *KeyManagerConfig) error {
if config.KeystorePath == "" {
return fmt.Errorf("keystore path cannot be empty")
}
if config.EncryptionKey == "" {
return fmt.Errorf("encryption key cannot be empty")
}
if len(config.EncryptionKey) < 32 {
return fmt.Errorf("encryption key must be at least 32 characters")
}
return nil
}
func deriveEncryptionKey(masterKey string) ([]byte, error) {
if masterKey == "" {
return nil, fmt.Errorf("master key cannot be empty")
}
// Generate secure random salt
salt := make([]byte, 32)
if _, err := rand.Read(salt); err != nil {
return nil, fmt.Errorf("failed to generate random salt: %w", err)
}
key, err := scrypt.Key([]byte(masterKey), salt, 32768, 8, 1, 32)
if err != nil {
return nil, fmt.Errorf("key derivation failed: %w", err)
}
return key, nil
}
// clearPrivateKey securely clears all private key material from memory
func clearPrivateKey(privateKey *ecdsa.PrivateKey) {
if privateKey == nil {
return
}
// Clear D parameter (private key scalar)
if privateKey.D != nil {
secureClearBigInt(privateKey.D)
privateKey.D = nil
}
// Clear PublicKey components for complete cleanup
if privateKey.PublicKey.X != nil {
secureClearBigInt(privateKey.PublicKey.X)
privateKey.PublicKey.X = nil
}
if privateKey.PublicKey.Y != nil {
secureClearBigInt(privateKey.PublicKey.Y)
privateKey.PublicKey.Y = nil
}
// Clear the curve reference
privateKey.PublicKey.Curve = nil
}
// secureClearBigInt securely clears a big.Int's underlying data
func secureClearBigInt(bi *big.Int) {
if bi == nil {
return
}
// Zero out the internal bits slice
for i := range bi.Bits() {
bi.Bits()[i] = 0
}
// Set to zero using multiple methods to ensure clearing
bi.SetInt64(0)
bi.SetBytes([]byte{})
// Additional clearing by setting to a new zero value
bi.Set(big.NewInt(0))
}
// secureClearBytes securely clears a byte slice
func secureClearBytes(data []byte) {
for i := range data {
data[i] = 0
}
// Force compiler to not optimize away the clearing
runtime.KeepAlive(data)
}
func generateAuditID() string {
bytes := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
// Fallback to current time if crypto/rand fails (shouldn't happen)
return fmt.Sprintf("%x", time.Now().UnixNano())
}
return hex.EncodeToString(bytes)
}
func calculateRiskScore(operation string, success bool) int {
if !success {
return 8 // Failed operations are high risk
}
switch operation {
case "TRANSACTION_SIGNED":
return 3
case "KEY_GENERATED", "KEY_IMPORTED":
return 5
case "KEY_ROTATED":
return 4
default:
return 2
}
}
// Logout invalidates a session
func (km *KeyManager) Logout(sessionID string) error {
km.sessionsMutex.Lock()
defer km.sessionsMutex.Unlock()
session, exists := km.activeSessions[sessionID]
if !exists {
return fmt.Errorf("session not found")
}
session.IsActive = false
km.auditLog("USER_LOGOUT", common.Address{}, true,
fmt.Sprintf("User %s logged out", session.Context.UserID))
return nil
}
// validateCredentials validates user credentials (simplified implementation)
func (km *KeyManager) validateCredentials(userID, password string) bool {
// In production, this should use proper password hashing (bcrypt, scrypt, etc.)
// For now, we'll use a simple hash comparison
expectedHash := hashPassword(password)
storredHash := km.getStoredPasswordHash(userID)
return subtle.ConstantTimeCompare([]byte(expectedHash), []byte(storredHash)) == 1
}
// getStoredPasswordHash retrieves stored password hash (simplified)
func (km *KeyManager) getStoredPasswordHash(userID string) string {
// In production, this would fetch from secure storage
// For development/testing, we'll use a default hash
if userID == "admin" {
return hashPassword("secure_admin_password_123")
}
return hashPassword("default_password")
}
// hashPassword creates a hash of the password
func hashPassword(password string) string {
hash := sha256.Sum256([]byte(password))
return hex.EncodeToString(hash[:])
}
// generateSessionID generates a secure session ID
func generateSessionID() string {
bytes := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
return fmt.Sprintf("%x", time.Now().UnixNano())
}
return hex.EncodeToString(bytes)
}
// calculateAuthRiskScore calculates risk score for authentication
func calculateAuthRiskScore(ipAddress, userAgent string) int {
riskScore := 1 // Base risk
// Increase risk for external IPs
if !strings.HasPrefix(ipAddress, "127.") && !strings.HasPrefix(ipAddress, "192.168.") && !strings.HasPrefix(ipAddress, "10.") {
riskScore += 3
}
// Increase risk for unknown user agents
if len(userAgent) < 10 || !strings.Contains(userAgent, "Mozilla") {
riskScore += 2
}
return riskScore
}
// contains checks if a slice contains a string
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// auditLogWithAuth writes an audit entry with authentication context
func (km *KeyManager) auditLogWithAuth(operation string, keyAddress common.Address, success bool, details string, authContext *AuthenticationContext) {
entry := AuditEntry{
Timestamp: time.Now(),
Operation: operation,
KeyAddress: keyAddress,
Success: success,
Details: details,
RiskScore: calculateRiskScore(operation, success),
}
if authContext != nil {
entry.IPAddress = authContext.IPAddress
entry.UserAgent = authContext.UserAgent
}
// Write to audit log
if km.config.AuditLogPath != "" {
km.logger.Info(fmt.Sprintf("AUDIT: %s %s %v - %s (Risk: %d) [User: %v]",
entry.Operation, entry.KeyAddress.Hex(), entry.Success, entry.Details, entry.RiskScore,
map[string]interface{}{"user_id": authContext.UserID, "session_id": authContext.SessionID}))
}
}
func encryptBackupData(data interface{}, key []byte) ([]byte, error) {
// Convert data to JSON bytes
jsonData, err := json.Marshal(data)
if err != nil {
return nil, fmt.Errorf("failed to marshal backup data: %w", err)
}
// Create AES cipher
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
// Create GCM mode for authenticated encryption
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM mode: %w", err)
}
// Generate random nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt and authenticate the data
ciphertext := gcm.Seal(nonce, nonce, jsonData, nil)
return ciphertext, nil
}
// validateProductionConfig validates production-specific security requirements
func validateProductionConfig(config *KeyManagerConfig) error {
// Check for encryption key presence
if config.EncryptionKey == "" {
return fmt.Errorf("MEV_BOT_ENCRYPTION_KEY environment variable is required for production")
}
// Check for test/default encryption keys
if strings.Contains(strings.ToLower(config.EncryptionKey), "test") ||
strings.Contains(strings.ToLower(config.EncryptionKey), "default") ||
strings.Contains(strings.ToLower(config.EncryptionKey), "example") {
return fmt.Errorf("production deployment cannot use test/default encryption keys")
}
// Validate encryption key strength
if len(config.EncryptionKey) < 32 {
return fmt.Errorf("encryption key must be at least 32 characters for production use")
}
// Check for weak encryption keys
if config.EncryptionKey == "test123" ||
config.EncryptionKey == "password" ||
config.EncryptionKey == "123456789012345678901234567890" ||
strings.Repeat("a", len(config.EncryptionKey)) == config.EncryptionKey {
return fmt.Errorf("encryption key is too weak for production use")
}
// Validate keystore path security
if config.KeystorePath != "" {
// Check that keystore path is not in a publicly accessible location
publicPaths := []string{"/tmp", "/var/tmp", "/home/public", "/usr/tmp"}
keystoreLower := strings.ToLower(config.KeystorePath)
for _, publicPath := range publicPaths {
if strings.HasPrefix(keystoreLower, publicPath) {
return fmt.Errorf("keystore path '%s' is in a publicly accessible location", config.KeystorePath)
}
}
}
// Validate backup path if specified
if config.BackupPath != "" {
if config.BackupPath == config.KeystorePath {
return fmt.Errorf("backup path cannot be the same as keystore path")
}
}
return nil
}