saving in place

This commit is contained in:
Krypto Kajun
2025-10-04 09:31:02 -05:00
parent 76c1b5cee1
commit f358f49aa9
295 changed files with 72071 additions and 17209 deletions

View File

@@ -30,11 +30,11 @@ type SecureConfig struct {
// Rate limiting
MaxRequestsPerSecond int
BurstSize int
BurstSize int
// Timeouts
RPCTimeout time.Duration
WebSocketTimeout time.Duration
RPCTimeout time.Duration
WebSocketTimeout time.Duration
TransactionTimeout time.Duration
// Encryption
@@ -53,13 +53,13 @@ type SecurityLimits struct {
// EndpointConfig stores RPC endpoint configuration securely
type EndpointConfig struct {
URL string
Priority int
Timeout time.Duration
MaxConnections int
HealthCheckURL string
RequiresAuth bool
AuthToken string // Encrypted when stored
URL string
Priority int
Timeout time.Duration
MaxConnections int
HealthCheckURL string
RequiresAuth bool
AuthToken string // Encrypted when stored
}
// NewSecureConfig creates a new secure configuration from environment
@@ -390,14 +390,14 @@ func (sc *SecureConfig) SecurityProfile() map[string]interface{} {
return map[string]interface{}{
"max_gas_price_gwei": sc.MaxGasPriceGwei,
"max_transaction_value": sc.MaxTransactionValue,
"max_slippage_bps": sc.MaxSlippageBps,
"min_profit_threshold": sc.MinProfitThreshold,
"max_slippage_bps": sc.MaxSlippageBps,
"min_profit_threshold": sc.MinProfitThreshold,
"max_requests_per_second": sc.MaxRequestsPerSecond,
"rpc_timeout": sc.RPCTimeout.String(),
"websocket_timeout": sc.WebSocketTimeout.String(),
"transaction_timeout": sc.TransactionTimeout.String(),
"rpc_endpoints_count": len(sc.RPCEndpoints),
"backup_rpcs_count": len(sc.BackupRPCs),
"config_hash": sc.CreateConfigHash(),
"rpc_timeout": sc.RPCTimeout.String(),
"websocket_timeout": sc.WebSocketTimeout.String(),
"transaction_timeout": sc.TransactionTimeout.String(),
"rpc_endpoints_count": len(sc.RPCEndpoints),
"backup_rpcs_count": len(sc.BackupRPCs),
"config_hash": sc.CreateConfigHash(),
}
}
}

View File

@@ -0,0 +1,563 @@
package security
import (
"context"
"encoding/hex"
"fmt"
"math/big"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/fraktal/mev-beta/internal/logger"
)
// ContractInfo represents information about a verified contract
type ContractInfo struct {
Address common.Address `json:"address"`
BytecodeHash string `json:"bytecode_hash"`
Name string `json:"name"`
Version string `json:"version"`
DeployedAt *big.Int `json:"deployed_at"`
Deployer common.Address `json:"deployer"`
VerifiedAt time.Time `json:"verified_at"`
IsWhitelisted bool `json:"is_whitelisted"`
RiskLevel RiskLevel `json:"risk_level"`
Permissions ContractPermissions `json:"permissions"`
ABIHash string `json:"abi_hash,omitempty"`
SourceCodeHash string `json:"source_code_hash,omitempty"`
}
// ContractPermissions defines what operations are allowed with a contract
type ContractPermissions struct {
CanInteract bool `json:"can_interact"`
CanSendValue bool `json:"can_send_value"`
MaxValueWei *big.Int `json:"max_value_wei,omitempty"`
AllowedMethods []string `json:"allowed_methods,omitempty"`
RequireConfirm bool `json:"require_confirmation"`
DailyLimit *big.Int `json:"daily_limit,omitempty"`
}
// RiskLevel represents the risk assessment of a contract
type RiskLevel int
const (
RiskLevelLow RiskLevel = iota
RiskLevelMedium
RiskLevelHigh
RiskLevelCritical
RiskLevelBlocked
)
func (r RiskLevel) String() string {
switch r {
case RiskLevelLow:
return "Low"
case RiskLevelMedium:
return "Medium"
case RiskLevelHigh:
return "High"
case RiskLevelCritical:
return "Critical"
case RiskLevelBlocked:
return "Blocked"
default:
return "Unknown"
}
}
// ContractValidationResult contains the result of contract validation
type ContractValidationResult struct {
IsValid bool `json:"is_valid"`
ContractInfo *ContractInfo `json:"contract_info"`
ValidationError string `json:"validation_error,omitempty"`
Warnings []string `json:"warnings"`
ChecksPerformed []ValidationCheck `json:"checks_performed"`
RiskScore int `json:"risk_score"` // 1-10
}
// ValidationCheck represents a single validation check
type ValidationCheck struct {
Name string `json:"name"`
Passed bool `json:"passed"`
Description string `json:"description"`
Error string `json:"error,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
// ContractValidator provides secure contract validation and verification
type ContractValidator struct {
client *ethclient.Client
logger *logger.Logger
trustedContracts map[common.Address]*ContractInfo
contractCache map[common.Address]*ContractInfo
cacheMutex sync.RWMutex
config *ContractValidatorConfig
// Security tracking
interactionCounts map[common.Address]int64
dailyLimits map[common.Address]*big.Int
lastResetTime time.Time
limitsMutex sync.RWMutex
}
// ContractValidatorConfig provides configuration for the contract validator
type ContractValidatorConfig struct {
EnableBytecodeVerification bool `json:"enable_bytecode_verification"`
EnableABIValidation bool `json:"enable_abi_validation"`
RequireWhitelist bool `json:"require_whitelist"`
MaxBytecodeSize int `json:"max_bytecode_size"`
CacheTimeout time.Duration `json:"cache_timeout"`
MaxRiskScore int `json:"max_risk_score"`
BlockUnverifiedContracts bool `json:"block_unverified_contracts"`
RequireSourceCode bool `json:"require_source_code"`
EnableRealTimeValidation bool `json:"enable_realtime_validation"`
}
// NewContractValidator creates a new contract validator
func NewContractValidator(client *ethclient.Client, logger *logger.Logger, config *ContractValidatorConfig) *ContractValidator {
if config == nil {
config = getDefaultValidatorConfig()
}
return &ContractValidator{
client: client,
logger: logger,
config: config,
trustedContracts: make(map[common.Address]*ContractInfo),
contractCache: make(map[common.Address]*ContractInfo),
interactionCounts: make(map[common.Address]int64),
dailyLimits: make(map[common.Address]*big.Int),
lastResetTime: time.Now(),
}
}
// AddTrustedContract adds a contract to the trusted list
func (cv *ContractValidator) AddTrustedContract(info *ContractInfo) error {
cv.cacheMutex.Lock()
defer cv.cacheMutex.Unlock()
// Validate the contract info
if info.Address == (common.Address{}) {
return fmt.Errorf("invalid contract address")
}
if info.BytecodeHash == "" {
return fmt.Errorf("bytecode hash is required")
}
// Mark as whitelisted and set low risk
info.IsWhitelisted = true
if info.RiskLevel == 0 {
info.RiskLevel = RiskLevelLow
}
info.VerifiedAt = time.Now()
cv.trustedContracts[info.Address] = info
cv.contractCache[info.Address] = info
cv.logger.Info(fmt.Sprintf("Added trusted contract: %s (%s)", info.Address.Hex(), info.Name))
return nil
}
// ValidateContract performs comprehensive contract validation
func (cv *ContractValidator) ValidateContract(ctx context.Context, address common.Address) (*ContractValidationResult, error) {
result := &ContractValidationResult{
IsValid: false,
Warnings: make([]string, 0),
ChecksPerformed: make([]ValidationCheck, 0),
}
// Check if contract is in trusted list first
cv.cacheMutex.RLock()
if trusted, exists := cv.trustedContracts[address]; exists {
cv.cacheMutex.RUnlock()
result.IsValid = true
result.ContractInfo = trusted
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Trusted Contract Check",
Passed: true,
Description: "Contract found in trusted whitelist",
Timestamp: time.Now(),
})
return result, nil
}
// Check cache
if cached, exists := cv.contractCache[address]; exists {
if time.Since(cached.VerifiedAt) < cv.config.CacheTimeout {
cv.cacheMutex.RUnlock()
result.IsValid = true
result.ContractInfo = cached
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Cache Check",
Passed: true,
Description: "Contract found in validation cache",
Timestamp: time.Now(),
})
return result, nil
}
}
cv.cacheMutex.RUnlock()
// Perform real-time validation
contractInfo, err := cv.validateContractOnChain(ctx, address, result)
if err != nil {
result.ValidationError = err.Error()
return result, err
}
result.ContractInfo = contractInfo
result.RiskScore = cv.calculateRiskScore(contractInfo, result)
// Check if contract meets security requirements
if cv.config.RequireWhitelist && !contractInfo.IsWhitelisted {
result.ValidationError = "Contract not whitelisted"
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Whitelist Check",
Passed: false,
Description: "Contract not found in whitelist",
Error: "Contract not whitelisted",
Timestamp: time.Now(),
})
return result, fmt.Errorf("contract not whitelisted: %s", address.Hex())
}
if result.RiskScore > cv.config.MaxRiskScore {
result.ValidationError = fmt.Sprintf("Risk score too high: %d > %d", result.RiskScore, cv.config.MaxRiskScore)
return result, fmt.Errorf("contract risk score too high: %d", result.RiskScore)
}
// Cache the validation result
cv.cacheMutex.Lock()
cv.contractCache[address] = contractInfo
cv.cacheMutex.Unlock()
result.IsValid = true
return result, nil
}
// validateContractOnChain performs on-chain validation of a contract
func (cv *ContractValidator) validateContractOnChain(ctx context.Context, address common.Address, result *ContractValidationResult) (*ContractInfo, error) {
// Check if address is a contract
bytecode, err := cv.client.CodeAt(ctx, address, nil)
if err != nil {
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Bytecode Retrieval",
Passed: false,
Description: "Failed to retrieve contract bytecode",
Error: err.Error(),
Timestamp: time.Now(),
})
return nil, fmt.Errorf("failed to get contract bytecode: %w", err)
}
if len(bytecode) == 0 {
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Contract Existence",
Passed: false,
Description: "Address is not a contract (no bytecode)",
Error: "No bytecode found",
Timestamp: time.Now(),
})
return nil, fmt.Errorf("address is not a contract: %s", address.Hex())
}
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Contract Existence",
Passed: true,
Description: fmt.Sprintf("Contract bytecode found (%d bytes)", len(bytecode)),
Timestamp: time.Now(),
})
// Validate bytecode size
if cv.config.MaxBytecodeSize > 0 && len(bytecode) > cv.config.MaxBytecodeSize {
result.Warnings = append(result.Warnings, fmt.Sprintf("Large bytecode size: %d bytes", len(bytecode)))
}
// Create bytecode hash
bytecodeHash := crypto.Keccak256Hash(bytecode).Hex()
// Get deployment transaction info
deployedAt, deployer, err := cv.getDeploymentInfo(ctx, address)
if err != nil {
cv.logger.Warn(fmt.Sprintf("Could not retrieve deployment info for %s: %v", address.Hex(), err))
deployedAt = big.NewInt(0)
deployer = common.Address{}
}
// Create contract info
contractInfo := &ContractInfo{
Address: address,
BytecodeHash: bytecodeHash,
Name: "Unknown Contract",
Version: "unknown",
DeployedAt: deployedAt,
Deployer: deployer,
VerifiedAt: time.Now(),
IsWhitelisted: false,
RiskLevel: cv.assessRiskLevel(bytecode, result),
Permissions: cv.getDefaultPermissions(),
}
// Verify bytecode against known contracts if enabled
if cv.config.EnableBytecodeVerification {
cv.verifyBytecodeSignature(bytecode, contractInfo, result)
}
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Bytecode Validation",
Passed: true,
Description: "Bytecode hash calculated and verified",
Timestamp: time.Now(),
})
return contractInfo, nil
}
// getDeploymentInfo retrieves deployment information for a contract
func (cv *ContractValidator) getDeploymentInfo(ctx context.Context, address common.Address) (*big.Int, common.Address, error) {
// This is a simplified implementation
// In production, you would need to scan blocks or use an indexer
return big.NewInt(0), common.Address{}, fmt.Errorf("deployment info not available")
}
// assessRiskLevel assesses the risk level of a contract based on its bytecode
func (cv *ContractValidator) assessRiskLevel(bytecode []byte, result *ContractValidationResult) RiskLevel {
riskFactors := 0
// Check for suspicious patterns in bytecode
bytecodeStr := hex.EncodeToString(bytecode)
// Look for dangerous opcodes
dangerousOpcodes := []string{
"ff", // SELFDESTRUCT
"f4", // DELEGATECALL
"3d", // RETURNDATASIZE (often used in proxy patterns)
}
for _, opcode := range dangerousOpcodes {
if contains := func(haystack, needle string) bool {
return len(haystack) >= len(needle) && haystack[:len(needle)] == needle ||
len(haystack) > len(needle) && haystack[len(haystack)-len(needle):] == needle
}; contains(bytecodeStr, opcode) {
riskFactors++
}
}
// Check bytecode size (larger contracts may be more complex/risky)
if len(bytecode) > 20000 { // 20KB
riskFactors++
result.Warnings = append(result.Warnings, "Large contract size detected")
}
// Assess risk level based on factors
switch {
case riskFactors == 0:
return RiskLevelLow
case riskFactors <= 2:
return RiskLevelMedium
case riskFactors <= 4:
return RiskLevelHigh
default:
return RiskLevelCritical
}
}
// verifyBytecodeSignature verifies bytecode against known contract signatures
func (cv *ContractValidator) verifyBytecodeSignature(bytecode []byte, info *ContractInfo, result *ContractValidationResult) {
// Known contract bytecode hashes for common contracts
knownContracts := map[string]string{
// Uniswap V3 Factory
"0x1f98431c8ad98523631ae4a59f267346ea31f984": "uniswap_v3_factory",
// Uniswap V3 Router
"0xe592427a0aece92de3edee1f18e0157c05861564": "uniswap_v3_router",
// Add more known contracts...
}
addressStr := info.Address.Hex()
if name, exists := knownContracts[addressStr]; exists {
info.Name = name
info.IsWhitelisted = true
info.RiskLevel = RiskLevelLow
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Known Contract Verification",
Passed: true,
Description: fmt.Sprintf("Verified as known contract: %s", name),
Timestamp: time.Now(),
})
}
}
// calculateRiskScore calculates a numerical risk score (1-10)
func (cv *ContractValidator) calculateRiskScore(info *ContractInfo, result *ContractValidationResult) int {
score := 1 // Base score
// Adjust based on risk level
switch info.RiskLevel {
case RiskLevelLow:
score += 0
case RiskLevelMedium:
score += 2
case RiskLevelHigh:
score += 5
case RiskLevelCritical:
score += 8
case RiskLevelBlocked:
score = 10
}
// Adjust based on whitelist status
if !info.IsWhitelisted {
score += 2
}
// Adjust based on warnings
score += len(result.Warnings)
// Cap at 10
if score > 10 {
score = 10
}
return score
}
// getDefaultPermissions returns default permissions for unverified contracts
func (cv *ContractValidator) getDefaultPermissions() ContractPermissions {
return ContractPermissions{
CanInteract: true,
CanSendValue: false,
MaxValueWei: big.NewInt(0),
AllowedMethods: []string{}, // Empty means all methods allowed
RequireConfirm: true,
DailyLimit: big.NewInt(1000000000000000000), // 1 ETH
}
}
// ValidateTransaction validates a transaction against contract permissions
func (cv *ContractValidator) ValidateTransaction(ctx context.Context, tx *types.Transaction) error {
if tx.To() == nil {
return nil // Contract creation, allow
}
// Validate the contract
result, err := cv.ValidateContract(ctx, *tx.To())
if err != nil {
return fmt.Errorf("contract validation failed: %w", err)
}
if !result.IsValid {
return fmt.Errorf("transaction to invalid contract: %s", tx.To().Hex())
}
// Check permissions
permissions := result.ContractInfo.Permissions
// Check value transfer permission
if tx.Value().Sign() > 0 && !permissions.CanSendValue {
return fmt.Errorf("contract does not allow value transfers: %s", tx.To().Hex())
}
// Check value limits
if permissions.MaxValueWei != nil && tx.Value().Cmp(permissions.MaxValueWei) > 0 {
return fmt.Errorf("transaction value exceeds limit: %s > %s",
tx.Value().String(), permissions.MaxValueWei.String())
}
// Check daily limits
if err := cv.checkDailyLimit(*tx.To(), tx.Value()); err != nil {
return err
}
cv.logger.Debug(fmt.Sprintf("Transaction validated for contract %s", tx.To().Hex()))
return nil
}
// checkDailyLimit checks if transaction exceeds daily interaction limit
func (cv *ContractValidator) checkDailyLimit(contractAddr common.Address, value *big.Int) error {
cv.limitsMutex.Lock()
defer cv.limitsMutex.Unlock()
// Reset daily counters if needed
if time.Since(cv.lastResetTime) > 24*time.Hour {
cv.dailyLimits = make(map[common.Address]*big.Int)
cv.lastResetTime = time.Now()
}
// Get current daily usage
currentUsage, exists := cv.dailyLimits[contractAddr]
if !exists {
currentUsage = big.NewInt(0)
cv.dailyLimits[contractAddr] = currentUsage
}
// Get contract info for daily limit
cv.cacheMutex.RLock()
contractInfo, exists := cv.contractCache[contractAddr]
cv.cacheMutex.RUnlock()
if !exists {
return nil // No limit if contract not cached
}
if contractInfo.Permissions.DailyLimit == nil {
return nil // No daily limit set
}
// Check if adding this transaction would exceed limit
newUsage := new(big.Int).Add(currentUsage, value)
if newUsage.Cmp(contractInfo.Permissions.DailyLimit) > 0 {
return fmt.Errorf("daily limit exceeded for contract %s: %s + %s > %s",
contractAddr.Hex(),
currentUsage.String(),
value.String(),
contractInfo.Permissions.DailyLimit.String())
}
// Update usage
cv.dailyLimits[contractAddr] = newUsage
return nil
}
// getDefaultValidatorConfig returns default configuration
func getDefaultValidatorConfig() *ContractValidatorConfig {
return &ContractValidatorConfig{
EnableBytecodeVerification: true,
EnableABIValidation: false, // Requires additional infrastructure
RequireWhitelist: false, // Start permissive, can be tightened
MaxBytecodeSize: 50000, // 50KB
CacheTimeout: 1 * time.Hour,
MaxRiskScore: 7, // Allow medium-high risk
BlockUnverifiedContracts: false,
RequireSourceCode: false,
EnableRealTimeValidation: true,
}
}
// GetContractInfo returns information about a validated contract
func (cv *ContractValidator) GetContractInfo(address common.Address) (*ContractInfo, bool) {
cv.cacheMutex.RLock()
defer cv.cacheMutex.RUnlock()
if info, exists := cv.contractCache[address]; exists {
return info, true
}
return nil, false
}
// ListTrustedContracts returns all trusted contracts
func (cv *ContractValidator) ListTrustedContracts() map[common.Address]*ContractInfo {
cv.cacheMutex.RLock()
defer cv.cacheMutex.RUnlock()
// Create a copy to avoid race conditions
trusted := make(map[common.Address]*ContractInfo)
for addr, info := range cv.trustedContracts {
trusted[addr] = info
}
return trusted
}

View File

@@ -0,0 +1,348 @@
package security
import (
"context"
"fmt"
"regexp"
"runtime"
"strings"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// SecureError represents a security-aware error with context
type SecureError struct {
Code string `json:"code"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
Context map[string]interface{} `json:"context,omitempty"`
Stack []StackFrame `json:"stack,omitempty"`
Wrapped error `json:"-"`
Sensitive bool `json:"sensitive"`
Category ErrorCategory `json:"category"`
Severity ErrorSeverity `json:"severity"`
}
// StackFrame represents a single frame in the call stack
type StackFrame struct {
Function string `json:"function"`
File string `json:"file"`
Line int `json:"line"`
}
// ErrorCategory defines categories of errors
type ErrorCategory string
const (
ErrorCategoryAuthentication ErrorCategory = "authentication"
ErrorCategoryAuthorization ErrorCategory = "authorization"
ErrorCategoryValidation ErrorCategory = "validation"
ErrorCategoryRateLimit ErrorCategory = "rate_limit"
ErrorCategoryCircuitBreaker ErrorCategory = "circuit_breaker"
ErrorCategoryEncryption ErrorCategory = "encryption"
ErrorCategoryNetwork ErrorCategory = "network"
ErrorCategoryTransaction ErrorCategory = "transaction"
ErrorCategoryInternal ErrorCategory = "internal"
)
// ErrorSeverity defines error severity levels
type ErrorSeverity string
const (
ErrorSeverityLow ErrorSeverity = "low"
ErrorSeverityMedium ErrorSeverity = "medium"
ErrorSeverityHigh ErrorSeverity = "high"
ErrorSeverityCritical ErrorSeverity = "critical"
)
// ErrorHandler provides secure error handling with context preservation
type ErrorHandler struct {
enableStackTrace bool
sensitiveFields map[string]bool
errorMetrics *ErrorMetrics
logger *logger.Logger
}
// ErrorMetrics tracks error statistics
type ErrorMetrics struct {
TotalErrors int64 `json:"total_errors"`
ErrorsByCategory map[ErrorCategory]int64 `json:"errors_by_category"`
ErrorsBySeverity map[ErrorSeverity]int64 `json:"errors_by_severity"`
SensitiveDataLeaks int64 `json:"sensitive_data_leaks"`
}
// NewErrorHandler creates a new secure error handler
func NewErrorHandler(enableStackTrace bool) *ErrorHandler {
return &ErrorHandler{
enableStackTrace: enableStackTrace,
sensitiveFields: map[string]bool{
"password": true,
"private_key": true,
"secret": true,
"token": true,
"seed": true,
"mnemonic": true,
"api_key": true,
"private": true,
},
errorMetrics: &ErrorMetrics{
ErrorsByCategory: make(map[ErrorCategory]int64),
ErrorsBySeverity: make(map[ErrorSeverity]int64),
},
logger: logger.New("info", "json", "logs/errors.log"),
}
}
// WrapError wraps an error with security context
func (eh *ErrorHandler) WrapError(err error, code string, message string, category ErrorCategory, severity ErrorSeverity) *SecureError {
if err == nil {
return nil
}
secureErr := &SecureError{
Code: code,
Message: message,
Timestamp: time.Now(),
Wrapped: err,
Category: category,
Severity: severity,
Context: make(map[string]interface{}),
}
// Capture stack trace if enabled
if eh.enableStackTrace {
secureErr.Stack = eh.captureStackTrace()
}
// Check for sensitive data
secureErr.Sensitive = eh.containsSensitiveData(err.Error()) || eh.containsSensitiveData(message)
// Update metrics
eh.updateMetrics(secureErr)
// Log error appropriately
eh.logError(secureErr)
return secureErr
}
// WrapErrorWithContext wraps an error with additional context
func (eh *ErrorHandler) WrapErrorWithContext(ctx context.Context, err error, code string, message string, category ErrorCategory, severity ErrorSeverity, context map[string]interface{}) *SecureError {
secureErr := eh.WrapError(err, code, message, category, severity)
if secureErr == nil {
return nil
}
// Add context while sanitizing sensitive data
for key, value := range context {
if !eh.isSensitiveField(key) {
secureErr.Context[key] = value
} else {
secureErr.Context[key] = "[REDACTED]"
secureErr.Sensitive = true
}
}
// Add request context if available
if ctx != nil {
if requestID := ctx.Value("request_id"); requestID != nil {
secureErr.Context["request_id"] = requestID
}
if userID := ctx.Value("user_id"); userID != nil {
secureErr.Context["user_id"] = userID
}
if sessionID := ctx.Value("session_id"); sessionID != nil {
secureErr.Context["session_id"] = sessionID
}
}
return secureErr
}
// Error implements the error interface
func (se *SecureError) Error() string {
if se.Sensitive {
return fmt.Sprintf("[%s] %s (sensitive data redacted)", se.Code, se.Message)
}
return fmt.Sprintf("[%s] %s", se.Code, se.Message)
}
// Unwrap returns the wrapped error
func (se *SecureError) Unwrap() error {
return se.Wrapped
}
// SafeString returns a safe string representation without sensitive data
func (se *SecureError) SafeString() string {
if se.Sensitive {
return fmt.Sprintf("Error: %s (details redacted for security)", se.Message)
}
return se.Error()
}
// DetailedString returns detailed error information for internal logging
func (se *SecureError) DetailedString() string {
var parts []string
parts = append(parts, fmt.Sprintf("Code: %s", se.Code))
parts = append(parts, fmt.Sprintf("Message: %s", se.Message))
parts = append(parts, fmt.Sprintf("Category: %s", se.Category))
parts = append(parts, fmt.Sprintf("Severity: %s", se.Severity))
parts = append(parts, fmt.Sprintf("Timestamp: %s", se.Timestamp.Format(time.RFC3339)))
if len(se.Context) > 0 {
parts = append(parts, fmt.Sprintf("Context: %+v", se.Context))
}
if se.Wrapped != nil {
parts = append(parts, fmt.Sprintf("Wrapped: %s", se.Wrapped.Error()))
}
return strings.Join(parts, ", ")
}
// captureStackTrace captures the current call stack
func (eh *ErrorHandler) captureStackTrace() []StackFrame {
var frames []StackFrame
// Skip the first few frames (this function and WrapError)
for i := 3; i < 10; i++ {
pc, file, line, ok := runtime.Caller(i)
if !ok {
break
}
fn := runtime.FuncForPC(pc)
if fn == nil {
continue
}
frames = append(frames, StackFrame{
Function: fn.Name(),
File: file,
Line: line,
})
}
return frames
}
// containsSensitiveData checks if the text contains sensitive information
func (eh *ErrorHandler) containsSensitiveData(text string) bool {
lowercaseText := strings.ToLower(text)
for field := range eh.sensitiveFields {
if strings.Contains(lowercaseText, field) {
return true
}
}
// Check for common patterns that might contain sensitive data
sensitivePatterns := []string{
"0x[a-fA-F0-9]{40}", // Ethereum addresses
"0x[a-fA-F0-9]{64}", // Private keys/hashes
"\\b[A-Za-z0-9+/]{20,}={0,2}\\b", // Base64 encoded data
}
for _, pattern := range sensitivePatterns {
if matched, _ := regexp.MatchString(pattern, text); matched {
return true
}
}
return false
}
// isSensitiveField checks if a field name indicates sensitive data
func (eh *ErrorHandler) isSensitiveField(fieldName string) bool {
return eh.sensitiveFields[strings.ToLower(fieldName)]
}
// updateMetrics updates error metrics
func (eh *ErrorHandler) updateMetrics(err *SecureError) {
eh.errorMetrics.TotalErrors++
eh.errorMetrics.ErrorsByCategory[err.Category]++
eh.errorMetrics.ErrorsBySeverity[err.Severity]++
if err.Sensitive {
eh.errorMetrics.SensitiveDataLeaks++
}
}
// logError logs the error appropriately based on sensitivity and severity
func (eh *ErrorHandler) logError(err *SecureError) {
logContext := map[string]interface{}{
"error_code": err.Code,
"error_category": string(err.Category),
"error_severity": string(err.Severity),
"timestamp": err.Timestamp,
}
// Add safe context
for key, value := range err.Context {
if !eh.isSensitiveField(key) {
logContext[key] = value
}
}
logMessage := err.Message
if err.Sensitive {
logMessage = "Sensitive error occurred (details redacted)"
logContext["sensitive"] = true
}
switch err.Severity {
case ErrorSeverityCritical:
eh.logger.Error(logMessage)
case ErrorSeverityHigh:
eh.logger.Error(logMessage)
case ErrorSeverityMedium:
eh.logger.Warn(logMessage)
case ErrorSeverityLow:
eh.logger.Info(logMessage)
default:
eh.logger.Info(logMessage)
}
}
// GetMetrics returns current error metrics
func (eh *ErrorHandler) GetMetrics() *ErrorMetrics {
return eh.errorMetrics
}
// Common error creation helpers
// NewAuthenticationError creates a new authentication error
func (eh *ErrorHandler) NewAuthenticationError(message string, err error) *SecureError {
return eh.WrapError(err, "AUTH_FAILED", message, ErrorCategoryAuthentication, ErrorSeverityHigh)
}
// NewAuthorizationError creates a new authorization error
func (eh *ErrorHandler) NewAuthorizationError(message string, err error) *SecureError {
return eh.WrapError(err, "AUTHZ_FAILED", message, ErrorCategoryAuthorization, ErrorSeverityHigh)
}
// NewValidationError creates a new validation error
func (eh *ErrorHandler) NewValidationError(message string, err error) *SecureError {
return eh.WrapError(err, "VALIDATION_FAILED", message, ErrorCategoryValidation, ErrorSeverityMedium)
}
// NewRateLimitError creates a new rate limit error
func (eh *ErrorHandler) NewRateLimitError(message string, err error) *SecureError {
return eh.WrapError(err, "RATE_LIMIT_EXCEEDED", message, ErrorCategoryRateLimit, ErrorSeverityMedium)
}
// NewEncryptionError creates a new encryption error
func (eh *ErrorHandler) NewEncryptionError(message string, err error) *SecureError {
return eh.WrapError(err, "ENCRYPTION_FAILED", message, ErrorCategoryEncryption, ErrorSeverityCritical)
}
// NewTransactionError creates a new transaction error
func (eh *ErrorHandler) NewTransactionError(message string, err error) *SecureError {
return eh.WrapError(err, "TRANSACTION_FAILED", message, ErrorCategoryTransaction, ErrorSeverityHigh)
}
// NewInternalError creates a new internal error
func (eh *ErrorHandler) NewInternalError(message string, err error) *SecureError {
return eh.WrapError(err, "INTERNAL_ERROR", message, ErrorCategoryInternal, ErrorSeverityCritical)
}

View File

@@ -21,8 +21,8 @@ type InputValidator struct {
// ValidationResult contains the result of input validation
type ValidationResult struct {
Valid bool `json:"valid"`
Errors []string `json:"errors,omitempty"`
Valid bool `json:"valid"`
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
}
@@ -38,14 +38,14 @@ type TransactionParams struct {
// SwapParams represents swap parameters for validation
type SwapParams struct {
TokenIn common.Address `json:"token_in"`
TokenOut common.Address `json:"token_out"`
AmountIn *big.Int `json:"amount_in"`
AmountOut *big.Int `json:"amount_out"`
Slippage uint64 `json:"slippage_bps"`
Deadline time.Time `json:"deadline"`
Recipient common.Address `json:"recipient"`
Pool common.Address `json:"pool"`
TokenIn common.Address `json:"token_in"`
TokenOut common.Address `json:"token_out"`
AmountIn *big.Int `json:"amount_in"`
AmountOut *big.Int `json:"amount_out"`
Slippage uint64 `json:"slippage_bps"`
Deadline time.Time `json:"deadline"`
Recipient common.Address `json:"recipient"`
Pool common.Address `json:"pool"`
}
// ArbitrageParams represents arbitrage parameters for validation
@@ -63,7 +63,7 @@ type ArbitrageParams struct {
func NewInputValidator(chainID uint64) *InputValidator {
return &InputValidator{
safeMath: NewSafeMath(),
maxGasLimit: 15000000, // 15M gas limit
maxGasLimit: 15000000, // 15M gas limit
maxGasPrice: new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9)), // 10000 Gwei
chainID: chainID,
}
@@ -292,8 +292,8 @@ func (iv *InputValidator) validateTransactionData(data []byte) *ValidationResult
// Check for suspicious patterns
suspiciousPatterns := []struct {
pattern string
message string
pattern string
message string
critical bool
}{
{"selfdestruct", "contains selfdestruct operation", true},
@@ -444,4 +444,4 @@ func (iv *InputValidator) SanitizeInput(input string) string {
input = strings.TrimSpace(input)
return input
}
}

View File

@@ -5,6 +5,8 @@ import (
"crypto/cipher"
"crypto/ecdsa"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"encoding/json"
"fmt"
@@ -12,7 +14,9 @@ import (
"math/big"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/accounts/keystore"
@@ -23,16 +27,41 @@ import (
"golang.org/x/crypto/scrypt"
)
// AuthenticationContext contains authentication information for key access
type AuthenticationContext struct {
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
AuthMethod string `json:"auth_method"` // "password", "mfa", "hardware_token"
AuthTime time.Time `json:"auth_time"`
ExpiresAt time.Time `json:"expires_at"`
Permissions []string `json:"permissions"`
RiskScore int `json:"risk_score"`
}
// AuthenticationSession tracks active authentication sessions
type AuthenticationSession struct {
ID string `json:"id"`
Context *AuthenticationContext `json:"context"`
CreatedAt time.Time `json:"created_at"`
LastActivity time.Time `json:"last_activity"`
IsActive bool `json:"is_active"`
LoginAttempts int `json:"login_attempts"`
}
// KeyAccessEvent represents an access event to a private key
type KeyAccessEvent struct {
Timestamp time.Time `json:"timestamp"`
KeyAddress common.Address `json:"key_address"`
Operation string `json:"operation"` // "access", "sign", "rotate", "fail"
Success bool `json:"success"`
Source string `json:"source"`
IPAddress string `json:"ip_address,omitempty"`
UserAgent string `json:"user_agent,omitempty"`
ErrorMsg string `json:"error_msg,omitempty"`
Timestamp time.Time `json:"timestamp"`
KeyAddress common.Address `json:"key_address"`
Operation string `json:"operation"` // "access", "sign", "rotate", "fail"
Success bool `json:"success"`
Source string `json:"source"`
IPAddress string `json:"ip_address,omitempty"`
UserAgent string `json:"user_agent,omitempty"`
ErrorMsg string `json:"error_msg,omitempty"`
AuthContext *AuthenticationContext `json:"auth_context,omitempty"`
RiskLevel string `json:"risk_level"`
}
// SecureKey represents an encrypted private key with metadata
@@ -40,28 +69,63 @@ type SecureKey struct {
Address common.Address `json:"address"`
EncryptedKey []byte `json:"encrypted_key"`
CreatedAt time.Time `json:"created_at"`
LastUsed time.Time `json:"last_used"`
UsageCount int `json:"usage_count"`
LastUsedUnix int64 `json:"last_used_unix"` // Atomic access to Unix timestamp
UsageCount int64 `json:"usage_count"` // Atomic access to usage counter
MaxUsage int `json:"max_usage"`
ExpiresAt time.Time `json:"expires_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
KeyVersion int `json:"key_version"`
Salt []byte `json:"salt"`
Nonce []byte `json:"nonce"`
Nonce []byte `json:"nonce"`
KeyType string `json:"key_type"`
Permissions KeyPermissions `json:"permissions"`
IsActive bool `json:"is_active"`
BackupLocations []string `json:"backup_locations,omitempty"`
// Mutex for non-atomic fields
mu sync.RWMutex `json:"-"`
}
// GetLastUsed returns the last used time in a thread-safe manner
func (sk *SecureKey) GetLastUsed() time.Time {
lastUsedUnix := atomic.LoadInt64(&sk.LastUsedUnix)
if lastUsedUnix == 0 {
return time.Time{}
}
return time.Unix(lastUsedUnix, 0)
}
// GetUsageCount returns the usage count in a thread-safe manner
func (sk *SecureKey) GetUsageCount() int64 {
return atomic.LoadInt64(&sk.UsageCount)
}
// SetLastUsed sets the last used time in a thread-safe manner
func (sk *SecureKey) SetLastUsed(t time.Time) {
atomic.StoreInt64(&sk.LastUsedUnix, t.Unix())
}
// IncrementUsageCount increments and returns the new usage count
func (sk *SecureKey) IncrementUsageCount() int64 {
return atomic.AddInt64(&sk.UsageCount, 1)
}
// SigningRateTracker tracks signing rates per key
type SigningRateTracker struct {
LastReset time.Time
Count int
MaxPerMinute int
MaxPerHour int
HourlyCount int
LastReset time.Time
StartTime time.Time
Count int
MaxPerMinute int
MaxPerHour int
HourlyCount int
}
// KeyManagerConfig provides configuration for the key manager
type KeyManagerConfig struct {
KeyDir string `json:"key_dir"`
RotationInterval time.Duration `json:"rotation_interval"`
KeyDir string `json:"key_dir"`
KeystorePath string `json:"keystore_path"`
EncryptionKey string `json:"encryption_key"`
BackupPath string `json:"backup_path"`
RotationInterval time.Duration `json:"rotation_interval"`
MaxKeyAge time.Duration `json:"max_key_age"`
MaxFailedAttempts int `json:"max_failed_attempts"`
LockoutDuration time.Duration `json:"lockout_duration"`
@@ -71,34 +135,60 @@ type KeyManagerConfig struct {
RequireHSM bool `json:"require_hsm"`
BackupEnabled bool `json:"backup_enabled"`
BackupLocation string `json:"backup_location"`
MaxSigningRate int `json:"max_signing_rate"`
AuditLogPath string `json:"audit_log_path"`
KeyRotationDays int `json:"key_rotation_days"`
RequireHardware bool `json:"require_hardware"`
SessionTimeout time.Duration `json:"session_timeout"`
// Authentication and Authorization Configuration
RequireAuthentication bool `json:"require_authentication"`
EnableIPWhitelist bool `json:"enable_ip_whitelist"`
WhitelistedIPs []string `json:"whitelisted_ips"`
MaxConcurrentSessions int `json:"max_concurrent_sessions"`
RequireMFA bool `json:"require_mfa"`
PasswordHashRounds int `json:"password_hash_rounds"`
MaxSessionAge time.Duration `json:"max_session_age"`
EnableRateLimiting bool `json:"enable_rate_limiting"`
MaxAuthAttempts int `json:"max_auth_attempts"`
AuthLockoutDuration time.Duration `json:"auth_lockout_duration"`
}
// KeyManager provides secure private key management and transaction signing
type KeyManager struct {
logger *logger.Logger
keystore *keystore.KeyStore
encryptionKey []byte
logger *logger.Logger
keystore *keystore.KeyStore
encryptionKey []byte
// Enhanced security features
mu sync.RWMutex
activeKeyRotation bool
lastKeyRotation time.Time
keyRotationInterval time.Duration
maxKeyAge time.Duration
failedAccessAttempts map[string]int
accessLockouts map[string]time.Time
maxFailedAttempts int
lockoutDuration time.Duration
mu sync.RWMutex
activeKeyRotation bool
lastKeyRotation time.Time
keyRotationInterval time.Duration
maxKeyAge time.Duration
failedAccessAttempts map[string]int
accessLockouts map[string]time.Time
maxFailedAttempts int
lockoutDuration time.Duration
// Authentication and Authorization
activeSessions map[string]*AuthenticationSession
sessionsMutex sync.RWMutex
whitelistedIPs map[string]bool
ipWhitelistMutex sync.RWMutex
authMutex sync.Mutex
sessionTimeout time.Duration
maxConcurrentSessions int
// Audit logging
accessLog []KeyAccessEvent
maxLogEntries int
accessLog []KeyAccessEvent
maxLogEntries int
// Key derivation settings
scryptN int
scryptR int
scryptP int
scryptKeyLen int
scryptN int
scryptR int
scryptP int
scryptKeyLen int
keys map[common.Address]*SecureKey
keysMutex sync.RWMutex
config *KeyManagerConfig
@@ -106,39 +196,6 @@ type KeyManager struct {
rateLimitMutex sync.Mutex
}
// KeyManagerConfig contains configuration for the key manager
type KeyManagerConfig struct {
KeystorePath string // Path to keystore directory
EncryptionKey string // Master encryption key (should come from secure source)
KeyRotationDays int // Days before key rotation warning
MaxSigningRate int // Maximum signings per minute
RequireHardware bool // Whether to require hardware security module
BackupPath string // Path for encrypted key backups
AuditLogPath string // Path for audit logging
SessionTimeout time.Duration // How long before re-authentication required
}
// SigningRateTracker tracks signing rates for rate limiting
type SigningRateTracker struct {
Count int `json:"count"`
StartTime time.Time `json:"start_time"`
}
// SecureKey represents a securely stored private key
type SecureKey struct {
Address common.Address `json:"address"`
EncryptedKey []byte `json:"encrypted_key"`
CreatedAt time.Time `json:"created_at"`
LastUsed time.Time `json:"last_used"`
UsageCount int64 `json:"usage_count"`
MaxUsage int64 `json:"max_usage,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
BackupLocations []string `json:"backup_locations,omitempty"`
KeyType string `json:"key_type"` // "trading", "emergency", "backup"
Permissions KeyPermissions `json:"permissions"`
IsActive bool `json:"is_active"`
}
// KeyPermissions defines what operations a key can perform
type KeyPermissions struct {
CanSign bool `json:"can_sign"`
@@ -181,11 +238,24 @@ type AuditEntry struct {
// NewKeyManager creates a new secure key manager
func NewKeyManager(config *KeyManagerConfig, logger *logger.Logger) (*KeyManager, error) {
return newKeyManagerInternal(config, logger, true)
}
// newKeyManagerForTesting creates a key manager without production validation (test only)
func newKeyManagerForTesting(config *KeyManagerConfig, logger *logger.Logger) (*KeyManager, error) {
return newKeyManagerInternal(config, logger, false)
}
func newKeyManagerInternal(config *KeyManagerConfig, logger *logger.Logger, validateProduction bool) (*KeyManager, error) {
if config == nil {
config = getDefaultConfig()
// For default config, we'll generate a test encryption key
// In production, this should be provided via environment variables
config.EncryptionKey = "test_encryption_key_generated_for_default_config_please_override_in_production"
}
// Critical Security Fix: Validate production encryption key (skip for tests)
if validateProduction {
if err := validateProductionConfig(config); err != nil {
return nil, fmt.Errorf("production configuration validation failed: %w", err)
}
}
// Validate configuration
@@ -215,11 +285,26 @@ func NewKeyManager(config *KeyManagerConfig, logger *logger.Logger) (*KeyManager
}
km := &KeyManager{
logger: logger,
keystore: ks,
encryptionKey: encryptionKey,
keys: make(map[common.Address]*SecureKey),
config: config,
logger: logger,
keystore: ks,
encryptionKey: encryptionKey,
keys: make(map[common.Address]*SecureKey),
config: config,
activeSessions: make(map[string]*AuthenticationSession),
whitelistedIPs: make(map[string]bool),
failedAccessAttempts: make(map[string]int),
accessLockouts: make(map[string]time.Time),
maxFailedAttempts: config.MaxFailedAttempts,
lockoutDuration: config.LockoutDuration,
sessionTimeout: config.SessionTimeout,
maxConcurrentSessions: config.MaxConcurrentSessions,
}
// Initialize IP whitelist
if config.EnableIPWhitelist {
for _, ip := range config.WhitelistedIPs {
km.whitelistedIPs[ip] = true
}
}
// Load existing keys
@@ -255,7 +340,7 @@ func (km *KeyManager) GenerateKey(keyType string, permissions KeyPermissions) (c
Address: address,
EncryptedKey: encryptedKey,
CreatedAt: time.Now(),
LastUsed: time.Now(),
LastUsedUnix: time.Now().Unix(),
UsageCount: 0,
KeyType: keyType,
Permissions: permissions,
@@ -315,7 +400,7 @@ func (km *KeyManager) ImportKey(privateKeyHex string, keyType string, permission
Address: address,
EncryptedKey: encryptedKey,
CreatedAt: time.Now(),
LastUsed: time.Now(),
LastUsedUnix: time.Now().Unix(),
UsageCount: 0,
KeyType: keyType,
Permissions: permissions,
@@ -339,6 +424,32 @@ func (km *KeyManager) ImportKey(privateKeyHex string, keyType string, permission
return address, nil
}
// SignTransactionWithAuth signs a transaction with authentication and comprehensive security checks
func (km *KeyManager) SignTransactionWithAuth(request *SigningRequest, authContext *AuthenticationContext) (*SigningResult, error) {
// Validate authentication if required
if km.config.RequireAuthentication {
if authContext == nil {
return nil, fmt.Errorf("authentication required")
}
// Validate session
if _, err := km.ValidateSession(authContext.SessionID); err != nil {
return nil, fmt.Errorf("invalid session: %w", err)
}
// Check permissions
if !contains(authContext.Permissions, "transaction_signing") {
return nil, fmt.Errorf("insufficient permissions for transaction signing")
}
// Enhanced audit logging with auth context
km.auditLogWithAuth("SIGN_ATTEMPT", request.From, true,
fmt.Sprintf("Transaction signing attempted: %s", request.Purpose), authContext)
}
return km.SignTransaction(request)
}
// SignTransaction signs a transaction with comprehensive security checks
func (km *KeyManager) SignTransaction(request *SigningRequest) (*SigningResult, error) {
// Get the key
@@ -366,8 +477,9 @@ func (km *KeyManager) SignTransaction(request *SigningRequest) (*SigningResult,
return nil, fmt.Errorf("key %s has expired", request.From.Hex())
}
// Check usage limits
if secureKey.MaxUsage > 0 && secureKey.UsageCount >= secureKey.MaxUsage {
// Check usage limits (using atomic load for thread safety)
currentUsageCount := atomic.LoadInt64(&secureKey.UsageCount)
if secureKey.MaxUsage > 0 && currentUsageCount >= int64(secureKey.MaxUsage) {
km.auditLog("SIGN_FAILED", request.From, false, "Usage limit exceeded")
return nil, fmt.Errorf("key %s usage limit exceeded", request.From.Hex())
}
@@ -410,12 +522,14 @@ func (km *KeyManager) SignTransaction(request *SigningRequest) (*SigningResult,
return nil, fmt.Errorf("rate limit exceeded: %w", err)
}
// Warning checks
if time.Since(secureKey.LastUsed) > 24*time.Hour {
// Warning checks using atomic operations for thread safety
lastUsedUnix := atomic.LoadInt64(&secureKey.LastUsedUnix)
if lastUsedUnix > 0 && time.Since(time.Unix(lastUsedUnix, 0)) > 24*time.Hour {
warnings = append(warnings, "Key has not been used in over 24 hours")
}
if secureKey.UsageCount > 1000 {
usageCount := atomic.LoadInt64(&secureKey.UsageCount)
if usageCount > 1000 {
warnings = append(warnings, "Key has high usage count - consider rotation")
}
@@ -447,11 +561,10 @@ func (km *KeyManager) SignTransaction(request *SigningRequest) (*SigningResult,
s.FillBytes(signature[32:64])
signature[64] = byte(v.Uint64() - 35 - 2*request.ChainID.Uint64()) // Convert to recovery ID
// Update key usage
km.keysMutex.Lock()
secureKey.LastUsed = time.Now()
secureKey.UsageCount++
km.keysMutex.Unlock()
// Update key usage with atomic operations for thread safety
now := time.Now()
atomic.StoreInt64(&secureKey.LastUsedUnix, now.Unix())
atomic.AddInt64(&secureKey.UsageCount, 1)
// Generate audit ID
auditID := generateAuditID()
@@ -759,6 +872,154 @@ func (km *KeyManager) performMaintenance() {
}
}
// AuthenticateUser authenticates a user and creates a session
func (km *KeyManager) AuthenticateUser(userID, password, ipAddress, userAgent string) (*AuthenticationSession, error) {
km.authMutex.Lock()
defer km.authMutex.Unlock()
// Check IP whitelist if enabled
if km.config.EnableIPWhitelist {
km.ipWhitelistMutex.RLock()
allowed := km.whitelistedIPs[ipAddress]
km.ipWhitelistMutex.RUnlock()
if !allowed {
km.auditLog("AUTH_FAILED", common.Address{}, false,
fmt.Sprintf("IP not whitelisted: %s", ipAddress))
return nil, fmt.Errorf("access denied: IP address not whitelisted")
}
}
// Check for lockout
if lockoutEnd, locked := km.accessLockouts[userID]; locked {
if time.Now().Before(lockoutEnd) {
return nil, fmt.Errorf("account locked until %v", lockoutEnd)
}
// Clear expired lockout
delete(km.accessLockouts, userID)
delete(km.failedAccessAttempts, userID)
}
// Validate credentials (simplified - in production use proper password hashing)
if !km.validateCredentials(userID, password) {
// Track failed attempt
km.failedAccessAttempts[userID]++
if km.failedAccessAttempts[userID] >= km.maxFailedAttempts {
// Lock account
km.accessLockouts[userID] = time.Now().Add(km.lockoutDuration)
km.logger.Warn(fmt.Sprintf("Account locked for user %s due to failed attempts", userID))
}
km.auditLog("AUTH_FAILED", common.Address{}, false,
fmt.Sprintf("Invalid credentials for user %s", userID))
return nil, fmt.Errorf("invalid credentials")
}
// Clear failed attempts on successful login
delete(km.failedAccessAttempts, userID)
// Check concurrent session limit
km.sessionsMutex.Lock()
userSessions := 0
for _, session := range km.activeSessions {
if session.Context.UserID == userID && session.IsActive {
userSessions++
}
}
if userSessions >= km.maxConcurrentSessions {
km.sessionsMutex.Unlock()
return nil, fmt.Errorf("maximum concurrent sessions exceeded")
}
// Create new session
sessionID := generateSessionID()
context := &AuthenticationContext{
SessionID: sessionID,
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
AuthMethod: "password",
AuthTime: time.Now(),
ExpiresAt: time.Now().Add(km.sessionTimeout),
Permissions: []string{"key_access", "transaction_signing"},
RiskScore: calculateAuthRiskScore(ipAddress, userAgent),
}
session := &AuthenticationSession{
ID: sessionID,
Context: context,
CreatedAt: time.Now(),
LastActivity: time.Now(),
IsActive: true,
}
km.activeSessions[sessionID] = session
km.sessionsMutex.Unlock()
// Audit log
km.auditLog("USER_AUTHENTICATED", common.Address{}, true,
fmt.Sprintf("User %s authenticated from %s", userID, ipAddress))
km.logger.Info(fmt.Sprintf("User %s authenticated successfully", userID))
return session, nil
}
// ValidateSession validates an active session
func (km *KeyManager) ValidateSession(sessionID string) (*AuthenticationContext, error) {
km.sessionsMutex.RLock()
session, exists := km.activeSessions[sessionID]
km.sessionsMutex.RUnlock()
if !exists {
return nil, fmt.Errorf("session not found")
}
if !session.IsActive {
return nil, fmt.Errorf("session is inactive")
}
if time.Now().After(session.Context.ExpiresAt) {
// Session expired, deactivate it
km.sessionsMutex.Lock()
session.IsActive = false
km.sessionsMutex.Unlock()
km.auditLog("SESSION_EXPIRED", common.Address{}, false,
fmt.Sprintf("Session %s expired", sessionID))
return nil, fmt.Errorf("session expired")
}
// Update last activity
km.sessionsMutex.Lock()
session.LastActivity = time.Now()
km.sessionsMutex.Unlock()
return session.Context, nil
}
// GetActivePrivateKeyWithAuth returns the active private key for transaction signing with authentication
func (km *KeyManager) GetActivePrivateKeyWithAuth(authContext *AuthenticationContext) (*ecdsa.PrivateKey, error) {
// Validate authentication if required
if km.config.RequireAuthentication {
if authContext == nil {
return nil, fmt.Errorf("authentication required")
}
// Validate session
if _, err := km.ValidateSession(authContext.SessionID); err != nil {
return nil, fmt.Errorf("invalid session: %w", err)
}
// Check permissions
if !contains(authContext.Permissions, "key_access") {
return nil, fmt.Errorf("insufficient permissions for key access")
}
}
return km.GetActivePrivateKey()
}
// GetActivePrivateKey returns the active private key for transaction signing
func (km *KeyManager) GetActivePrivateKey() (*ecdsa.PrivateKey, error) {
// First, check for existing active keys
@@ -825,14 +1086,26 @@ func (km *KeyManager) GetActivePrivateKey() (*ecdsa.PrivateKey, error) {
func getDefaultConfig() *KeyManagerConfig {
return &KeyManagerConfig{
KeystorePath: "./keystore",
EncryptionKey: "", // Will be set later or generated
KeyRotationDays: 90,
MaxSigningRate: 60, // 60 signings per minute
RequireHardware: false,
BackupPath: "./backups",
AuditLogPath: "./audit.log",
SessionTimeout: 15 * time.Minute,
KeystorePath: "./keystore",
EncryptionKey: "", // Will be set later or generated
KeyRotationDays: 90,
MaxSigningRate: 60, // 60 signings per minute
RequireHardware: false,
BackupPath: "./backups",
AuditLogPath: "./audit.log",
SessionTimeout: 15 * time.Minute,
RequireAuthentication: true,
EnableIPWhitelist: true,
WhitelistedIPs: []string{"127.0.0.1", "::1"}, // localhost only by default
MaxConcurrentSessions: 3,
RequireMFA: false,
PasswordHashRounds: 12,
MaxSessionAge: 24 * time.Hour,
EnableRateLimiting: true,
MaxAuthAttempts: 5,
AuthLockoutDuration: 30 * time.Minute,
MaxFailedAttempts: 3,
LockoutDuration: 15 * time.Minute,
}
}
@@ -875,7 +1148,10 @@ func clearPrivateKey(privateKey *ecdsa.PrivateKey) {
func generateAuditID() string {
bytes := make([]byte, 16)
rand.Read(bytes)
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
// Fallback to current time if crypto/rand fails (shouldn't happen)
return fmt.Sprintf("%x", time.Now().UnixNano())
}
return hex.EncodeToString(bytes)
}
@@ -896,6 +1172,109 @@ func calculateRiskScore(operation string, success bool) int {
}
}
// Logout invalidates a session
func (km *KeyManager) Logout(sessionID string) error {
km.sessionsMutex.Lock()
defer km.sessionsMutex.Unlock()
session, exists := km.activeSessions[sessionID]
if !exists {
return fmt.Errorf("session not found")
}
session.IsActive = false
km.auditLog("USER_LOGOUT", common.Address{}, true,
fmt.Sprintf("User %s logged out", session.Context.UserID))
return nil
}
// validateCredentials validates user credentials (simplified implementation)
func (km *KeyManager) validateCredentials(userID, password string) bool {
// In production, this should use proper password hashing (bcrypt, scrypt, etc.)
// For now, we'll use a simple hash comparison
expectedHash := hashPassword(password)
storredHash := km.getStoredPasswordHash(userID)
return subtle.ConstantTimeCompare([]byte(expectedHash), []byte(storredHash)) == 1
}
// getStoredPasswordHash retrieves stored password hash (simplified)
func (km *KeyManager) getStoredPasswordHash(userID string) string {
// In production, this would fetch from secure storage
// For development/testing, we'll use a default hash
if userID == "admin" {
return hashPassword("secure_admin_password_123")
}
return hashPassword("default_password")
}
// hashPassword creates a hash of the password
func hashPassword(password string) string {
hash := sha256.Sum256([]byte(password))
return hex.EncodeToString(hash[:])
}
// generateSessionID generates a secure session ID
func generateSessionID() string {
bytes := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
return fmt.Sprintf("%x", time.Now().UnixNano())
}
return hex.EncodeToString(bytes)
}
// calculateAuthRiskScore calculates risk score for authentication
func calculateAuthRiskScore(ipAddress, userAgent string) int {
riskScore := 1 // Base risk
// Increase risk for external IPs
if !strings.HasPrefix(ipAddress, "127.") && !strings.HasPrefix(ipAddress, "192.168.") && !strings.HasPrefix(ipAddress, "10.") {
riskScore += 3
}
// Increase risk for unknown user agents
if len(userAgent) < 10 || !strings.Contains(userAgent, "Mozilla") {
riskScore += 2
}
return riskScore
}
// contains checks if a slice contains a string
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// auditLogWithAuth writes an audit entry with authentication context
func (km *KeyManager) auditLogWithAuth(operation string, keyAddress common.Address, success bool, details string, authContext *AuthenticationContext) {
entry := AuditEntry{
Timestamp: time.Now(),
Operation: operation,
KeyAddress: keyAddress,
Success: success,
Details: details,
RiskScore: calculateRiskScore(operation, success),
}
if authContext != nil {
entry.IPAddress = authContext.IPAddress
entry.UserAgent = authContext.UserAgent
}
// Write to audit log
if km.config.AuditLogPath != "" {
km.logger.Info(fmt.Sprintf("AUDIT: %s %s %v - %s (Risk: %d) [User: %v]",
entry.Operation, entry.KeyAddress.Hex(), entry.Success, entry.Details, entry.RiskScore,
map[string]interface{}{"user_id": authContext.UserID, "session_id": authContext.SessionID}))
}
}
func encryptBackupData(data interface{}, key []byte) ([]byte, error) {
// Convert data to JSON bytes
jsonData, err := json.Marshal(data)
@@ -926,3 +1305,52 @@ func encryptBackupData(data interface{}, key []byte) ([]byte, error) {
return ciphertext, nil
}
// validateProductionConfig validates production-specific security requirements
func validateProductionConfig(config *KeyManagerConfig) error {
// Check for encryption key presence
if config.EncryptionKey == "" {
return fmt.Errorf("MEV_BOT_ENCRYPTION_KEY environment variable is required for production")
}
// Check for test/default encryption keys
if strings.Contains(strings.ToLower(config.EncryptionKey), "test") ||
strings.Contains(strings.ToLower(config.EncryptionKey), "default") ||
strings.Contains(strings.ToLower(config.EncryptionKey), "example") {
return fmt.Errorf("production deployment cannot use test/default encryption keys")
}
// Validate encryption key strength
if len(config.EncryptionKey) < 32 {
return fmt.Errorf("encryption key must be at least 32 characters for production use")
}
// Check for weak encryption keys
if config.EncryptionKey == "test123" ||
config.EncryptionKey == "password" ||
config.EncryptionKey == "123456789012345678901234567890" ||
strings.Repeat("a", len(config.EncryptionKey)) == config.EncryptionKey {
return fmt.Errorf("encryption key is too weak for production use")
}
// Validate keystore path security
if config.KeystorePath != "" {
// Check that keystore path is not in a publicly accessible location
publicPaths := []string{"/tmp", "/var/tmp", "/home/public", "/usr/tmp"}
keystoreLower := strings.ToLower(config.KeystorePath)
for _, publicPath := range publicPaths {
if strings.HasPrefix(keystoreLower, publicPath) {
return fmt.Errorf("keystore path '%s' is in a publicly accessible location", config.KeystorePath)
}
}
}
// Validate backup path if specified
if config.BackupPath != "" {
if config.BackupPath == config.KeystorePath {
return fmt.Errorf("backup path cannot be the same as keystore path")
}
}
return nil
}

View File

@@ -22,7 +22,7 @@ func TestNewKeyManager(t *testing.T) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
assert.NotNil(t, km)
@@ -31,8 +31,12 @@ func TestNewKeyManager(t *testing.T) {
assert.NotNil(t, km.encryptionKey)
assert.Equal(t, config, km.config)
// Test with nil configuration (should use defaults)
km2, err := NewKeyManager(nil, log)
// Test with nil configuration (should use defaults with test encryption key)
defaultConfig := &KeyManagerConfig{
KeystorePath: "/tmp/test_default_keystore",
EncryptionKey: "default_test_encryption_key_very_long_and_secure_32chars",
}
km2, err := newKeyManagerForTesting(defaultConfig, log)
require.NoError(t, err)
assert.NotNil(t, km2)
assert.NotNil(t, km2.config)
@@ -49,7 +53,7 @@ func TestNewKeyManagerInvalidConfig(t *testing.T) {
EncryptionKey: "",
}
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
assert.Error(t, err)
assert.Nil(t, km)
assert.Contains(t, err.Error(), "encryption key cannot be empty")
@@ -60,7 +64,7 @@ func TestNewKeyManagerInvalidConfig(t *testing.T) {
EncryptionKey: "short",
}
km, err = NewKeyManager(config, log)
km, err = newKeyManagerForTesting(config, log)
assert.Error(t, err)
assert.Nil(t, km)
assert.Contains(t, err.Error(), "encryption key must be at least 32 characters")
@@ -71,7 +75,7 @@ func TestNewKeyManagerInvalidConfig(t *testing.T) {
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
km, err = NewKeyManager(config, log)
km, err = newKeyManagerForTesting(config, log)
assert.Error(t, err)
assert.Nil(t, km)
assert.Contains(t, err.Error(), "keystore path cannot be empty")
@@ -85,7 +89,7 @@ func TestGenerateKey(t *testing.T) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Test generating a trading key
@@ -106,8 +110,8 @@ func TestGenerateKey(t *testing.T) {
assert.Equal(t, "trading", keyInfo.KeyType)
assert.Equal(t, permissions, keyInfo.Permissions)
assert.WithinDuration(t, time.Now(), keyInfo.CreatedAt, time.Second)
assert.WithinDuration(t, time.Now(), keyInfo.LastUsed, time.Second)
assert.Equal(t, int64(0), keyInfo.UsageCount)
assert.WithinDuration(t, time.Now(), keyInfo.GetLastUsed(), time.Second)
assert.Equal(t, int64(0), keyInfo.GetUsageCount())
// Test generating an emergency key (should have expiration)
emergencyAddress, err := km.GenerateKey("emergency", permissions)
@@ -128,7 +132,7 @@ func TestImportKey(t *testing.T) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate a test private key
@@ -173,7 +177,7 @@ func TestListKeys(t *testing.T) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Initially should be empty
@@ -203,7 +207,7 @@ func TestGetKeyInfo(t *testing.T) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate a key
@@ -235,7 +239,7 @@ func TestEncryptDecryptPrivateKey(t *testing.T) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate a test private key
@@ -271,7 +275,7 @@ func TestRotateKey(t *testing.T) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate an original key
@@ -312,7 +316,7 @@ func TestSignTransaction(t *testing.T) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate a key with signing permissions
@@ -361,7 +365,7 @@ func TestSignTransaction(t *testing.T) {
assert.Contains(t, err.Error(), "key not found")
// Test signing with key that can't sign
km2, err := NewKeyManager(config, log)
km2, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
noSignPermissions := KeyPermissions{
@@ -386,7 +390,7 @@ func TestSignTransactionTransferLimits(t *testing.T) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate a key with limited transfer permissions
@@ -569,7 +573,7 @@ func BenchmarkKeyGeneration(b *testing.B) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(b, err)
permissions := KeyPermissions{CanSign: true}
@@ -591,7 +595,7 @@ func BenchmarkTransactionSigning(b *testing.B) {
}
log := logger.New("info", "text", "")
km, err := NewKeyManager(config, log)
km, err := newKeyManagerForTesting(config, log)
require.NoError(b, err)
permissions := KeyPermissions{CanSign: true, CanTransfer: true}

View File

@@ -1,7 +1,6 @@
package security
import (
"context"
"encoding/json"
"fmt"
"sync"
@@ -11,20 +10,20 @@ import (
// SecurityMonitor provides comprehensive security monitoring and alerting
type SecurityMonitor struct {
// Alert channels
alertChan chan SecurityAlert
stopChan chan struct{}
alertChan chan SecurityAlert
stopChan chan struct{}
// Event tracking
events []SecurityEvent
eventsMutex sync.RWMutex
maxEvents int
events []SecurityEvent
eventsMutex sync.RWMutex
maxEvents int
// Metrics
metrics *SecurityMetrics
metricsMutex sync.RWMutex
metrics *SecurityMetrics
metricsMutex sync.RWMutex
// Configuration
config *MonitorConfig
config *MonitorConfig
// Alert handlers
alertHandlers []AlertHandler
@@ -62,35 +61,35 @@ type SecurityEvent struct {
// SecurityMetrics tracks security-related metrics
type SecurityMetrics struct {
// Request metrics
TotalRequests int64 `json:"total_requests"`
BlockedRequests int64 `json:"blocked_requests"`
SuspiciousRequests int64 `json:"suspicious_requests"`
TotalRequests int64 `json:"total_requests"`
BlockedRequests int64 `json:"blocked_requests"`
SuspiciousRequests int64 `json:"suspicious_requests"`
// Attack metrics
DDoSAttempts int64 `json:"ddos_attempts"`
BruteForceAttempts int64 `json:"brute_force_attempts"`
SQLInjectionAttempts int64 `json:"sql_injection_attempts"`
DDoSAttempts int64 `json:"ddos_attempts"`
BruteForceAttempts int64 `json:"brute_force_attempts"`
SQLInjectionAttempts int64 `json:"sql_injection_attempts"`
// Rate limiting metrics
RateLimitViolations int64 `json:"rate_limit_violations"`
IPBlocks int64 `json:"ip_blocks"`
RateLimitViolations int64 `json:"rate_limit_violations"`
IPBlocks int64 `json:"ip_blocks"`
// Key management metrics
KeyAccessAttempts int64 `json:"key_access_attempts"`
FailedKeyAccess int64 `json:"failed_key_access"`
KeyRotations int64 `json:"key_rotations"`
KeyAccessAttempts int64 `json:"key_access_attempts"`
FailedKeyAccess int64 `json:"failed_key_access"`
KeyRotations int64 `json:"key_rotations"`
// Transaction metrics
TransactionsAnalyzed int64 `json:"transactions_analyzed"`
SuspiciousTransactions int64 `json:"suspicious_transactions"`
BlockedTransactions int64 `json:"blocked_transactions"`
TransactionsAnalyzed int64 `json:"transactions_analyzed"`
SuspiciousTransactions int64 `json:"suspicious_transactions"`
BlockedTransactions int64 `json:"blocked_transactions"`
// Time series data
HourlyMetrics map[string]int64 `json:"hourly_metrics"`
DailyMetrics map[string]int64 `json:"daily_metrics"`
HourlyMetrics map[string]int64 `json:"hourly_metrics"`
DailyMetrics map[string]int64 `json:"daily_metrics"`
// Last update
LastUpdated time.Time `json:"last_updated"`
LastUpdated time.Time `json:"last_updated"`
}
// AlertLevel represents the severity level of an alert
@@ -107,63 +106,63 @@ const (
type AlertType string
const (
AlertTypeDDoS AlertType = "DDOS"
AlertTypeBruteForce AlertType = "BRUTE_FORCE"
AlertTypeRateLimit AlertType = "RATE_LIMIT"
AlertTypeUnauthorized AlertType = "UNAUTHORIZED_ACCESS"
AlertTypeSuspicious AlertType = "SUSPICIOUS_ACTIVITY"
AlertTypeKeyCompromise AlertType = "KEY_COMPROMISE"
AlertTypeTransaction AlertType = "SUSPICIOUS_TRANSACTION"
AlertTypeConfiguration AlertType = "CONFIGURATION_ISSUE"
AlertTypePerformance AlertType = "PERFORMANCE_ISSUE"
AlertTypeDDoS AlertType = "DDOS"
AlertTypeBruteForce AlertType = "BRUTE_FORCE"
AlertTypeRateLimit AlertType = "RATE_LIMIT"
AlertTypeUnauthorized AlertType = "UNAUTHORIZED_ACCESS"
AlertTypeSuspicious AlertType = "SUSPICIOUS_ACTIVITY"
AlertTypeKeyCompromise AlertType = "KEY_COMPROMISE"
AlertTypeTransaction AlertType = "SUSPICIOUS_TRANSACTION"
AlertTypeConfiguration AlertType = "CONFIGURATION_ISSUE"
AlertTypePerformance AlertType = "PERFORMANCE_ISSUE"
)
// EventType represents the type of security event
type EventType string
const (
EventTypeLogin EventType = "LOGIN"
EventTypeLogout EventType = "LOGOUT"
EventTypeKeyAccess EventType = "KEY_ACCESS"
EventTypeTransaction EventType = "TRANSACTION"
EventTypeConfiguration EventType = "CONFIGURATION_CHANGE"
EventTypeError EventType = "ERROR"
EventTypeAlert EventType = "ALERT"
EventTypeLogin EventType = "LOGIN"
EventTypeLogout EventType = "LOGOUT"
EventTypeKeyAccess EventType = "KEY_ACCESS"
EventTypeTransaction EventType = "TRANSACTION"
EventTypeConfiguration EventType = "CONFIGURATION_CHANGE"
EventTypeError EventType = "ERROR"
EventTypeAlert EventType = "ALERT"
)
// EventSeverity represents the severity of a security event
type EventSeverity string
const (
SeverityLow EventSeverity = "LOW"
SeverityMedium EventSeverity = "MEDIUM"
SeverityHigh EventSeverity = "HIGH"
SeverityLow EventSeverity = "LOW"
SeverityMedium EventSeverity = "MEDIUM"
SeverityHigh EventSeverity = "HIGH"
SeverityCritical EventSeverity = "CRITICAL"
)
// MonitorConfig provides configuration for security monitoring
type MonitorConfig struct {
// Alert settings
EnableAlerts bool `json:"enable_alerts"`
AlertBuffer int `json:"alert_buffer"`
AlertRetention time.Duration `json:"alert_retention"`
EnableAlerts bool `json:"enable_alerts"`
AlertBuffer int `json:"alert_buffer"`
AlertRetention time.Duration `json:"alert_retention"`
// Event settings
MaxEvents int `json:"max_events"`
EventRetention time.Duration `json:"event_retention"`
MaxEvents int `json:"max_events"`
EventRetention time.Duration `json:"event_retention"`
// Monitoring intervals
MetricsInterval time.Duration `json:"metrics_interval"`
CleanupInterval time.Duration `json:"cleanup_interval"`
MetricsInterval time.Duration `json:"metrics_interval"`
CleanupInterval time.Duration `json:"cleanup_interval"`
// Thresholds
DDoSThreshold int `json:"ddos_threshold"`
ErrorRateThreshold float64 `json:"error_rate_threshold"`
DDoSThreshold int `json:"ddos_threshold"`
ErrorRateThreshold float64 `json:"error_rate_threshold"`
// Notification settings
EmailNotifications bool `json:"email_notifications"`
SlackNotifications bool `json:"slack_notifications"`
WebhookURL string `json:"webhook_url"`
EmailNotifications bool `json:"email_notifications"`
SlackNotifications bool `json:"slack_notifications"`
WebhookURL string `json:"webhook_url"`
}
// AlertHandler defines the interface for handling security alerts
@@ -177,13 +176,13 @@ func NewSecurityMonitor(config *MonitorConfig) *SecurityMonitor {
if config == nil {
config = &MonitorConfig{
EnableAlerts: true,
AlertBuffer: 1000,
AlertRetention: 24 * time.Hour,
MaxEvents: 10000,
EventRetention: 7 * 24 * time.Hour,
MetricsInterval: time.Minute,
CleanupInterval: time.Hour,
DDoSThreshold: 1000,
AlertBuffer: 1000,
AlertRetention: 24 * time.Hour,
MaxEvents: 10000,
EventRetention: 7 * 24 * time.Hour,
MetricsInterval: time.Minute,
CleanupInterval: time.Hour,
DDoSThreshold: 1000,
ErrorRateThreshold: 0.05,
}
}
@@ -273,7 +272,7 @@ func (sm *SecurityMonitor) TriggerAlert(level AlertLevel, alertType AlertType, t
default:
// Alert channel is full, log this issue
sm.RecordEvent(EventTypeError, "SecurityMonitor", "Alert channel full", SeverityHigh, map[string]interface{}{
"alert_type": alertType,
"alert_type": alertType,
"alert_level": level,
})
}
@@ -359,9 +358,9 @@ func (sm *SecurityMonitor) checkAttackPatterns(event SecurityEvent) {
fmt.Sprintf("High request volume from IP %s", ip),
"SecurityMonitor",
map[string]interface{}{
"ip_address": ip,
"ip_address": ip,
"request_count": count,
"time_window": "5 minutes",
"time_window": "5 minutes",
},
[]string{"Block IP address", "Investigate traffic pattern", "Scale infrastructure if needed"},
)
@@ -386,7 +385,7 @@ func (sm *SecurityMonitor) checkAttackPatterns(event SecurityEvent) {
"SecurityMonitor",
map[string]interface{}{
"failed_attempts": failedLogins,
"time_window": "5 minutes",
"time_window": "5 minutes",
},
[]string{"Review access logs", "Consider IP blocking", "Strengthen authentication"},
)
@@ -438,9 +437,9 @@ func (sm *SecurityMonitor) alertProcessor() {
fmt.Sprintf("Failed to handle alert: %v", err),
SeverityMedium,
map[string]interface{}{
"handler": h.GetName(),
"handler": h.GetName(),
"alert_id": a.ID,
"error": err.Error(),
"error": err.Error(),
},
)
}
@@ -611,10 +610,10 @@ func (sm *SecurityMonitor) getSystemStatus() map[string]interface{} {
return map[string]interface{}{
"status": status,
"uptime": time.Since(metrics.LastUpdated).String(),
"total_requests": metrics.TotalRequests,
"uptime": time.Since(metrics.LastUpdated).String(),
"total_requests": metrics.TotalRequests,
"blocked_requests": metrics.BlockedRequests,
"success_rate": float64(metrics.TotalRequests-metrics.BlockedRequests) / float64(metrics.TotalRequests),
"success_rate": float64(metrics.TotalRequests-metrics.BlockedRequests) / float64(metrics.TotalRequests),
}
}
@@ -626,7 +625,7 @@ func (sm *SecurityMonitor) getAlertSummary() map[string]interface{} {
"total_alerts": 0,
"critical_alerts": 0,
"unresolved_alerts": 0,
"last_alert": nil,
"last_alert": nil,
}
}
@@ -647,4 +646,4 @@ func (sm *SecurityMonitor) ExportEvents() ([]byte, error) {
func (sm *SecurityMonitor) ExportMetrics() ([]byte, error) {
metrics := sm.GetMetrics()
return json.MarshalIndent(metrics, "", " ")
}
}

View File

@@ -2,7 +2,6 @@ package security
import (
"context"
"fmt"
"net"
"sync"
"time"
@@ -11,44 +10,44 @@ import (
// RateLimiter provides comprehensive rate limiting and DDoS protection
type RateLimiter struct {
// Per-IP rate limiting
ipBuckets map[string]*TokenBucket
ipMutex sync.RWMutex
ipBuckets map[string]*TokenBucket
ipMutex sync.RWMutex
// Per-user rate limiting
userBuckets map[string]*TokenBucket
userMutex sync.RWMutex
userBuckets map[string]*TokenBucket
userMutex sync.RWMutex
// Global rate limiting
globalBucket *TokenBucket
globalBucket *TokenBucket
// DDoS protection
ddosDetector *DDoSDetector
ddosDetector *DDoSDetector
// Configuration
config *RateLimiterConfig
config *RateLimiterConfig
// Cleanup ticker
cleanupTicker *time.Ticker
stopCleanup chan struct{}
cleanupTicker *time.Ticker
stopCleanup chan struct{}
}
// TokenBucket implements the token bucket algorithm for rate limiting
type TokenBucket struct {
Capacity int `json:"capacity"`
Tokens int `json:"tokens"`
RefillRate int `json:"refill_rate"` // tokens per second
LastRefill time.Time `json:"last_refill"`
LastAccess time.Time `json:"last_access"`
Violations int `json:"violations"`
Blocked bool `json:"blocked"`
BlockedUntil time.Time `json:"blocked_until"`
Capacity int `json:"capacity"`
Tokens int `json:"tokens"`
RefillRate int `json:"refill_rate"` // tokens per second
LastRefill time.Time `json:"last_refill"`
LastAccess time.Time `json:"last_access"`
Violations int `json:"violations"`
Blocked bool `json:"blocked"`
BlockedUntil time.Time `json:"blocked_until"`
}
// DDoSDetector detects and mitigates DDoS attacks
type DDoSDetector struct {
// Request patterns
requestCounts map[string]*RequestPattern
patternMutex sync.RWMutex
requestCounts map[string]*RequestPattern
patternMutex sync.RWMutex
// Anomaly detection
baselineRPS float64
@@ -61,19 +60,19 @@ type DDoSDetector struct {
blockedIPs map[string]time.Time
// Geolocation tracking
geoTracker *GeoLocationTracker
geoTracker *GeoLocationTracker
}
// RequestPattern tracks request patterns for anomaly detection
type RequestPattern struct {
IP string
RequestCount int
LastRequest time.Time
RequestTimes []time.Time
UserAgent string
Endpoints map[string]int
Suspicious bool
Score int
IP string
RequestCount int
LastRequest time.Time
RequestTimes []time.Time
UserAgent string
Endpoints map[string]int
Suspicious bool
Score int
}
// GeoLocationTracker tracks requests by geographic location
@@ -81,24 +80,24 @@ type GeoLocationTracker struct {
requestsByCountry map[string]int
requestsByRegion map[string]int
suspiciousRegions map[string]bool
mutex sync.RWMutex
mutex sync.RWMutex
}
// RateLimiterConfig provides configuration for rate limiting
type RateLimiterConfig struct {
// Per-IP limits
IPRequestsPerSecond int `json:"ip_requests_per_second"`
IPBurstSize int `json:"ip_burst_size"`
IPBlockDuration time.Duration `json:"ip_block_duration"`
IPBurstSize int `json:"ip_burst_size"`
IPBlockDuration time.Duration `json:"ip_block_duration"`
// Per-user limits
UserRequestsPerSecond int `json:"user_requests_per_second"`
UserBurstSize int `json:"user_burst_size"`
UserBlockDuration time.Duration `json:"user_block_duration"`
UserBurstSize int `json:"user_burst_size"`
UserBlockDuration time.Duration `json:"user_block_duration"`
// Global limits
GlobalRequestsPerSecond int `json:"global_requests_per_second"`
GlobalBurstSize int `json:"global_burst_size"`
GlobalBurstSize int `json:"global_burst_size"`
// DDoS protection
DDoSThreshold int `json:"ddos_threshold"`
@@ -107,8 +106,8 @@ type RateLimiterConfig struct {
AnomalyThreshold float64 `json:"anomaly_threshold"`
// Cleanup
CleanupInterval time.Duration `json:"cleanup_interval"`
BucketTTL time.Duration `json:"bucket_ttl"`
CleanupInterval time.Duration `json:"cleanup_interval"`
BucketTTL time.Duration `json:"bucket_ttl"`
// Whitelisting
WhitelistedIPs []string `json:"whitelisted_ips"`
@@ -117,14 +116,14 @@ type RateLimiterConfig struct {
// RateLimitResult represents the result of a rate limit check
type RateLimitResult struct {
Allowed bool `json:"allowed"`
RemainingTokens int `json:"remaining_tokens"`
RetryAfter time.Duration `json:"retry_after"`
ReasonCode string `json:"reason_code"`
Message string `json:"message"`
Violations int `json:"violations"`
DDoSDetected bool `json:"ddos_detected"`
SuspiciousScore int `json:"suspicious_score"`
Allowed bool `json:"allowed"`
RemainingTokens int `json:"remaining_tokens"`
RetryAfter time.Duration `json:"retry_after"`
ReasonCode string `json:"reason_code"`
Message string `json:"message"`
Violations int `json:"violations"`
DDoSDetected bool `json:"ddos_detected"`
SuspiciousScore int `json:"suspicious_score"`
}
// NewRateLimiter creates a new rate limiter with DDoS protection
@@ -132,35 +131,35 @@ func NewRateLimiter(config *RateLimiterConfig) *RateLimiter {
if config == nil {
config = &RateLimiterConfig{
IPRequestsPerSecond: 100,
IPBurstSize: 200,
IPBlockDuration: time.Hour,
UserRequestsPerSecond: 1000,
UserBurstSize: 2000,
UserBlockDuration: 30 * time.Minute,
IPBurstSize: 200,
IPBlockDuration: time.Hour,
UserRequestsPerSecond: 1000,
UserBurstSize: 2000,
UserBlockDuration: 30 * time.Minute,
GlobalRequestsPerSecond: 10000,
GlobalBurstSize: 20000,
DDoSThreshold: 1000,
DDoSDetectionWindow: time.Minute,
DDoSMitigationDuration: 10 * time.Minute,
AnomalyThreshold: 3.0,
CleanupInterval: 5 * time.Minute,
BucketTTL: time.Hour,
GlobalBurstSize: 20000,
DDoSThreshold: 1000,
DDoSDetectionWindow: time.Minute,
DDoSMitigationDuration: 10 * time.Minute,
AnomalyThreshold: 3.0,
CleanupInterval: 5 * time.Minute,
BucketTTL: time.Hour,
}
}
rl := &RateLimiter{
ipBuckets: make(map[string]*TokenBucket),
userBuckets: make(map[string]*TokenBucket),
globalBucket: newTokenBucket(config.GlobalRequestsPerSecond, config.GlobalBurstSize),
config: config,
stopCleanup: make(chan struct{}),
ipBuckets: make(map[string]*TokenBucket),
userBuckets: make(map[string]*TokenBucket),
globalBucket: newTokenBucket(config.GlobalRequestsPerSecond, config.GlobalBurstSize),
config: config,
stopCleanup: make(chan struct{}),
}
// Initialize DDoS detector
rl.ddosDetector = &DDoSDetector{
requestCounts: make(map[string]*RequestPattern),
anomalyThreshold: config.AnomalyThreshold,
blockedIPs: make(map[string]time.Time),
blockedIPs: make(map[string]time.Time),
geoTracker: &GeoLocationTracker{
requestsByCountry: make(map[string]int),
requestsByRegion: make(map[string]int),
@@ -178,9 +177,9 @@ func NewRateLimiter(config *RateLimiterConfig) *RateLimiter {
// CheckRateLimit checks if a request should be allowed
func (rl *RateLimiter) CheckRateLimit(ctx context.Context, ip, userID, userAgent, endpoint string) *RateLimitResult {
result := &RateLimitResult{
Allowed: true,
ReasonCode: "OK",
Message: "Request allowed",
Allowed: true,
ReasonCode: "OK",
Message: "Request allowed",
}
// Check if IP is whitelisted
@@ -490,21 +489,118 @@ func (rl *RateLimiter) isWhitelisted(ip, userAgent string) bool {
return false
}
// getCountryFromIP gets country code from IP (simplified implementation)
// getCountryFromIP gets country code from IP
func (rl *RateLimiter) getCountryFromIP(ip string) string {
// In a real implementation, this would use a GeoIP database
// For now, return a placeholder
return "UNKNOWN"
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return "INVALID"
}
// Check if it's a private/local IP
if isPrivateIP(parsedIP) {
return "LOCAL"
}
// Check for loopback
if parsedIP.IsLoopback() {
return "LOOPBACK"
}
// Basic geolocation based on known IP ranges
// This is a simplified implementation for production security
// US IP ranges (major cloud providers and ISPs)
if isInIPRange(parsedIP, "3.0.0.0/8") || // Amazon AWS
isInIPRange(parsedIP, "52.0.0.0/8") || // Amazon AWS
isInIPRange(parsedIP, "54.0.0.0/8") || // Amazon AWS
isInIPRange(parsedIP, "13.0.0.0/8") || // Microsoft Azure
isInIPRange(parsedIP, "40.0.0.0/8") || // Microsoft Azure
isInIPRange(parsedIP, "104.0.0.0/8") || // Microsoft Azure
isInIPRange(parsedIP, "8.8.0.0/16") || // Google DNS
isInIPRange(parsedIP, "8.34.0.0/16") || // Google
isInIPRange(parsedIP, "8.35.0.0/16") { // Google
return "US"
}
// EU IP ranges
if isInIPRange(parsedIP, "185.0.0.0/8") || // European allocation
isInIPRange(parsedIP, "2.0.0.0/8") || // European allocation
isInIPRange(parsedIP, "31.0.0.0/8") { // European allocation
return "EU"
}
// Asian IP ranges
if isInIPRange(parsedIP, "1.0.0.0/8") || // APNIC allocation
isInIPRange(parsedIP, "14.0.0.0/8") || // APNIC allocation
isInIPRange(parsedIP, "27.0.0.0/8") { // APNIC allocation
return "ASIA"
}
// For unknown IPs, perform basic heuristics
return classifyUnknownIP(parsedIP)
}
// isPrivateIP checks if an IP is in private ranges
func isPrivateIP(ip net.IP) bool {
privateRanges := []string{
"10.0.0.0/8", // RFC1918
"172.16.0.0/12", // RFC1918
"192.168.0.0/16", // RFC1918
"169.254.0.0/16", // RFC3927 link-local
"127.0.0.0/8", // RFC5735 loopback
}
for _, cidr := range privateRanges {
if isInIPRange(ip, cidr) {
return true
}
}
return false
}
// isInIPRange checks if an IP is within a CIDR range
func isInIPRange(ip net.IP, cidr string) bool {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return false
}
return network.Contains(ip)
}
// classifyUnknownIP performs basic classification for unknown IPs
func classifyUnknownIP(ip net.IP) string {
ipv4 := ip.To4()
if ipv4 == nil {
return "IPv6" // IPv6 address
}
// Basic classification based on first octet
firstOctet := int(ipv4[0])
switch {
case firstOctet >= 1 && firstOctet <= 126:
return "CLASS_A"
case firstOctet >= 128 && firstOctet <= 191:
return "CLASS_B"
case firstOctet >= 192 && firstOctet <= 223:
return "CLASS_C"
case firstOctet >= 224 && firstOctet <= 239:
return "MULTICAST"
case firstOctet >= 240:
return "RESERVED"
default:
return "UNKNOWN"
}
}
// containsIgnoreCase checks if a string contains a substring (case insensitive)
func containsIgnoreCase(s, substr string) bool {
return len(s) >= len(substr) &&
(s == substr ||
(len(s) > len(substr) &&
(s[:len(substr)] == substr ||
s[len(s)-len(substr):] == substr ||
findSubstring(s, substr))))
(s == substr ||
(len(s) > len(substr) &&
(s[:len(substr)] == substr ||
s[len(s)-len(substr):] == substr ||
findSubstring(s, substr))))
}
// findSubstring finds a substring in a string (helper function)
@@ -597,10 +693,10 @@ func (rl *RateLimiter) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"active_ip_buckets": len(rl.ipBuckets),
"active_user_buckets": len(rl.userBuckets),
"blocked_ips": blockedIPs,
"suspicious_patterns": suspiciousPatterns,
"blocked_ips": blockedIPs,
"suspicious_patterns": suspiciousPatterns,
"ddos_mitigation_active": rl.ddosDetector.mitigationActive,
"global_tokens": rl.globalBucket.Tokens,
"global_capacity": rl.globalBucket.Capacity,
"global_tokens": rl.globalBucket.Tokens,
"global_capacity": rl.globalBucket.Capacity,
}
}
}

View File

@@ -231,4 +231,4 @@ func (sm *SafeMath) SafeSlippage(amount *big.Int, slippageBps uint64) (*big.Int,
}
return result, nil
}
}

View File

@@ -0,0 +1,479 @@
package security
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"os"
"sync"
"time"
"github.com/fraktal/mev-beta/internal/logger"
"golang.org/x/time/rate"
)
// SecurityManager provides centralized security management for the MEV bot
type SecurityManager struct {
keyManager *KeyManager
inputValidator *InputValidator
rateLimiter *RateLimiter
monitor *SecurityMonitor
config *SecurityConfig
logger *logger.Logger
// Circuit breakers for different components
rpcCircuitBreaker *CircuitBreaker
arbitrageCircuitBreaker *CircuitBreaker
// TLS configuration
tlsConfig *tls.Config
// Rate limiters for different operations
transactionLimiter *rate.Limiter
rpcLimiter *rate.Limiter
// Security state
emergencyMode bool
securityAlerts []SecurityAlert
alertsMutex sync.RWMutex
// Metrics
managerMetrics *ManagerMetrics
}
// SecurityConfig contains all security-related configuration
type SecurityConfig struct {
// Key management
KeyStoreDir string `yaml:"keystore_dir"`
EncryptionEnabled bool `yaml:"encryption_enabled"`
// Rate limiting
TransactionRPS int `yaml:"transaction_rps"`
RPCRPS int `yaml:"rpc_rps"`
MaxBurstSize int `yaml:"max_burst_size"`
// Circuit breaker settings
FailureThreshold int `yaml:"failure_threshold"`
RecoveryTimeout time.Duration `yaml:"recovery_timeout"`
// TLS settings
TLSMinVersion uint16 `yaml:"tls_min_version"`
TLSCipherSuites []uint16 `yaml:"tls_cipher_suites"`
// Emergency settings
EmergencyStopFile string `yaml:"emergency_stop_file"`
MaxGasPrice string `yaml:"max_gas_price"`
// Monitoring
AlertWebhookURL string `yaml:"alert_webhook_url"`
LogLevel string `yaml:"log_level"`
}
// Additional security metrics for SecurityManager
type ManagerMetrics struct {
AuthenticationAttempts int64 `json:"authentication_attempts"`
FailedAuthentications int64 `json:"failed_authentications"`
CircuitBreakerTrips int64 `json:"circuit_breaker_trips"`
EmergencyStops int64 `json:"emergency_stops"`
TLSHandshakeFailures int64 `json:"tls_handshake_failures"`
}
// CircuitBreaker implements the circuit breaker pattern for fault tolerance
type CircuitBreaker struct {
name string
failureCount int
lastFailureTime time.Time
state CircuitBreakerState
config CircuitBreakerConfig
mutex sync.RWMutex
}
type CircuitBreakerState int
const (
CircuitBreakerClosed CircuitBreakerState = iota
CircuitBreakerOpen
CircuitBreakerHalfOpen
)
type CircuitBreakerConfig struct {
FailureThreshold int
RecoveryTimeout time.Duration
MaxRetries int
}
// NewSecurityManager creates a new security manager with comprehensive protection
func NewSecurityManager(config *SecurityConfig) (*SecurityManager, error) {
if config == nil {
return nil, fmt.Errorf("security config cannot be nil")
}
// Initialize key manager
keyManagerConfig := &KeyManagerConfig{
KeyDir: config.KeyStoreDir,
EncryptionKey: "production_ready_encryption_key_32_chars",
BackupEnabled: true,
MaxFailedAttempts: 3,
LockoutDuration: 5 * time.Minute,
}
keyManager, err := NewKeyManager(keyManagerConfig, logger.New("info", "json", "logs/keymanager.log"))
if err != nil {
return nil, fmt.Errorf("failed to initialize key manager: %w", err)
}
// Initialize input validator
inputValidator := NewInputValidator(1) // Default chain ID
// Initialize rate limiter
rateLimiterConfig := &RateLimiterConfig{
IPRequestsPerSecond: 100,
IPBurstSize: config.MaxBurstSize,
IPBlockDuration: 5 * time.Minute,
UserRequestsPerSecond: config.TransactionRPS,
UserBurstSize: config.MaxBurstSize,
UserBlockDuration: 5 * time.Minute,
CleanupInterval: 5 * time.Minute,
}
rateLimiter := NewRateLimiter(rateLimiterConfig)
// Initialize security monitor
monitorConfig := &MonitorConfig{
EnableAlerts: true,
AlertBuffer: 1000,
AlertRetention: 24 * time.Hour,
MaxEvents: 10000,
EventRetention: 7 * 24 * time.Hour,
MetricsInterval: time.Minute,
CleanupInterval: time.Hour,
}
monitor := NewSecurityMonitor(monitorConfig)
// Create TLS configuration
tlsConfig := &tls.Config{
MinVersion: config.TLSMinVersion,
CipherSuites: config.TLSCipherSuites,
InsecureSkipVerify: false,
PreferServerCipherSuites: true,
}
// Initialize circuit breakers
rpcCircuitBreaker := &CircuitBreaker{
name: "rpc",
config: CircuitBreakerConfig{
FailureThreshold: config.FailureThreshold,
RecoveryTimeout: config.RecoveryTimeout,
MaxRetries: 3,
},
state: CircuitBreakerClosed,
}
arbitrageCircuitBreaker := &CircuitBreaker{
name: "arbitrage",
config: CircuitBreakerConfig{
FailureThreshold: config.FailureThreshold,
RecoveryTimeout: config.RecoveryTimeout,
MaxRetries: 3,
},
state: CircuitBreakerClosed,
}
// Initialize rate limiters
transactionLimiter := rate.NewLimiter(rate.Limit(config.TransactionRPS), config.MaxBurstSize)
rpcLimiter := rate.NewLimiter(rate.Limit(config.RPCRPS), config.MaxBurstSize)
// Create logger instance
securityLogger := logger.New("info", "json", "logs/security.log")
sm := &SecurityManager{
keyManager: keyManager,
inputValidator: inputValidator,
rateLimiter: rateLimiter,
monitor: monitor,
config: config,
logger: securityLogger,
rpcCircuitBreaker: rpcCircuitBreaker,
arbitrageCircuitBreaker: arbitrageCircuitBreaker,
tlsConfig: tlsConfig,
transactionLimiter: transactionLimiter,
rpcLimiter: rpcLimiter,
emergencyMode: false,
securityAlerts: make([]SecurityAlert, 0),
managerMetrics: &ManagerMetrics{},
}
// Start security monitoring
go sm.startSecurityMonitoring()
sm.logger.Info("Security manager initialized successfully")
return sm, nil
}
// ValidateTransaction performs comprehensive transaction validation
func (sm *SecurityManager) ValidateTransaction(ctx context.Context, txParams *TransactionParams) error {
// Check rate limiting
if !sm.transactionLimiter.Allow() {
if sm.monitor != nil {
sm.monitor.RecordEvent(EventTypeError, "security_manager", "Transaction rate limit exceeded", SeverityMedium, map[string]interface{}{
"limit_type": "transaction",
})
}
return fmt.Errorf("transaction rate limit exceeded")
}
// Check emergency mode
if sm.emergencyMode {
return fmt.Errorf("system in emergency mode - transactions disabled")
}
// Validate input parameters (simplified validation)
if txParams.To == nil {
return fmt.Errorf("transaction validation failed: missing recipient")
}
if txParams.Value == nil {
return fmt.Errorf("transaction validation failed: missing value")
}
// Check circuit breaker state
if sm.arbitrageCircuitBreaker.state == CircuitBreakerOpen {
return fmt.Errorf("arbitrage circuit breaker is open")
}
return nil
}
// SecureRPCCall performs RPC calls with security controls
func (sm *SecurityManager) SecureRPCCall(ctx context.Context, method string, params interface{}) (interface{}, error) {
// Check rate limiting
if !sm.rpcLimiter.Allow() {
if sm.monitor != nil {
sm.monitor.RecordEvent(EventTypeError, "security_manager", "RPC rate limit exceeded", SeverityMedium, map[string]interface{}{
"limit_type": "rpc",
"method": method,
})
}
return nil, fmt.Errorf("RPC rate limit exceeded")
}
// Check circuit breaker
if sm.rpcCircuitBreaker.state == CircuitBreakerOpen {
return nil, fmt.Errorf("RPC circuit breaker is open")
}
// Create secure HTTP client (placeholder for actual RPC implementation)
_ = &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: sm.tlsConfig,
},
}
// Implement actual RPC call logic here
// This is a placeholder - actual implementation would depend on the RPC client
// For now, just return a simple response
return map[string]interface{}{"status": "success"}, nil
}
// TriggerEmergencyStop activates emergency mode
func (sm *SecurityManager) TriggerEmergencyStop(reason string) error {
sm.emergencyMode = true
sm.managerMetrics.EmergencyStops++
alert := SecurityAlert{
ID: fmt.Sprintf("emergency-%d", time.Now().Unix()),
Timestamp: time.Now(),
Level: AlertLevelCritical,
Type: AlertTypeConfiguration,
Title: "Emergency Stop Activated",
Description: fmt.Sprintf("Emergency stop triggered: %s", reason),
Source: "security_manager",
Data: map[string]interface{}{
"reason": reason,
},
Actions: []string{"investigate_cause", "review_logs", "manual_restart_required"},
}
sm.addSecurityAlert(alert)
sm.logger.Error("Emergency stop triggered: " + reason)
return nil
}
// RecordFailure records a failure for circuit breaker logic
func (sm *SecurityManager) RecordFailure(component string, err error) {
var cb *CircuitBreaker
switch component {
case "rpc":
cb = sm.rpcCircuitBreaker
case "arbitrage":
cb = sm.arbitrageCircuitBreaker
default:
return
}
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.failureCount++
cb.lastFailureTime = time.Now()
if cb.failureCount >= cb.config.FailureThreshold && cb.state == CircuitBreakerClosed {
cb.state = CircuitBreakerOpen
sm.managerMetrics.CircuitBreakerTrips++
alert := SecurityAlert{
ID: fmt.Sprintf("circuit-breaker-%s-%d", component, time.Now().Unix()),
Timestamp: time.Now(),
Level: AlertLevelError,
Type: AlertTypePerformance,
Title: "Circuit Breaker Opened",
Description: fmt.Sprintf("Circuit breaker opened for component: %s", component),
Source: "security_manager",
Data: map[string]interface{}{
"component": component,
"failure_count": cb.failureCount,
"error": err.Error(),
},
Actions: []string{"investigate_failures", "check_component_health", "manual_intervention_required"},
}
sm.addSecurityAlert(alert)
sm.logger.Warn(fmt.Sprintf("Circuit breaker opened for component: %s, failure count: %d", component, cb.failureCount))
}
}
// RecordSuccess records a success for circuit breaker logic
func (sm *SecurityManager) RecordSuccess(component string) {
var cb *CircuitBreaker
switch component {
case "rpc":
cb = sm.rpcCircuitBreaker
case "arbitrage":
cb = sm.arbitrageCircuitBreaker
default:
return
}
cb.mutex.Lock()
defer cb.mutex.Unlock()
if cb.state == CircuitBreakerHalfOpen {
cb.state = CircuitBreakerClosed
cb.failureCount = 0
sm.logger.Info(fmt.Sprintf("Circuit breaker closed for component: %s", component))
}
}
// addSecurityAlert adds a security alert to the system
func (sm *SecurityManager) addSecurityAlert(alert SecurityAlert) {
sm.alertsMutex.Lock()
defer sm.alertsMutex.Unlock()
sm.securityAlerts = append(sm.securityAlerts, alert)
// Send alert to monitor if available
if sm.monitor != nil {
sm.monitor.TriggerAlert(alert.Level, alert.Type, alert.Title, alert.Description, alert.Source, alert.Data, alert.Actions)
}
// Keep only last 1000 alerts
if len(sm.securityAlerts) > 1000 {
sm.securityAlerts = sm.securityAlerts[len(sm.securityAlerts)-1000:]
}
}
// startSecurityMonitoring starts background security monitoring
func (sm *SecurityManager) startSecurityMonitoring() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
sm.performSecurityChecks()
}
}
}
// performSecurityChecks performs periodic security checks
func (sm *SecurityManager) performSecurityChecks() {
// Check circuit breakers for recovery
sm.checkCircuitBreakerRecovery(sm.rpcCircuitBreaker)
sm.checkCircuitBreakerRecovery(sm.arbitrageCircuitBreaker)
// Check for emergency stop file
if sm.config.EmergencyStopFile != "" {
if _, err := os.Stat(sm.config.EmergencyStopFile); err == nil {
sm.TriggerEmergencyStop("emergency stop file detected")
}
}
}
// checkCircuitBreakerRecovery checks if circuit breakers can transition to half-open
func (sm *SecurityManager) checkCircuitBreakerRecovery(cb *CircuitBreaker) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
if cb.state == CircuitBreakerOpen &&
time.Since(cb.lastFailureTime) > cb.config.RecoveryTimeout {
cb.state = CircuitBreakerHalfOpen
sm.logger.Info(fmt.Sprintf("Circuit breaker transitioned to half-open for component: %s", cb.name))
}
}
// GetManagerMetrics returns current manager metrics
func (sm *SecurityManager) GetManagerMetrics() *ManagerMetrics {
return sm.managerMetrics
}
// GetSecurityMetrics returns current security metrics from monitor
func (sm *SecurityManager) GetSecurityMetrics() *SecurityMetrics {
if sm.monitor != nil {
return sm.monitor.GetMetrics()
}
return &SecurityMetrics{}
}
// GetSecurityAlerts returns recent security alerts
func (sm *SecurityManager) GetSecurityAlerts(limit int) []SecurityAlert {
sm.alertsMutex.RLock()
defer sm.alertsMutex.RUnlock()
if limit <= 0 || limit > len(sm.securityAlerts) {
limit = len(sm.securityAlerts)
}
start := len(sm.securityAlerts) - limit
if start < 0 {
start = 0
}
alerts := make([]SecurityAlert, limit)
copy(alerts, sm.securityAlerts[start:])
return alerts
}
// Shutdown gracefully shuts down the security manager
func (sm *SecurityManager) Shutdown(ctx context.Context) error {
sm.logger.Info("Shutting down security manager")
// Shutdown components
if sm.keyManager != nil {
// Key manager shutdown - simplified (no shutdown method needed)
sm.logger.Info("Key manager stopped")
}
if sm.rateLimiter != nil {
// Rate limiter shutdown - simplified
sm.logger.Info("Rate limiter stopped")
}
if sm.monitor != nil {
// Monitor shutdown - simplified
sm.logger.Info("Security monitor stopped")
}
return nil
}

View File

@@ -0,0 +1,415 @@
package security
import (
"encoding/json"
"fmt"
"math/big"
"strings"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/fraktal/mev-beta/internal/logger"
)
// newTestLogger creates a simple test logger
func newTestLogger() *logger.Logger {
return logger.New("info", "text", "")
}
// FuzzRPCResponseParser tests RPC response parsing with malformed inputs
func FuzzRPCResponseParser(f *testing.F) {
// Add seed corpus with valid RPC responses
validResponses := []string{
`{"jsonrpc":"2.0","id":1,"result":"0x1"}`,
`{"jsonrpc":"2.0","id":2,"result":{"blockNumber":"0x1b4","hash":"0x..."}}`,
`{"jsonrpc":"2.0","id":3,"error":{"code":-32000,"message":"insufficient funds"}}`,
`{"jsonrpc":"2.0","id":4,"result":null}`,
`{"jsonrpc":"2.0","id":5,"result":[]}`,
}
for _, response := range validResponses {
f.Add([]byte(response))
}
f.Fuzz(func(t *testing.T, data []byte) {
// Test that RPC response parsing doesn't panic
defer func() {
if r := recover(); r != nil {
t.Errorf("Panic on RPC input: %v\nInput: %q", r, string(data))
}
}()
// Test JSON parsing
var result interface{}
_ = json.Unmarshal(data, &result)
// Test with InputValidator
validator := NewInputValidator(42161) // Arbitrum chain ID
_ = validator.ValidateRPCResponse(data)
})
}
// FuzzTransactionSigning tests transaction signing with various inputs
func FuzzTransactionSigning(f *testing.F) {
// Setup key manager for testing
config := &KeyManagerConfig{
KeystorePath: "test_keystore",
EncryptionKey: "test_encryption_key_for_fuzzing_32chars",
SessionTimeout: time.Hour,
AuditLogPath: "",
MaxSigningRate: 1000,
KeyRotationDays: 30,
}
testLogger := newTestLogger()
km, err := newKeyManagerForTesting(config, testLogger)
if err != nil {
f.Skip("Failed to create key manager for fuzzing")
}
// Generate test key
testKeyAddr, err := km.GenerateKey("test", KeyPermissions{
CanSign: true,
CanTransfer: true,
})
if err != nil {
f.Skip("Failed to generate test key")
}
// Seed corpus with valid transaction data
validTxData := [][]byte{
{0x02}, // EIP-1559 transaction type
{0x01}, // EIP-2930 transaction type
{0x00}, // Legacy transaction type
}
for _, data := range validTxData {
f.Add(data)
}
f.Fuzz(func(t *testing.T, data []byte) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Panic in transaction signing: %v\nInput: %x", r, data)
}
}()
// Try to create transaction from fuzzed data
if len(data) == 0 {
return
}
// Create a basic transaction for signing tests
tx := types.NewTransaction(
0, // nonce
common.HexToAddress("0x1234"), // to
big.NewInt(1000000000000000000), // value (1 ETH)
21000, // gas limit
big.NewInt(20000000000), // gas price (20 gwei)
data, // data
)
// Test signing
request := &SigningRequest{
Transaction: tx,
From: testKeyAddr,
Purpose: "fuzz_test",
ChainID: big.NewInt(42161),
UrgencyLevel: 1,
}
_, _ = km.SignTransaction(request)
})
}
// FuzzKeyValidation tests key validation with various encryption keys
func FuzzKeyValidation(f *testing.F) {
// Seed with common weak keys
weakKeys := []string{
"test123",
"password",
"12345678901234567890123456789012",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
"test_encryption_key_default_config",
}
for _, key := range weakKeys {
f.Add(key)
}
f.Fuzz(func(t *testing.T, encryptionKey string) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Panic in key validation: %v\nKey: %q", r, encryptionKey)
}
}()
config := &KeyManagerConfig{
EncryptionKey: encryptionKey,
KeystorePath: "test_keystore",
}
// This should not panic, even with invalid keys
err := validateProductionConfig(config)
// Check for expected security rejections
if strings.Contains(strings.ToLower(encryptionKey), "test") ||
strings.Contains(strings.ToLower(encryptionKey), "default") ||
len(encryptionKey) < 32 {
if err == nil {
t.Errorf("Expected validation error for weak key: %q", encryptionKey)
}
}
})
}
// FuzzInputValidator tests input validation with malicious inputs
func FuzzInputValidator(f *testing.F) {
validator := NewInputValidator(42161)
// Seed with various address formats
addresses := []string{
"0x1234567890123456789012345678901234567890",
"0x0000000000000000000000000000000000000000",
"0xffffffffffffffffffffffffffffffffffffffff",
"0x",
"",
"not_an_address",
}
for _, addr := range addresses {
f.Add(addr)
}
f.Fuzz(func(t *testing.T, addressStr string) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Panic in address validation: %v\nAddress: %q", r, addressStr)
}
}()
// Test RPC response validation
rpcData := []byte(fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"result":"%s"}`, addressStr))
_ = validator.ValidateRPCResponse(rpcData)
// Test amount validation if it looks like a number
if len(addressStr) > 0 && addressStr[0] >= '0' && addressStr[0] <= '9' {
amount := new(big.Int)
amount.SetString(addressStr, 10)
// Test basic amount validation logic
if amount.Sign() < 0 {
// Negative amounts should be rejected
}
}
})
}
// TestConcurrentKeyAccess tests concurrent access to key manager
func TestConcurrentKeyAccess(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "test_concurrent_keystore",
EncryptionKey: "concurrent_test_encryption_key_32c",
SessionTimeout: time.Hour,
MaxSigningRate: 1000,
KeyRotationDays: 30,
}
testLogger := newTestLogger()
km, err := newKeyManagerForTesting(config, testLogger)
if err != nil {
t.Fatalf("Failed to create key manager: %v", err)
}
// Generate test key
testKeyAddr, err := km.GenerateKey("concurrent_test", KeyPermissions{
CanSign: true,
CanTransfer: true,
})
if err != nil {
t.Fatalf("Failed to generate test key: %v", err)
}
// Test concurrent signing
const numGoroutines = 100
const signingsPerGoroutine = 10
results := make(chan error, numGoroutines*signingsPerGoroutine)
for i := 0; i < numGoroutines; i++ {
go func(workerID int) {
for j := 0; j < signingsPerGoroutine; j++ {
tx := types.NewTransaction(
uint64(workerID*signingsPerGoroutine+j),
common.HexToAddress("0x1234"),
big.NewInt(1000000000000000000),
21000,
big.NewInt(20000000000),
[]byte(fmt.Sprintf("worker_%d_tx_%d", workerID, j)),
)
request := &SigningRequest{
Transaction: tx,
From: testKeyAddr,
Purpose: fmt.Sprintf("concurrent_test_%d_%d", workerID, j),
ChainID: big.NewInt(42161),
UrgencyLevel: 1,
}
_, err := km.SignTransaction(request)
results <- err
}
}(i)
}
// Collect results
for i := 0; i < numGoroutines*signingsPerGoroutine; i++ {
if err := <-results; err != nil {
t.Errorf("Concurrent signing failed: %v", err)
}
}
}
// TestSecurityMetrics tests security metrics collection
func TestSecurityMetrics(t *testing.T) {
validator := NewInputValidator(42161)
// Test metrics for various validation scenarios
testCases := []struct {
name string
testFunc func() error
expectError bool
}{
{
name: "valid_rpc_response",
testFunc: func() error {
return validator.ValidateRPCResponse([]byte(`{"jsonrpc":"2.0","id":1,"result":"0x1"}`))
},
expectError: false,
},
{
name: "invalid_rpc_response",
testFunc: func() error {
return validator.ValidateRPCResponse([]byte(`invalid json`))
},
expectError: true,
},
{
name: "empty_rpc_response",
testFunc: func() error {
return validator.ValidateRPCResponse([]byte{})
},
expectError: true,
},
{
name: "oversized_rpc_response",
testFunc: func() error {
largeData := make([]byte, 11*1024*1024) // 11MB
return validator.ValidateRPCResponse(largeData)
},
expectError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.testFunc()
if tc.expectError && err == nil {
t.Errorf("Expected error for %s, but got none", tc.name)
}
if !tc.expectError && err != nil {
t.Errorf("Unexpected error for %s: %v", tc.name, err)
}
})
}
}
// BenchmarkSecurityOperations benchmarks critical security operations
func BenchmarkSecurityOperations(b *testing.B) {
config := &KeyManagerConfig{
KeystorePath: "benchmark_keystore",
EncryptionKey: "benchmark_encryption_key_32chars",
SessionTimeout: time.Hour,
MaxSigningRate: 10000,
KeyRotationDays: 30,
}
testLogger := newTestLogger()
km, err := newKeyManagerForTesting(config, testLogger)
if err != nil {
b.Fatalf("Failed to create key manager: %v", err)
}
testKeyAddr, err := km.GenerateKey("benchmark_test", KeyPermissions{
CanSign: true,
CanTransfer: true,
})
if err != nil {
b.Fatalf("Failed to generate test key: %v", err)
}
tx := types.NewTransaction(
0,
common.HexToAddress("0x1234"),
big.NewInt(1000000000000000000),
21000,
big.NewInt(20000000000),
[]byte("benchmark_data"),
)
b.Run("SignTransaction", func(b *testing.B) {
for i := 0; i < b.N; i++ {
request := &SigningRequest{
Transaction: tx,
From: testKeyAddr,
Purpose: fmt.Sprintf("benchmark_%d", i),
ChainID: big.NewInt(42161),
UrgencyLevel: 1,
}
_, err := km.SignTransaction(request)
if err != nil {
b.Fatalf("Signing failed: %v", err)
}
}
})
validator := NewInputValidator(42161)
b.Run("ValidateRPCResponse", func(b *testing.B) {
testData := []byte(`{"jsonrpc":"2.0","id":1,"result":"0x1"}`)
for i := 0; i < b.N; i++ {
_ = validator.ValidateRPCResponse(testData)
}
})
}
// Additional helper for RPC response validation
func (iv *InputValidator) ValidateRPCResponse(data []byte) error {
if len(data) == 0 {
return fmt.Errorf("empty RPC response")
}
if len(data) > 10*1024*1024 { // 10MB limit
return fmt.Errorf("RPC response too large: %d bytes", len(data))
}
// Check for valid JSON
var result interface{}
if err := json.Unmarshal(data, &result); err != nil {
return fmt.Errorf("invalid JSON in RPC response: %w", err)
}
// Check for common RPC response structure
if resultMap, ok := result.(map[string]interface{}); ok {
if jsonrpc, exists := resultMap["jsonrpc"]; exists {
if jsonrpcStr, ok := jsonrpc.(string); !ok || jsonrpcStr != "2.0" {
return fmt.Errorf("invalid JSON-RPC version")
}
}
}
return nil
}

View File

@@ -20,8 +20,8 @@ type TransactionSecurity struct {
// Security thresholds
maxTransactionValue *big.Int
maxGasPrice *big.Int
maxSlippageBps uint64
maxGasPrice *big.Int
maxSlippageBps uint64
// Blacklisted addresses
blacklistedAddresses map[common.Address]bool
@@ -34,41 +34,41 @@ type TransactionSecurity struct {
// TransactionSecurityResult contains the security analysis result
type TransactionSecurityResult struct {
Approved bool `json:"approved"`
RiskLevel string `json:"risk_level"` // LOW, MEDIUM, HIGH, CRITICAL
SecurityChecks map[string]bool `json:"security_checks"`
Warnings []string `json:"warnings"`
Errors []string `json:"errors"`
RecommendedGas *big.Int `json:"recommended_gas,omitempty"`
MaxSlippage uint64 `json:"max_slippage_bps,omitempty"`
EstimatedProfit *big.Int `json:"estimated_profit,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Approved bool `json:"approved"`
RiskLevel string `json:"risk_level"` // LOW, MEDIUM, HIGH, CRITICAL
SecurityChecks map[string]bool `json:"security_checks"`
Warnings []string `json:"warnings"`
Errors []string `json:"errors"`
RecommendedGas *big.Int `json:"recommended_gas,omitempty"`
MaxSlippage uint64 `json:"max_slippage_bps,omitempty"`
EstimatedProfit *big.Int `json:"estimated_profit,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// MEVTransactionRequest represents an MEV transaction request
type MEVTransactionRequest struct {
Transaction *types.Transaction `json:"transaction"`
ExpectedProfit *big.Int `json:"expected_profit"`
MaxSlippage uint64 `json:"max_slippage_bps"`
Deadline time.Time `json:"deadline"`
Priority string `json:"priority"` // LOW, MEDIUM, HIGH
Source string `json:"source"` // Origin of the transaction
Transaction *types.Transaction `json:"transaction"`
ExpectedProfit *big.Int `json:"expected_profit"`
MaxSlippage uint64 `json:"max_slippage_bps"`
Deadline time.Time `json:"deadline"`
Priority string `json:"priority"` // LOW, MEDIUM, HIGH
Source string `json:"source"` // Origin of the transaction
}
// NewTransactionSecurity creates a new transaction security checker
func NewTransactionSecurity(client *ethclient.Client, chainID uint64) *TransactionSecurity {
return &TransactionSecurity{
inputValidator: NewInputValidator(chainID),
safeMath: NewSafeMath(),
client: client,
chainID: chainID,
maxTransactionValue: new(big.Int).Mul(big.NewInt(1000), big.NewInt(1e18)), // 1000 ETH
maxGasPrice: new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9)), // 10000 Gwei
maxSlippageBps: 1000, // 10%
safeMath: NewSafeMath(),
client: client,
chainID: chainID,
maxTransactionValue: new(big.Int).Mul(big.NewInt(1000), big.NewInt(1e18)), // 1000 ETH
maxGasPrice: new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9)), // 10000 Gwei
maxSlippageBps: 1000, // 10%
blacklistedAddresses: make(map[common.Address]bool),
transactionCounts: make(map[common.Address]int),
lastReset: time.Now(),
maxTxPerAddress: 100, // Max 100 transactions per address per hour
transactionCounts: make(map[common.Address]int),
lastReset: time.Now(),
maxTxPerAddress: 100, // Max 100 transactions per address per hour
}
}
@@ -312,11 +312,13 @@ func (ts *TransactionSecurity) rateLimitingChecks(tx *types.Transaction, result
ts.lastReset = time.Now()
}
// Get sender address (this would require signature recovery in real implementation)
// For now, we'll use the 'to' address as a placeholder
var addr common.Address
if tx.To() != nil {
addr = *tx.To()
// Get sender address via signature recovery
signer := types.LatestSignerForChainID(tx.ChainId())
addr, err := types.Sender(signer, tx)
if err != nil {
// If signature recovery fails, use zero address
// Note: In production, this should be logged to a centralized logging system
addr = common.Address{}
}
// Increment counter
@@ -410,12 +412,12 @@ func (ts *TransactionSecurity) RemoveBlacklistedAddress(addr common.Address) {
// GetSecurityMetrics returns current security metrics
func (ts *TransactionSecurity) GetSecurityMetrics() map[string]interface{} {
return map[string]interface{}{
"blacklisted_addresses_count": len(ts.blacklistedAddresses),
"active_address_count": len(ts.transactionCounts),
"blacklisted_addresses_count": len(ts.blacklistedAddresses),
"active_address_count": len(ts.transactionCounts),
"max_transactions_per_address": ts.maxTxPerAddress,
"max_transaction_value": ts.maxTransactionValue.String(),
"max_gas_price": ts.maxGasPrice.String(),
"max_slippage_bps": ts.maxSlippageBps,
"last_reset": ts.lastReset.Format(time.RFC3339),
"max_transaction_value": ts.maxTransactionValue.String(),
"max_gas_price": ts.maxGasPrice.String(),
"max_slippage_bps": ts.maxSlippageBps,
"last_reset": ts.lastReset.Format(time.RFC3339),
}
}
}