package security import ( "fmt" "math/big" "regexp" "strings" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" ) // InputValidator provides comprehensive input validation for all MEV bot operations type InputValidator struct { safeMath *SafeMath maxGasLimit uint64 maxGasPrice *big.Int chainID uint64 } // ValidationResult contains the result of input validation type ValidationResult struct { Valid bool `json:"valid"` Errors []string `json:"errors,omitempty"` Warnings []string `json:"warnings,omitempty"` } // TransactionParams represents transaction parameters for validation type TransactionParams struct { To *common.Address `json:"to"` Value *big.Int `json:"value"` Data []byte `json:"data"` Gas uint64 `json:"gas"` GasPrice *big.Int `json:"gas_price"` Nonce uint64 `json:"nonce"` } // SwapParams represents swap parameters for validation type SwapParams struct { TokenIn common.Address `json:"token_in"` TokenOut common.Address `json:"token_out"` AmountIn *big.Int `json:"amount_in"` AmountOut *big.Int `json:"amount_out"` Slippage uint64 `json:"slippage_bps"` Deadline time.Time `json:"deadline"` Recipient common.Address `json:"recipient"` Pool common.Address `json:"pool"` } // ArbitrageParams represents arbitrage parameters for validation type ArbitrageParams struct { BuyPool common.Address `json:"buy_pool"` SellPool common.Address `json:"sell_pool"` Token common.Address `json:"token"` AmountIn *big.Int `json:"amount_in"` MinProfit *big.Int `json:"min_profit"` MaxGasPrice *big.Int `json:"max_gas_price"` Deadline time.Time `json:"deadline"` } // NewInputValidator creates a new input validator with security limits func NewInputValidator(chainID uint64) *InputValidator { return &InputValidator{ safeMath: NewSafeMath(), maxGasLimit: 15000000, // 15M gas limit maxGasPrice: new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9)), // 10000 Gwei chainID: chainID, } } // ValidateAddress validates an Ethereum address func (iv *InputValidator) ValidateAddress(addr common.Address) *ValidationResult { result := &ValidationResult{Valid: true} // Check for zero address if addr == (common.Address{}) { result.Valid = false result.Errors = append(result.Errors, "address cannot be zero address") return result } // Check for known malicious addresses (extend this list as needed) maliciousAddresses := []common.Address{ // Add known malicious addresses here common.HexToAddress("0x0000000000000000000000000000000000000000"), } for _, malicious := range maliciousAddresses { if addr == malicious { result.Valid = false result.Errors = append(result.Errors, "address is flagged as malicious") return result } } // Check for suspicious patterns addrStr := addr.Hex() if strings.Contains(strings.ToLower(addrStr), "dead") || strings.Contains(strings.ToLower(addrStr), "beef") { result.Warnings = append(result.Warnings, "address contains suspicious patterns") } return result } // ValidateTransaction validates a complete transaction func (iv *InputValidator) ValidateTransaction(tx *types.Transaction) *ValidationResult { result := &ValidationResult{Valid: true} // Validate chain ID if tx.ChainId().Uint64() != iv.chainID { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("invalid chain ID: expected %d, got %d", iv.chainID, tx.ChainId().Uint64())) } // Validate gas limit if tx.Gas() > iv.maxGasLimit { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("gas limit %d exceeds maximum %d", tx.Gas(), iv.maxGasLimit)) } if tx.Gas() < 21000 { result.Valid = false result.Errors = append(result.Errors, "gas limit below minimum 21000") } // Validate gas price if tx.GasPrice() != nil { if err := iv.safeMath.ValidateGasPrice(tx.GasPrice()); err != nil { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("invalid gas price: %v", err)) } } // Validate transaction value if tx.Value() != nil { if err := iv.safeMath.ValidateTransactionValue(tx.Value()); err != nil { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("invalid transaction value: %v", err)) } } // Validate recipient address if tx.To() != nil { addrResult := iv.ValidateAddress(*tx.To()) if !addrResult.Valid { result.Valid = false result.Errors = append(result.Errors, "invalid recipient address") result.Errors = append(result.Errors, addrResult.Errors...) } result.Warnings = append(result.Warnings, addrResult.Warnings...) } // Validate transaction data for suspicious patterns if len(tx.Data()) > 0 { dataResult := iv.validateTransactionData(tx.Data()) if !dataResult.Valid { result.Valid = false result.Errors = append(result.Errors, dataResult.Errors...) } result.Warnings = append(result.Warnings, dataResult.Warnings...) } return result } // ValidateSwapParams validates swap parameters func (iv *InputValidator) ValidateSwapParams(params *SwapParams) *ValidationResult { result := &ValidationResult{Valid: true} // Validate addresses for _, addr := range []common.Address{params.TokenIn, params.TokenOut, params.Recipient, params.Pool} { addrResult := iv.ValidateAddress(addr) if !addrResult.Valid { result.Valid = false result.Errors = append(result.Errors, addrResult.Errors...) } result.Warnings = append(result.Warnings, addrResult.Warnings...) } // Validate tokens are different if params.TokenIn == params.TokenOut { result.Valid = false result.Errors = append(result.Errors, "token in and token out cannot be the same") } // Validate amounts if params.AmountIn == nil || params.AmountIn.Sign() <= 0 { result.Valid = false result.Errors = append(result.Errors, "amount in must be positive") } if params.AmountOut == nil || params.AmountOut.Sign() <= 0 { result.Valid = false result.Errors = append(result.Errors, "amount out must be positive") } // Validate slippage if params.Slippage > 10000 { // Max 100% result.Valid = false result.Errors = append(result.Errors, "slippage cannot exceed 100%") } if params.Slippage > 500 { // Warn if > 5% result.Warnings = append(result.Warnings, "slippage above 5% detected") } // Validate deadline if params.Deadline.Before(time.Now()) { result.Valid = false result.Errors = append(result.Errors, "deadline is in the past") } if params.Deadline.After(time.Now().Add(1 * time.Hour)) { result.Warnings = append(result.Warnings, "deadline is more than 1 hour in the future") } return result } // ValidateArbitrageParams validates arbitrage parameters func (iv *InputValidator) ValidateArbitrageParams(params *ArbitrageParams) *ValidationResult { result := &ValidationResult{Valid: true} // Validate addresses for _, addr := range []common.Address{params.BuyPool, params.SellPool, params.Token} { addrResult := iv.ValidateAddress(addr) if !addrResult.Valid { result.Valid = false result.Errors = append(result.Errors, addrResult.Errors...) } result.Warnings = append(result.Warnings, addrResult.Warnings...) } // Validate pools are different if params.BuyPool == params.SellPool { result.Valid = false result.Errors = append(result.Errors, "buy pool and sell pool cannot be the same") } // Validate amounts if params.AmountIn == nil || params.AmountIn.Sign() <= 0 { result.Valid = false result.Errors = append(result.Errors, "amount in must be positive") } if params.MinProfit == nil || params.MinProfit.Sign() <= 0 { result.Valid = false result.Errors = append(result.Errors, "minimum profit must be positive") } // Validate gas price if params.MaxGasPrice != nil { if err := iv.safeMath.ValidateGasPrice(params.MaxGasPrice); err != nil { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("invalid max gas price: %v", err)) } } // Validate deadline if params.Deadline.Before(time.Now()) { result.Valid = false result.Errors = append(result.Errors, "deadline is in the past") } // Check if arbitrage is potentially profitable if params.AmountIn != nil && params.MinProfit != nil { // Rough profitability check (at least 0.1% profit) minProfitThreshold, _ := iv.safeMath.SafePercent(params.AmountIn, 10) // 0.1% if params.MinProfit.Cmp(minProfitThreshold) < 0 { result.Warnings = append(result.Warnings, "minimum profit threshold is very low") } } return result } // validateTransactionData validates transaction data for suspicious patterns func (iv *InputValidator) validateTransactionData(data []byte) *ValidationResult { result := &ValidationResult{Valid: true} // Check data size if len(data) > 100000 { // 100KB limit result.Valid = false result.Errors = append(result.Errors, "transaction data exceeds size limit") return result } // Convert to hex string for pattern matching dataHex := common.Bytes2Hex(data) // Check for suspicious patterns suspiciousPatterns := []struct { pattern string message string critical bool }{ {"selfdestruct", "contains selfdestruct operation", true}, {"delegatecall", "contains delegatecall operation", false}, {"create2", "contains create2 operation", false}, {"ff" + strings.Repeat("00", 19), "contains potential burn address", false}, } for _, suspicious := range suspiciousPatterns { if strings.Contains(strings.ToLower(dataHex), strings.ToLower(suspicious.pattern)) { if suspicious.critical { result.Valid = false result.Errors = append(result.Errors, "transaction "+suspicious.message) } else { result.Warnings = append(result.Warnings, "transaction "+suspicious.message) } } } // Check for known function selectors of risky operations if len(data) >= 4 { selector := common.Bytes2Hex(data[:4]) riskySelectors := map[string]string{ "ff6cae96": "selfdestruct function", "9dc29fac": "burn function", "42966c68": "burn function (alternative)", } if message, exists := riskySelectors[selector]; exists { result.Warnings = append(result.Warnings, "transaction calls "+message) } } return result } // ValidateString validates string inputs for injection attacks func (iv *InputValidator) ValidateString(input, fieldName string, maxLength int) *ValidationResult { result := &ValidationResult{Valid: true} // Check length if len(input) > maxLength { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("%s exceeds maximum length %d", fieldName, maxLength)) } // Check for null bytes if strings.Contains(input, "\x00") { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("%s contains null bytes", fieldName)) } // Check for control characters controlCharPattern := regexp.MustCompile(`[\x00-\x1f\x7f-\x9f]`) if controlCharPattern.MatchString(input) { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("%s contains control characters", fieldName)) } // Check for SQL injection patterns sqlPatterns := []string{ "'", "\"", "--", "/*", "*/", "xp_", "sp_", "exec", "execute", "select", "insert", "update", "delete", "drop", "create", "alter", "union", "join", "script", "javascript", } lowerInput := strings.ToLower(input) for _, pattern := range sqlPatterns { if strings.Contains(lowerInput, pattern) { result.Warnings = append(result.Warnings, fmt.Sprintf("%s contains potentially dangerous pattern: %s", fieldName, pattern)) } } return result } // ValidateNumericString validates numeric string inputs func (iv *InputValidator) ValidateNumericString(input, fieldName string) *ValidationResult { result := &ValidationResult{Valid: true} // Check if string is numeric numericPattern := regexp.MustCompile(`^[0-9]+(\.[0-9]+)?$`) if !numericPattern.MatchString(input) { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("%s must be numeric", fieldName)) return result } // Check for leading zeros (except for decimals) if len(input) > 1 && input[0] == '0' && input[1] != '.' { result.Warnings = append(result.Warnings, fmt.Sprintf("%s has leading zeros", fieldName)) } // Check for reasonable decimal places if strings.Contains(input, ".") { parts := strings.Split(input, ".") if len(parts[1]) > 18 { result.Warnings = append(result.Warnings, fmt.Sprintf("%s has excessive decimal places", fieldName)) } } return result } // ValidateBatchSize validates batch operation sizes func (iv *InputValidator) ValidateBatchSize(size int, operation string) *ValidationResult { result := &ValidationResult{Valid: true} maxBatchSizes := map[string]int{ "transaction": 100, "swap": 50, "arbitrage": 20, "query": 1000, } maxSize, exists := maxBatchSizes[operation] if !exists { maxSize = 50 // Default } if size <= 0 { result.Valid = false result.Errors = append(result.Errors, "batch size must be positive") } if size > maxSize { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("batch size %d exceeds maximum %d for %s operations", size, maxSize, operation)) } if size > maxSize/2 { result.Warnings = append(result.Warnings, fmt.Sprintf("large batch size %d for %s operations", size, operation)) } return result } // SanitizeInput sanitizes string input by removing dangerous characters func (iv *InputValidator) SanitizeInput(input string) string { // Remove null bytes input = strings.ReplaceAll(input, "\x00", "") // Remove control characters except newline and tab controlCharPattern := regexp.MustCompile(`[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]`) input = controlCharPattern.ReplaceAllString(input, "") // Trim whitespace input = strings.TrimSpace(input) return input } // ValidateExternalData performs comprehensive validation for data from external sources func (iv *InputValidator) ValidateExternalData(data []byte, source string, maxSize int) *ValidationResult { result := &ValidationResult{Valid: true} // Comprehensive bounds checking if data == nil { result.Valid = false result.Errors = append(result.Errors, "external data cannot be nil") return result } // Check size limits if len(data) > maxSize { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("external data size %d exceeds maximum %d for source %s", len(data), maxSize, source)) return result } // Check for obviously malformed data patterns if len(data) > 0 { // Check for all-zero data (suspicious) allZero := true for _, b := range data { if b != 0 { allZero = false break } } if allZero && len(data) > 32 { result.Warnings = append(result.Warnings, "external data appears to be all zeros") } // Check for repetitive patterns that might indicate malformed data if len(data) >= 4 { pattern := data[:4] repetitive := true for i := 4; i < len(data) && i < 1000; i += 4 { // Check first 1KB for performance if i+4 <= len(data) { for j := 0; j < 4; j++ { if data[i+j] != pattern[j] { repetitive = false break } } if !repetitive { break } } } if repetitive && len(data) > 64 { result.Warnings = append(result.Warnings, "external data contains highly repetitive patterns") } } } return result } // ValidateArrayBounds validates array access bounds to prevent buffer overflows func (iv *InputValidator) ValidateArrayBounds(arrayLen, index int, operation string) *ValidationResult { result := &ValidationResult{Valid: true} if arrayLen < 0 { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("negative array length %d in operation %s", arrayLen, operation)) return result } if index < 0 { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("negative array index %d in operation %s", index, operation)) return result } if index >= arrayLen { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("array index %d exceeds length %d in operation %s", index, arrayLen, operation)) return result } // Maximum reasonable array size (prevent DoS) const maxArraySize = 100000 if arrayLen > maxArraySize { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("array length %d exceeds maximum %d in operation %s", arrayLen, maxArraySize, operation)) return result } return result } // ValidateBufferAccess validates buffer access operations func (iv *InputValidator) ValidateBufferAccess(bufferSize, offset, length int, operation string) *ValidationResult { result := &ValidationResult{Valid: true} if bufferSize < 0 { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("negative buffer size %d in operation %s", bufferSize, operation)) return result } if offset < 0 { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("negative buffer offset %d in operation %s", offset, operation)) return result } if length < 0 { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("negative buffer length %d in operation %s", length, operation)) return result } if offset+length > bufferSize { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("buffer access [%d:%d] exceeds buffer size %d in operation %s", offset, offset+length, bufferSize, operation)) return result } // Check for integer overflow in offset+length calculation if offset > 0 && length > 0 { // Use uint64 to detect overflow sum := uint64(offset) + uint64(length) if sum > uint64(^uint(0)>>1) { // Max int value result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("integer overflow in buffer access calculation: offset %d + length %d in operation %s", offset, length, operation)) return result } } return result } // ValidateMemoryAllocation validates memory allocation requests func (iv *InputValidator) ValidateMemoryAllocation(size int, purpose string) *ValidationResult { result := &ValidationResult{Valid: true} if size < 0 { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("negative memory allocation size %d for purpose %s", size, purpose)) return result } if size == 0 { result.Warnings = append(result.Warnings, fmt.Sprintf("zero memory allocation for purpose %s", purpose)) return result } // Set reasonable limits based on purpose limits := map[string]int{ "transaction_data": 1024 * 1024, // 1MB "abi_decoding": 512 * 1024, // 512KB "log_message": 64 * 1024, // 64KB "swap_params": 4 * 1024, // 4KB "address_list": 100 * 1024, // 100KB "default": 256 * 1024, // 256KB } limit, exists := limits[purpose] if !exists { limit = limits["default"] } if size > limit { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("memory allocation size %d exceeds limit %d for purpose %s", size, limit, purpose)) return result } // Warn for large allocations if size > limit/2 { result.Warnings = append(result.Warnings, fmt.Sprintf("large memory allocation %d for purpose %s", size, purpose)) } return result }