Files
mev-beta/pkg/security/config.go
2025-10-04 09:31:02 -05:00

404 lines
12 KiB
Go

package security
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"os"
"regexp"
"strconv"
"strings"
"time"
)
// SecureConfig manages all security-sensitive configuration
type SecureConfig struct {
// Network endpoints - never hardcoded
RPCEndpoints []string
WSEndpoints []string
BackupRPCs []string
// Security settings
MaxGasPriceGwei int64
MaxTransactionValue string // In ETH
MaxSlippageBps uint64
MinProfitThreshold string // In ETH
// Rate limiting
MaxRequestsPerSecond int
BurstSize int
// Timeouts
RPCTimeout time.Duration
WebSocketTimeout time.Duration
TransactionTimeout time.Duration
// Encryption
encryptionKey []byte
}
// SecurityLimits defines operational security limits
type SecurityLimits struct {
MaxGasPrice int64 // Gwei
MaxTransactionValue string // ETH
MaxDailyVolume string // ETH
MaxSlippage uint64 // basis points
MinProfit string // ETH
MaxOrderSize string // ETH
}
// EndpointConfig stores RPC endpoint configuration securely
type EndpointConfig struct {
URL string
Priority int
Timeout time.Duration
MaxConnections int
HealthCheckURL string
RequiresAuth bool
AuthToken string // Encrypted when stored
}
// NewSecureConfig creates a new secure configuration from environment
func NewSecureConfig() (*SecureConfig, error) {
config := &SecureConfig{}
// Load encryption key from environment
keyStr := os.Getenv("MEV_BOT_ENCRYPTION_KEY")
if keyStr == "" {
return nil, fmt.Errorf("MEV_BOT_ENCRYPTION_KEY environment variable is required")
}
key, err := base64.StdEncoding.DecodeString(keyStr)
if err != nil {
return nil, fmt.Errorf("invalid encryption key format: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("encryption key must be 32 bytes (256 bits)")
}
config.encryptionKey = key
// Load RPC endpoints
if err := config.loadRPCEndpoints(); err != nil {
return nil, fmt.Errorf("failed to load RPC endpoints: %w", err)
}
// Load security limits
if err := config.loadSecurityLimits(); err != nil {
return nil, fmt.Errorf("failed to load security limits: %w", err)
}
// Load rate limiting config
if err := config.loadRateLimits(); err != nil {
return nil, fmt.Errorf("failed to load rate limits: %w", err)
}
// Load timeouts
if err := config.loadTimeouts(); err != nil {
return nil, fmt.Errorf("failed to load timeouts: %w", err)
}
return config, nil
}
// loadRPCEndpoints loads and validates RPC endpoints from environment
func (sc *SecureConfig) loadRPCEndpoints() error {
// Primary RPC endpoints
rpcEndpoints := os.Getenv("ARBITRUM_RPC_ENDPOINTS")
if rpcEndpoints == "" {
return fmt.Errorf("ARBITRUM_RPC_ENDPOINTS environment variable is required")
}
sc.RPCEndpoints = strings.Split(rpcEndpoints, ",")
for i, endpoint := range sc.RPCEndpoints {
sc.RPCEndpoints[i] = strings.TrimSpace(endpoint)
if err := validateEndpoint(sc.RPCEndpoints[i]); err != nil {
return fmt.Errorf("invalid RPC endpoint %s: %w", sc.RPCEndpoints[i], err)
}
}
// WebSocket endpoints
wsEndpoints := os.Getenv("ARBITRUM_WS_ENDPOINTS")
if wsEndpoints != "" {
sc.WSEndpoints = strings.Split(wsEndpoints, ",")
for i, endpoint := range sc.WSEndpoints {
sc.WSEndpoints[i] = strings.TrimSpace(endpoint)
if err := validateWebSocketEndpoint(sc.WSEndpoints[i]); err != nil {
return fmt.Errorf("invalid WebSocket endpoint %s: %w", sc.WSEndpoints[i], err)
}
}
}
// Backup RPC endpoints
backupRPCs := os.Getenv("BACKUP_RPC_ENDPOINTS")
if backupRPCs != "" {
sc.BackupRPCs = strings.Split(backupRPCs, ",")
for i, endpoint := range sc.BackupRPCs {
sc.BackupRPCs[i] = strings.TrimSpace(endpoint)
if err := validateEndpoint(sc.BackupRPCs[i]); err != nil {
return fmt.Errorf("invalid backup RPC endpoint %s: %w", sc.BackupRPCs[i], err)
}
}
}
return nil
}
// loadSecurityLimits loads security limits from environment with safe defaults
func (sc *SecureConfig) loadSecurityLimits() error {
// Max gas price in Gwei (default: 1000 Gwei)
maxGasPriceStr := getEnvWithDefault("MAX_GAS_PRICE_GWEI", "1000")
maxGasPrice, err := strconv.ParseInt(maxGasPriceStr, 10, 64)
if err != nil || maxGasPrice <= 0 || maxGasPrice > 100000 {
return fmt.Errorf("invalid MAX_GAS_PRICE_GWEI: must be between 1 and 100000")
}
sc.MaxGasPriceGwei = maxGasPrice
// Max transaction value in ETH (default: 100 ETH)
sc.MaxTransactionValue = getEnvWithDefault("MAX_TRANSACTION_VALUE_ETH", "100")
if err := validateETHAmount(sc.MaxTransactionValue); err != nil {
return fmt.Errorf("invalid MAX_TRANSACTION_VALUE_ETH: %w", err)
}
// Max slippage in basis points (default: 500 = 5%)
maxSlippageStr := getEnvWithDefault("MAX_SLIPPAGE_BPS", "500")
maxSlippage, err := strconv.ParseUint(maxSlippageStr, 10, 64)
if err != nil || maxSlippage > 10000 {
return fmt.Errorf("invalid MAX_SLIPPAGE_BPS: must be between 0 and 10000")
}
sc.MaxSlippageBps = maxSlippage
// Min profit threshold in ETH (default: 0.01 ETH)
sc.MinProfitThreshold = getEnvWithDefault("MIN_PROFIT_THRESHOLD_ETH", "0.01")
if err := validateETHAmount(sc.MinProfitThreshold); err != nil {
return fmt.Errorf("invalid MIN_PROFIT_THRESHOLD_ETH: %w", err)
}
return nil
}
// loadRateLimits loads rate limiting configuration
func (sc *SecureConfig) loadRateLimits() error {
// Max requests per second (default: 100)
maxRPSStr := getEnvWithDefault("MAX_REQUESTS_PER_SECOND", "100")
maxRPS, err := strconv.Atoi(maxRPSStr)
if err != nil || maxRPS <= 0 || maxRPS > 10000 {
return fmt.Errorf("invalid MAX_REQUESTS_PER_SECOND: must be between 1 and 10000")
}
sc.MaxRequestsPerSecond = maxRPS
// Burst size (default: 200)
burstSizeStr := getEnvWithDefault("RATE_LIMIT_BURST_SIZE", "200")
burstSize, err := strconv.Atoi(burstSizeStr)
if err != nil || burstSize <= 0 || burstSize > 20000 {
return fmt.Errorf("invalid RATE_LIMIT_BURST_SIZE: must be between 1 and 20000")
}
sc.BurstSize = burstSize
return nil
}
// loadTimeouts loads timeout configuration
func (sc *SecureConfig) loadTimeouts() error {
// RPC timeout (default: 30s)
rpcTimeoutStr := getEnvWithDefault("RPC_TIMEOUT_SECONDS", "30")
rpcTimeout, err := strconv.Atoi(rpcTimeoutStr)
if err != nil || rpcTimeout <= 0 || rpcTimeout > 300 {
return fmt.Errorf("invalid RPC_TIMEOUT_SECONDS: must be between 1 and 300")
}
sc.RPCTimeout = time.Duration(rpcTimeout) * time.Second
// WebSocket timeout (default: 60s)
wsTimeoutStr := getEnvWithDefault("WEBSOCKET_TIMEOUT_SECONDS", "60")
wsTimeout, err := strconv.Atoi(wsTimeoutStr)
if err != nil || wsTimeout <= 0 || wsTimeout > 600 {
return fmt.Errorf("invalid WEBSOCKET_TIMEOUT_SECONDS: must be between 1 and 600")
}
sc.WebSocketTimeout = time.Duration(wsTimeout) * time.Second
// Transaction timeout (default: 300s)
txTimeoutStr := getEnvWithDefault("TRANSACTION_TIMEOUT_SECONDS", "300")
txTimeout, err := strconv.Atoi(txTimeoutStr)
if err != nil || txTimeout <= 0 || txTimeout > 3600 {
return fmt.Errorf("invalid TRANSACTION_TIMEOUT_SECONDS: must be between 1 and 3600")
}
sc.TransactionTimeout = time.Duration(txTimeout) * time.Second
return nil
}
// GetPrimaryRPCEndpoint returns the first healthy RPC endpoint
func (sc *SecureConfig) GetPrimaryRPCEndpoint() string {
if len(sc.RPCEndpoints) == 0 {
return ""
}
return sc.RPCEndpoints[0]
}
// GetAllRPCEndpoints returns all configured RPC endpoints
func (sc *SecureConfig) GetAllRPCEndpoints() []string {
return append(sc.RPCEndpoints, sc.BackupRPCs...)
}
// Encrypt encrypts sensitive data using AES-256-GCM
func (sc *SecureConfig) Encrypt(plaintext string) (string, error) {
if sc.encryptionKey == nil {
return "", fmt.Errorf("encryption key not initialized")
}
block, err := aes.NewCipher(sc.encryptionKey)
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts data encrypted with Encrypt
func (sc *SecureConfig) Decrypt(ciphertext string) (string, error) {
if sc.encryptionKey == nil {
return "", fmt.Errorf("encryption key not initialized")
}
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
}
block, err := aes.NewCipher(sc.encryptionKey)
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil)
if err != nil {
return "", fmt.Errorf("failed to decrypt: %w", err)
}
return string(plaintext), nil
}
// GenerateEncryptionKey generates a new 256-bit encryption key
func GenerateEncryptionKey() (string, error) {
key := make([]byte, 32) // 256 bits
if _, err := rand.Read(key); err != nil {
return "", fmt.Errorf("failed to generate encryption key: %w", err)
}
return base64.StdEncoding.EncodeToString(key), nil
}
// validateEndpoint validates RPC endpoint URL format
func validateEndpoint(endpoint string) error {
if endpoint == "" {
return fmt.Errorf("endpoint cannot be empty")
}
// Check for required protocols
if !strings.HasPrefix(endpoint, "https://") && !strings.HasPrefix(endpoint, "wss://") {
return fmt.Errorf("endpoint must use HTTPS or WSS protocol")
}
// Check for suspicious patterns that might indicate hardcoded keys
suspiciousPatterns := []string{
"localhost",
"127.0.0.1",
"demo",
"test",
"example",
}
lowerEndpoint := strings.ToLower(endpoint)
for _, pattern := range suspiciousPatterns {
if strings.Contains(lowerEndpoint, pattern) {
return fmt.Errorf("endpoint contains suspicious pattern: %s", pattern)
}
}
return nil
}
// validateWebSocketEndpoint validates WebSocket endpoint
func validateWebSocketEndpoint(endpoint string) error {
if !strings.HasPrefix(endpoint, "wss://") {
return fmt.Errorf("WebSocket endpoint must use WSS protocol")
}
return validateEndpoint(endpoint)
}
// validateETHAmount validates ETH amount string
func validateETHAmount(amount string) error {
// Use regex to validate ETH amount format
ethPattern := `^(\d+\.?\d*|\.\d+)$`
matched, err := regexp.MatchString(ethPattern, amount)
if err != nil {
return fmt.Errorf("regex error: %w", err)
}
if !matched {
return fmt.Errorf("invalid ETH amount format")
}
return nil
}
// getEnvWithDefault gets environment variable with fallback default
func getEnvWithDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// CreateConfigHash creates a SHA256 hash of configuration for integrity checking
func (sc *SecureConfig) CreateConfigHash() string {
hasher := sha256.New()
hasher.Write([]byte(fmt.Sprintf("%v", sc.RPCEndpoints)))
hasher.Write([]byte(fmt.Sprintf("%d", sc.MaxGasPriceGwei)))
hasher.Write([]byte(sc.MaxTransactionValue))
hasher.Write([]byte(fmt.Sprintf("%d", sc.MaxSlippageBps)))
return fmt.Sprintf("%x", hasher.Sum(nil))
}
// SecurityProfile returns current security configuration summary
func (sc *SecureConfig) SecurityProfile() map[string]interface{} {
return map[string]interface{}{
"max_gas_price_gwei": sc.MaxGasPriceGwei,
"max_transaction_value": sc.MaxTransactionValue,
"max_slippage_bps": sc.MaxSlippageBps,
"min_profit_threshold": sc.MinProfitThreshold,
"max_requests_per_second": sc.MaxRequestsPerSecond,
"rpc_timeout": sc.RPCTimeout.String(),
"websocket_timeout": sc.WebSocketTimeout.String(),
"transaction_timeout": sc.TransactionTimeout.String(),
"rpc_endpoints_count": len(sc.RPCEndpoints),
"backup_rpcs_count": len(sc.BackupRPCs),
"config_hash": sc.CreateConfigHash(),
}
}