package validation import ( "fmt" "math/big" "regexp" "strings" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/fraktal/mev-beta/internal/logger" ) // InputValidator provides comprehensive validation for transaction parameters and user inputs type InputValidator struct { logger *logger.Logger maxGasLimit uint64 maxGasPrice *big.Int maxValue *big.Int allowedMethods map[string]bool // method signatures that are allowed // Regex patterns for validation addressPattern *regexp.Regexp txHashPattern *regexp.Regexp blockHashPattern *regexp.Regexp hexDataPattern *regexp.Regexp } // ValidationConfig contains configuration for input validation type ValidationConfig struct { MaxGasLimit uint64 `json:"max_gas_limit"` MaxGasPriceGwei int64 `json:"max_gas_price_gwei"` MaxValueEther int64 `json:"max_value_ether"` AllowedMethods []string `json:"allowed_methods"` RequireDeadline bool `json:"require_deadline"` MaxDeadlineHours int `json:"max_deadline_hours"` } // TransactionValidationResult contains the result of transaction validation type TransactionValidationResult struct { IsValid bool `json:"is_valid"` Errors []string `json:"errors"` Warnings []string `json:"warnings"` RiskLevel string `json:"risk_level"` // "low", "medium", "high", "critical" EstimatedCost *big.Int `json:"estimated_cost,omitempty"` } // SwapParams represents swap transaction parameters type SwapParams struct { TokenIn common.Address `json:"token_in"` TokenOut common.Address `json:"token_out"` AmountIn *big.Int `json:"amount_in"` AmountOutMinimum *big.Int `json:"amount_out_minimum"` Fee uint32 `json:"fee"` Recipient common.Address `json:"recipient"` Deadline uint64 `json:"deadline"` SlippageTolerance *big.Int `json:"slippage_tolerance"` // in basis points } // ArbitrageParams represents arbitrage transaction parameters type ArbitrageParams struct { Path []common.Address `json:"path"` AmountIn *big.Int `json:"amount_in"` MinAmountOut *big.Int `json:"min_amount_out"` Deadline uint64 `json:"deadline"` MaxGasPrice *big.Int `json:"max_gas_price"` ProfitThreshold *big.Int `json:"profit_threshold"` MaxSlippageBps *big.Int `json:"max_slippage_bps"` } // LiquidityParams represents liquidity provision parameters type LiquidityParams struct { Token0 common.Address `json:"token0"` Token1 common.Address `json:"token1"` Fee uint32 `json:"fee"` TickLower int32 `json:"tick_lower"` TickUpper int32 `json:"tick_upper"` Amount0Desired *big.Int `json:"amount0_desired"` Amount1Desired *big.Int `json:"amount1_desired"` Amount0Min *big.Int `json:"amount0_min"` Amount1Min *big.Int `json:"amount1_min"` Recipient common.Address `json:"recipient"` Deadline uint64 `json:"deadline"` } // NewInputValidator creates a new input validator func NewInputValidator(config *ValidationConfig, logger *logger.Logger) *InputValidator { if config == nil { config = getDefaultValidationConfig() } validator := &InputValidator{ logger: logger, maxGasLimit: config.MaxGasLimit, maxGasPrice: big.NewInt(config.MaxGasPriceGwei * 1e9), // Convert Gwei to Wei maxValue: big.NewInt(config.MaxValueEther * 1e18), // Convert Ether to Wei allowedMethods: make(map[string]bool), addressPattern: regexp.MustCompile(`^0x[a-fA-F0-9]{40}$`), txHashPattern: regexp.MustCompile(`^0x[a-fA-F0-9]{64}$`), blockHashPattern: regexp.MustCompile(`^0x[a-fA-F0-9]{64}$`), hexDataPattern: regexp.MustCompile(`^0x[a-fA-F0-9]*$`), } // Initialize allowed methods for _, method := range config.AllowedMethods { validator.allowedMethods[method] = true } return validator } // ValidateTransaction performs comprehensive validation of a transaction func (iv *InputValidator) ValidateTransaction(tx *types.Transaction) (*TransactionValidationResult, error) { result := &TransactionValidationResult{ IsValid: true, Errors: make([]string, 0), Warnings: make([]string, 0), RiskLevel: "low", } // 1. Basic transaction validation iv.validateBasicTransaction(tx, result) // 2. Gas validation iv.validateGas(tx, result) // 3. Value validation iv.validateValue(tx, result) // 4. Recipient validation iv.validateRecipient(tx, result) // 5. Data validation (for contract calls) if len(tx.Data()) > 0 { iv.validateData(tx.Data(), result) } // 6. Calculate estimated cost result.EstimatedCost = iv.calculateEstimatedCost(tx) // 7. Determine final validity and risk level iv.finalizeValidation(result) if len(result.Errors) > 0 { iv.logger.Warn(fmt.Sprintf("Transaction validation failed: %v", result.Errors)) } return result, nil } // ValidateSwapParams validates swap transaction parameters func (iv *InputValidator) ValidateSwapParams(params *SwapParams) (*TransactionValidationResult, error) { result := &TransactionValidationResult{ IsValid: true, Errors: make([]string, 0), Warnings: make([]string, 0), RiskLevel: "low", } // 1. Validate token addresses if err := iv.ValidateAddress(params.TokenIn); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid token_in address: %v", err)) } if err := iv.ValidateAddress(params.TokenOut); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid token_out address: %v", err)) } // 2. Check tokens are different if params.TokenIn == params.TokenOut { result.Errors = append(result.Errors, "token_in and token_out must be different") } // 3. Validate amounts if err := iv.ValidateAmount(params.AmountIn); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid amount_in: %v", err)) } if err := iv.ValidateAmount(params.AmountOutMinimum); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid amount_out_minimum: %v", err)) } // 4. Validate slippage tolerance if err := iv.ValidateSlippage(params.SlippageTolerance); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid slippage tolerance: %v", err)) } // 5. Validate fee tier if err := iv.validateFeeTier(params.Fee); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid fee tier: %v", err)) } // 6. Validate recipient if err := iv.ValidateAddress(params.Recipient); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid recipient address: %v", err)) } // 7. Validate deadline if err := iv.validateDeadline(params.Deadline); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid deadline: %v", err)) } // 8. Additional security checks iv.performSwapSecurityChecks(params, result) iv.finalizeValidation(result) return result, nil } // ValidateArbitrageParams validates arbitrage transaction parameters func (iv *InputValidator) ValidateArbitrageParams(params *ArbitrageParams) (*TransactionValidationResult, error) { result := &TransactionValidationResult{ IsValid: true, Errors: make([]string, 0), Warnings: make([]string, 0), RiskLevel: "medium", // Arbitrage is inherently riskier } // 1. Validate path if len(params.Path) < 2 { result.Errors = append(result.Errors, "arbitrage path must have at least 2 tokens") } if len(params.Path) > 5 { result.Warnings = append(result.Warnings, "long arbitrage paths increase gas costs and slippage") } for i, addr := range params.Path { if err := iv.ValidateAddress(addr); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid address at path[%d]: %v", i, err)) } } // 2. Check for duplicate tokens in path seen := make(map[common.Address]bool) for _, addr := range params.Path { if seen[addr] { result.Errors = append(result.Errors, fmt.Sprintf("Duplicate token in path: %s", addr.Hex())) } seen[addr] = true } // 3. Validate amounts if err := iv.ValidateAmount(params.AmountIn); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid amount_in: %v", err)) } if err := iv.ValidateAmount(params.MinAmountOut); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid min_amount_out: %v", err)) } if err := iv.ValidateAmount(params.ProfitThreshold); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid profit_threshold: %v", err)) } // 4. Validate profit expectation if params.MinAmountOut.Cmp(params.AmountIn) <= 0 { result.Errors = append(result.Errors, "min_amount_out must be greater than amount_in for profitable arbitrage") } // 5. Validate slippage if err := iv.ValidateSlippage(params.MaxSlippageBps); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid max_slippage: %v", err)) } // 6. Validate gas price if params.MaxGasPrice != nil && params.MaxGasPrice.Cmp(iv.maxGasPrice) > 0 { result.Warnings = append(result.Warnings, "very high gas price may eat into profits") } // 7. Validate deadline if err := iv.validateDeadline(params.Deadline); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Invalid deadline: %v", err)) } iv.finalizeValidation(result) return result, nil } // ValidateAddress validates an Ethereum address func (iv *InputValidator) ValidateAddress(addr common.Address) error { if addr == (common.Address{}) { return fmt.Errorf("address cannot be zero") } // Check format using regex if !iv.addressPattern.MatchString(addr.Hex()) { return fmt.Errorf("invalid address format") } // Check for common invalid addresses if iv.isKnownInvalidAddress(addr) { return fmt.Errorf("address is known to be invalid or malicious") } return nil } // ValidateAmount validates a big.Int amount func (iv *InputValidator) ValidateAmount(amount *big.Int) error { if amount == nil { return fmt.Errorf("amount cannot be nil") } if amount.Sign() < 0 { return fmt.Errorf("amount cannot be negative") } if amount.Sign() == 0 { return fmt.Errorf("amount cannot be zero") } // Check for unreasonably large amounts (prevent overflow attacks) maxAmount := new(big.Int).Exp(big.NewInt(10), big.NewInt(30), nil) // 10^30 wei if amount.Cmp(maxAmount) > 0 { return fmt.Errorf("amount exceeds maximum allowed value") } return nil } // ValidateSlippage validates slippage tolerance in basis points func (iv *InputValidator) ValidateSlippage(slippageBps *big.Int) error { if slippageBps == nil { return fmt.Errorf("slippage cannot be nil") } if slippageBps.Sign() < 0 { return fmt.Errorf("slippage cannot be negative") } // Maximum 50% slippage (5000 basis points) maxSlippage := big.NewInt(5000) if slippageBps.Cmp(maxSlippage) > 0 { return fmt.Errorf("slippage tolerance cannot exceed 50%%") } return nil } // validateBasicTransaction validates basic transaction properties func (iv *InputValidator) validateBasicTransaction(tx *types.Transaction, result *TransactionValidationResult) { // Check nonce if tx.Nonce() > 1000000 { result.Warnings = append(result.Warnings, "unusually high nonce") } // Check transaction size txSize := len(tx.Data()) + 200 // approximate overhead if txSize > 128*1024 { // 128KB limit result.Errors = append(result.Errors, "transaction size exceeds limit") } } // validateGas validates gas-related parameters func (iv *InputValidator) validateGas(tx *types.Transaction, result *TransactionValidationResult) { // Validate gas limit if tx.Gas() == 0 { result.Errors = append(result.Errors, "gas limit cannot be zero") } if tx.Gas() > iv.maxGasLimit { result.Errors = append(result.Errors, fmt.Sprintf("gas limit %d exceeds maximum %d", tx.Gas(), iv.maxGasLimit)) } // Validate gas price if tx.GasPrice() != nil { if tx.GasPrice().Sign() == 0 { result.Errors = append(result.Errors, "gas price cannot be zero") } if tx.GasPrice().Cmp(iv.maxGasPrice) > 0 { result.Errors = append(result.Errors, fmt.Sprintf("gas price exceeds maximum")) } // Warn about very high gas prices highGasPrice := new(big.Int).Mul(big.NewInt(100), big.NewInt(1e9)) // 100 Gwei if tx.GasPrice().Cmp(highGasPrice) > 0 { result.Warnings = append(result.Warnings, "very high gas price") result.RiskLevel = "medium" } } // Validate gas fee cap and tip for EIP-1559 transactions if tx.GasFeeCap() != nil { if tx.GasFeeCap().Sign() == 0 { result.Errors = append(result.Errors, "gas fee cap cannot be zero") } if tx.GasFeeCap().Cmp(iv.maxGasPrice) > 0 { result.Errors = append(result.Errors, "gas fee cap exceeds maximum") } } if tx.GasTipCap() != nil { if tx.GasTipCap().Sign() < 0 { result.Errors = append(result.Errors, "gas tip cap cannot be negative") } if tx.GasFeeCap() != nil && tx.GasTipCap().Cmp(tx.GasFeeCap()) > 0 { result.Errors = append(result.Errors, "gas tip cap cannot exceed gas fee cap") } } } // validateValue validates the transaction value func (iv *InputValidator) validateValue(tx *types.Transaction, result *TransactionValidationResult) { if tx.Value() == nil { return } if tx.Value().Sign() < 0 { result.Errors = append(result.Errors, "transaction value cannot be negative") } if tx.Value().Cmp(iv.maxValue) > 0 { result.Errors = append(result.Errors, "transaction value exceeds maximum allowed") } // Warn about large value transfers largeValue := new(big.Int).Mul(big.NewInt(10), big.NewInt(1e18)) // 10 ETH if tx.Value().Cmp(largeValue) > 0 { result.Warnings = append(result.Warnings, "large value transfer") result.RiskLevel = "high" } } // validateRecipient validates the transaction recipient func (iv *InputValidator) validateRecipient(tx *types.Transaction, result *TransactionValidationResult) { if tx.To() == nil { // Contract creation transaction result.Warnings = append(result.Warnings, "contract creation transaction") result.RiskLevel = "high" return } // Check for zero address if *tx.To() == (common.Address{}) { result.Errors = append(result.Errors, "recipient cannot be zero address") } // Check for known malicious addresses if iv.isKnownInvalidAddress(*tx.To()) { result.Errors = append(result.Errors, "recipient is known malicious address") } } // validateData validates transaction data for contract calls func (iv *InputValidator) validateData(data []byte, result *TransactionValidationResult) { if len(data) == 0 { return } if len(data) < 4 { result.Errors = append(result.Errors, "invalid function call data") return } // Extract function selector selector := data[:4] methodSig := fmt.Sprintf("0x%x", selector) // Check if method is allowed if len(iv.allowedMethods) > 0 && !iv.allowedMethods[methodSig] { result.Errors = append(result.Errors, fmt.Sprintf("method %s not allowed", methodSig)) } // Check for suspicious patterns if iv.hasSuspiciousPatterns(data) { result.Warnings = append(result.Warnings, "transaction data contains suspicious patterns") result.RiskLevel = "high" } } // validateFeeTier validates Uniswap V3 fee tiers func (iv *InputValidator) validateFeeTier(fee uint32) error { validFees := []uint32{100, 500, 3000, 10000} // 0.01%, 0.05%, 0.3%, 1% for _, validFee := range validFees { if fee == validFee { return nil } } return fmt.Errorf("invalid fee tier: %d (must be one of: 100, 500, 3000, 10000)", fee) } // validateDeadline validates transaction deadline func (iv *InputValidator) validateDeadline(deadline uint64) error { if deadline == 0 { return fmt.Errorf("deadline cannot be zero") } now := uint64(time.Now().Unix()) if deadline <= now { return fmt.Errorf("deadline must be in the future") } // Warn about very long deadlines maxDeadline := now + 24*60*60 // 24 hours from now if deadline > maxDeadline { return fmt.Errorf("deadline too far in future (max 24 hours)") } return nil } // performSwapSecurityChecks performs additional security checks for swap parameters func (iv *InputValidator) performSwapSecurityChecks(params *SwapParams, result *TransactionValidationResult) { // Check for sandwich attack vulnerability if params.SlippageTolerance != nil && params.SlippageTolerance.Cmp(big.NewInt(500)) > 0 { // >5% result.Warnings = append(result.Warnings, "high slippage tolerance increases sandwich attack risk") result.RiskLevel = "medium" } // Check for MEV vulnerability if params.AmountIn != nil { // Large trades are more susceptible to MEV largeTradeThreshold := new(big.Int).Mul(big.NewInt(100), big.NewInt(1e18)) // 100 tokens if params.AmountIn.Cmp(largeTradeThreshold) > 0 { result.Warnings = append(result.Warnings, "large trade may be subject to MEV attacks") } } // Check deadline proximity now := uint64(time.Now().Unix()) if params.Deadline-now < 60 { // Less than 1 minute result.Warnings = append(result.Warnings, "very short deadline may cause transaction failures") } } // calculateEstimatedCost estimates the total cost of a transaction func (iv *InputValidator) calculateEstimatedCost(tx *types.Transaction) *big.Int { cost := new(big.Int) // Gas cost if tx.GasPrice() != nil { gasCost := new(big.Int).Mul(big.NewInt(int64(tx.Gas())), tx.GasPrice()) cost.Add(cost, gasCost) } else if tx.GasFeeCap() != nil { // For EIP-1559 transactions, use fee cap as estimate gasCost := new(big.Int).Mul(big.NewInt(int64(tx.Gas())), tx.GasFeeCap()) cost.Add(cost, gasCost) } // Value transfer if tx.Value() != nil { cost.Add(cost, tx.Value()) } return cost } // finalizeValidation determines final validation result func (iv *InputValidator) finalizeValidation(result *TransactionValidationResult) { if len(result.Errors) > 0 { result.IsValid = false result.RiskLevel = "critical" return } // Adjust risk level based on warnings if len(result.Warnings) > 2 { if result.RiskLevel == "low" { result.RiskLevel = "medium" } else if result.RiskLevel == "medium" { result.RiskLevel = "high" } } } // Helper functions func (iv *InputValidator) isKnownInvalidAddress(addr common.Address) bool { // Check against known malicious addresses // This would be populated from a real blacklist in production maliciousAddresses := map[common.Address]bool{ // Add known malicious addresses here } return maliciousAddresses[addr] } func (iv *InputValidator) hasSuspiciousPatterns(data []byte) bool { // Check for suspicious patterns in transaction data // This is a simplified implementation // Check for self-destruct calls if len(data) >= 4 { // selfdestruct selector: 0xff if data[0] == 0xff { return true } } // Check for delegate calls to unknown addresses // This would require more sophisticated analysis in production return false } func getDefaultValidationConfig() *ValidationConfig { return &ValidationConfig{ MaxGasLimit: 10000000, // 10M gas MaxGasPriceGwei: 500, // 500 Gwei MaxValueEther: 1000, // 1000 ETH AllowedMethods: []string{}, // Empty means all methods allowed RequireDeadline: true, MaxDeadlineHours: 24, } } // Legacy validation functions (keeping for backward compatibility) // ValidateEthereumAddress validates an Ethereum address string func (iv *InputValidator) ValidateEthereumAddress(address string) error { if !iv.addressPattern.MatchString(address) { return fmt.Errorf("invalid Ethereum address format") } return nil } // ValidateTransactionHash validates a transaction hash string func (iv *InputValidator) ValidateTransactionHash(hash string) error { if !iv.txHashPattern.MatchString(hash) { return fmt.Errorf("invalid transaction hash format") } return nil } // ValidateBlockHash validates a block hash string func (iv *InputValidator) ValidateBlockHash(hash string) error { if !iv.blockHashPattern.MatchString(hash) { return fmt.Errorf("invalid block hash format") } return nil } // ValidateHexData validates hex data string func (iv *InputValidator) ValidateHexData(data string) error { if !iv.hexDataPattern.MatchString(data) { return fmt.Errorf("invalid hex data format") } return nil } // SanitizeInput sanitizes string inputs to prevent injection attacks func SanitizeInput(input string) string { // Remove potentially dangerous characters reg := regexp.MustCompile(`[^\w\s\-\.]`) sanitized := reg.ReplaceAllString(input, "") // Limit length if len(sanitized) > 1000 { sanitized = sanitized[:1000] } return strings.TrimSpace(sanitized) } // ValidateHexString validates a hex string func ValidateHexString(hexStr string) error { if !strings.HasPrefix(hexStr, "0x") { return fmt.Errorf("hex string must start with 0x") } hexStr = hexStr[2:] // Remove 0x prefix if len(hexStr)%2 != 0 { return fmt.Errorf("hex string must have even length") } matched, err := regexp.MatchString("^[0-9a-fA-F]*$", hexStr) if err != nil { return err } if !matched { return fmt.Errorf("invalid hex characters") } return nil }