package validation import ( "context" "encoding/hex" "fmt" "math/big" "strings" "time" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethclient" "github.com/fraktal/mev-beta/internal/logger" "github.com/fraktal/mev-beta/pkg/pools" "github.com/fraktal/mev-beta/pkg/security" "github.com/fraktal/mev-beta/pkg/uniswap" ) // PoolValidator provides comprehensive security validation for liquidity pools type PoolValidator struct { client *ethclient.Client logger *logger.Logger create2Calculator *pools.CREATE2Calculator trustedFactories map[common.Address]string // factory address -> name bannedAddresses map[common.Address]string // banned address -> reason validationCache map[common.Address]*ValidationResult cacheTimeout time.Duration } // ValidationResult contains the result of pool validation type ValidationResult struct { IsValid bool `json:"is_valid"` SecurityScore int `json:"security_score"` // 0-100, higher is better Warnings []string `json:"warnings"` Errors []string `json:"errors"` PoolType string `json:"pool_type"` // "uniswap_v3", "uniswap_v2", etc. Factory string `json:"factory"` Token0 common.Address `json:"token0"` Token1 common.Address `json:"token1"` Fee uint32 `json:"fee,omitempty"` CreationBlock uint64 `json:"creation_block,omitempty"` ValidatedAt time.Time `json:"validated_at"` FactoryVerified bool `json:"factory_verified"` InterfaceValid bool `json:"interface_valid"` TokensValid bool `json:"tokens_valid"` } // PoolValidationConfig contains configuration for pool validation type PoolValidationConfig struct { RequireFactoryVerification bool // Whether factory verification is mandatory MinSecurityScore int // Minimum security score to accept (0-100) MaxValidationTime time.Duration // Maximum time to spend on validation AllowUnknownFactories bool // Whether to allow pools from unknown factories RequireTokenValidation bool // Whether to validate token contracts } // NewPoolValidator creates a new pool validator func NewPoolValidator(client *ethclient.Client, logger *logger.Logger) *PoolValidator { pv := &PoolValidator{ client: client, logger: logger, create2Calculator: pools.NewCREATE2Calculator(logger, client), trustedFactories: make(map[common.Address]string), bannedAddresses: make(map[common.Address]string), validationCache: make(map[common.Address]*ValidationResult), cacheTimeout: 5 * time.Minute, } pv.initializeTrustedFactories() pv.initializeBannedAddresses() return pv } // ValidatePool performs comprehensive security validation of a pool func (pv *PoolValidator) ValidatePool(ctx context.Context, poolAddr common.Address, config *PoolValidationConfig) (*ValidationResult, error) { if config == nil { config = pv.getDefaultConfig() } // Check cache first if cached := pv.getCachedResult(poolAddr); cached != nil { return cached, nil } // Create timeout context timeoutCtx, cancel := context.WithTimeout(ctx, config.MaxValidationTime) defer cancel() result := &ValidationResult{ ValidatedAt: time.Now(), SecurityScore: 0, Warnings: make([]string, 0), Errors: make([]string, 0), } // 1. Basic existence check if err := pv.validateBasicExistence(timeoutCtx, poolAddr, result); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Basic validation failed: %v", err)) result.IsValid = false pv.cacheResult(poolAddr, result) return result, nil } // 2. Check against banned addresses if err := pv.checkBannedAddresses(poolAddr, result); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Banned address check failed: %v", err)) result.IsValid = false result.SecurityScore = 0 pv.cacheResult(poolAddr, result) return result, nil } // 3. Detect pool type and validate interface if err := pv.validatePoolInterface(timeoutCtx, poolAddr, result); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Interface validation failed: %v", err)) if config.RequireFactoryVerification { result.IsValid = false } result.SecurityScore -= 30 } else { result.InterfaceValid = true result.SecurityScore += 25 } // 4. Validate factory deployment (critical security check) if err := pv.validateFactoryDeployment(timeoutCtx, poolAddr, result); err != nil { result.Errors = append(result.Errors, fmt.Sprintf("Factory validation failed: %v", err)) if config.RequireFactoryVerification { result.IsValid = false } result.SecurityScore -= 40 } else { result.FactoryVerified = true result.SecurityScore += 30 } // 5. Validate token contracts if config.RequireTokenValidation { if err := pv.validateTokenContracts(timeoutCtx, result); err != nil { result.Warnings = append(result.Warnings, fmt.Sprintf("Token validation warning: %v", err)) result.SecurityScore -= 10 } else { result.TokensValid = true result.SecurityScore += 15 } } // 6. Additional security checks pv.performAdditionalSecurityChecks(timeoutCtx, poolAddr, result) // 7. Final validation decision if result.SecurityScore >= config.MinSecurityScore && len(result.Errors) == 0 { result.IsValid = true } else { result.IsValid = false } // Ensure security score is within bounds if result.SecurityScore < 0 { result.SecurityScore = 0 } else if result.SecurityScore > 100 { result.SecurityScore = 100 } // Cache the result pv.cacheResult(poolAddr, result) pv.logger.Debug(fmt.Sprintf("Pool validation complete: %s, valid=%v, score=%d", poolAddr.Hex(), result.IsValid, result.SecurityScore)) return result, nil } // validateBasicExistence checks if the pool contract exists and has code func (pv *PoolValidator) validateBasicExistence(ctx context.Context, poolAddr common.Address, result *ValidationResult) error { // Check if contract has code code, err := pv.client.CodeAt(ctx, poolAddr, nil) if err != nil { return fmt.Errorf("failed to get contract code: %w", err) } if len(code) == 0 { return fmt.Errorf("no contract code at address %s", poolAddr.Hex()) } // Basic code size check - legitimate pools should have substantial code if len(code) < 100 { result.Warnings = append(result.Warnings, "Contract has very small code size") result.SecurityScore -= 10 } return nil } // checkBannedAddresses verifies the pool is not on the banned list func (pv *PoolValidator) checkBannedAddresses(poolAddr common.Address, result *ValidationResult) error { if reason, banned := pv.bannedAddresses[poolAddr]; banned { return fmt.Errorf("pool %s is banned: %s", poolAddr.Hex(), reason) } return nil } // validatePoolInterface detects pool type and validates the interface func (pv *PoolValidator) validatePoolInterface(ctx context.Context, poolAddr common.Address, result *ValidationResult) error { // Try to detect pool type by calling standard functions // Check for Uniswap V3 interface if pv.isUniswapV3Pool(ctx, poolAddr) { result.PoolType = "uniswap_v3" return pv.validateUniswapV3Interface(ctx, poolAddr, result) } // Check for Uniswap V2 interface if pv.isUniswapV2Pool(ctx, poolAddr) { result.PoolType = "uniswap_v2" return pv.validateUniswapV2Interface(ctx, poolAddr, result) } // Unknown pool type result.PoolType = "unknown" result.Warnings = append(result.Warnings, "Unknown pool type") return fmt.Errorf("unknown pool interface") } // isUniswapV3Pool checks if the pool implements Uniswap V3 interface func (pv *PoolValidator) isUniswapV3Pool(ctx context.Context, poolAddr common.Address) bool { // Try to call slot0() function (unique to Uniswap V3) slot0ABI := `[{"inputs":[],"name":"slot0","outputs":[{"internalType":"uint160","name":"sqrtPriceX96","type":"uint160"},{"internalType":"int24","name":"tick","type":"int24"},{"internalType":"uint16","name":"observationIndex","type":"uint16"},{"internalType":"uint16","name":"observationCardinality","type":"uint16"},{"internalType":"uint16","name":"observationCardinalityNext","type":"uint16"},{"internalType":"uint8","name":"feeProtocol","type":"uint8"},{"internalType":"bool","name":"unlocked","type":"bool"}],"stateMutability":"view","type":"function"}]` contractABI, err := uniswap.ParseABI(slot0ABI) if err != nil { return false } callData, err := contractABI.Pack("slot0") if err != nil { return false } _, err = pv.client.CallContract(ctx, ethereum.CallMsg{ To: &poolAddr, Data: callData, }, nil) return err == nil } // isUniswapV2Pool checks if the pool implements Uniswap V2 interface func (pv *PoolValidator) isUniswapV2Pool(ctx context.Context, poolAddr common.Address) bool { // Try to call getReserves() function (standard in Uniswap V2) reservesABI := `[{"inputs":[],"name":"getReserves","outputs":[{"internalType":"uint112","name":"_reserve0","type":"uint112"},{"internalType":"uint112","name":"_reserve1","type":"uint112"},{"internalType":"uint32","name":"_blockTimestampLast","type":"uint32"}],"stateMutability":"view","type":"function"}]` contractABI, err := uniswap.ParseABI(reservesABI) if err != nil { return false } callData, err := contractABI.Pack("getReserves") if err != nil { return false } _, err = pv.client.CallContract(ctx, ethereum.CallMsg{ To: &poolAddr, Data: callData, }, nil) return err == nil } // validateUniswapV3Interface validates Uniswap V3 specific functions func (pv *PoolValidator) validateUniswapV3Interface(ctx context.Context, poolAddr common.Address, result *ValidationResult) error { // Get token addresses and fee token0, token1, fee, err := pv.getUniswapV3PoolInfo(ctx, poolAddr) if err != nil { return fmt.Errorf("failed to get V3 pool info: %w", err) } result.Token0 = token0 result.Token1 = token1 result.Fee = fee // Validate token ordering (token0 < token1) if token0.Big().Cmp(token1.Big()) >= 0 { return fmt.Errorf("invalid token ordering: token0 must be < token1") } // Validate fee tier validFees := []uint32{500, 3000, 10000, 100} // 0.05%, 0.3%, 1%, 0.01% feeValid := false for _, validFee := range validFees { if fee == validFee { feeValid = true break } } if !feeValid { result.Warnings = append(result.Warnings, fmt.Sprintf("Unusual fee tier: %d", fee)) } return nil } // validateUniswapV2Interface validates Uniswap V2 specific functions func (pv *PoolValidator) validateUniswapV2Interface(ctx context.Context, poolAddr common.Address, result *ValidationResult) error { // Get token addresses token0, token1, err := pv.getUniswapV2PoolInfo(ctx, poolAddr) if err != nil { return fmt.Errorf("failed to get V2 pool info: %w", err) } result.Token0 = token0 result.Token1 = token1 result.Fee = 3000 // V2 has fixed 0.3% fee // Validate token ordering if token0.Big().Cmp(token1.Big()) >= 0 { return fmt.Errorf("invalid token ordering: token0 must be < token1") } return nil } // validateFactoryDeployment verifies the pool was deployed by a trusted factory func (pv *PoolValidator) validateFactoryDeployment(ctx context.Context, poolAddr common.Address, result *ValidationResult) error { // For each known factory, try to verify if this pool could have been deployed by it for factoryAddr, factoryName := range pv.trustedFactories { if pv.verifyFactoryDeployment(factoryAddr, factoryName, poolAddr, result) { result.Factory = factoryName return nil } } return fmt.Errorf("pool not deployed by any trusted factory") } // verifyFactoryDeployment verifies a specific factory deployed the pool func (pv *PoolValidator) verifyFactoryDeployment(factoryAddr common.Address, factoryName string, poolAddr common.Address, result *ValidationResult) bool { if result.Token0 == (common.Address{}) || result.Token1 == (common.Address{}) { return false } // Use CREATE2 calculator to verify the pool address return pv.create2Calculator.ValidatePoolAddress(factoryName, result.Token0, result.Token1, result.Fee, poolAddr) } // validateTokenContracts validates the token contracts in the pool func (pv *PoolValidator) validateTokenContracts(ctx context.Context, result *ValidationResult) error { if result.Token0 == (common.Address{}) || result.Token1 == (common.Address{}) { return fmt.Errorf("token addresses not available") } // Validate token0 if err := pv.validateTokenContract(ctx, result.Token0); err != nil { return fmt.Errorf("token0 validation failed: %w", err) } // Validate token1 if err := pv.validateTokenContract(ctx, result.Token1); err != nil { return fmt.Errorf("token1 validation failed: %w", err) } return nil } // validateTokenContract validates a single token contract func (pv *PoolValidator) validateTokenContract(ctx context.Context, tokenAddr common.Address) error { // Check if contract exists code, err := pv.client.CodeAt(ctx, tokenAddr, nil) if err != nil { return fmt.Errorf("failed to get token contract code: %w", err) } if len(code) == 0 { return fmt.Errorf("no contract code at token address %s", tokenAddr.Hex()) } // Try to call standard ERC20 functions return pv.validateERC20Interface(ctx, tokenAddr) } // validateERC20Interface validates ERC20 token interface func (pv *PoolValidator) validateERC20Interface(ctx context.Context, tokenAddr common.Address) error { // Try to call totalSupply() function erc20ABI := `[{"inputs":[],"name":"totalSupply","outputs":[{"internalType":"uint256","name":"","type":"uint256"}],"stateMutability":"view","type":"function"}]` contractABI, err := uniswap.ParseABI(erc20ABI) if err != nil { return fmt.Errorf("failed to parse ERC20 ABI: %w", err) } callData, err := contractABI.Pack("totalSupply") if err != nil { return fmt.Errorf("failed to pack totalSupply call: %w", err) } _, err = pv.client.CallContract(ctx, ethereum.CallMsg{ To: &tokenAddr, Data: callData, }, nil) if err != nil { return fmt.Errorf("totalSupply call failed: %w", err) } return nil } // performAdditionalSecurityChecks performs various security checks func (pv *PoolValidator) performAdditionalSecurityChecks(ctx context.Context, poolAddr common.Address, result *ValidationResult) { // Check contract creation time if creationBlock := pv.getContractCreationBlock(ctx, poolAddr); creationBlock > 0 { result.CreationBlock = creationBlock // Warn about very new contracts currentBlock, err := pv.client.BlockNumber(ctx) if err == nil && currentBlock-creationBlock < 100 { result.Warnings = append(result.Warnings, "Pool is very new (< 100 blocks old)") result.SecurityScore -= 5 } } // Check for common attack patterns pv.checkForAttackPatterns(ctx, poolAddr, result) } // getContractCreationBlock attempts to find when the contract was created using binary search func (pv *PoolValidator) getContractCreationBlock(ctx context.Context, addr common.Address) uint64 { pv.logger.Debug(fmt.Sprintf("Finding creation block for contract %s", addr.Hex())) // Get latest block number first latestBlock, err := pv.client.BlockNumber(ctx) if err != nil { pv.logger.Warn(fmt.Sprintf("Failed to get latest block number: %v", err)) return 0 } // Check if contract exists at latest block codeAtLatest, err := pv.client.CodeAt(ctx, addr, new(big.Int).SetUint64(latestBlock)) if err != nil || len(codeAtLatest) == 0 { pv.logger.Debug(fmt.Sprintf("Contract %s does not exist at latest block", addr.Hex())) return 0 } // Binary search to find creation block // Start with a reasonable range - most pools created in last 10M blocks searchStart := uint64(0) if latestBlock > 10000000 { searchStart = latestBlock - 10000000 } creationBlock := pv.binarySearchCreationBlock(ctx, addr, searchStart, latestBlock) if creationBlock > 0 { pv.logger.Debug(fmt.Sprintf("Contract %s created at block %d", addr.Hex(), creationBlock)) } return creationBlock } // binarySearchCreationBlock performs binary search to find the exact creation block func (pv *PoolValidator) binarySearchCreationBlock(ctx context.Context, addr common.Address, start, end uint64) uint64 { // Limit search iterations to prevent infinite loops maxIterations := 50 iteration := 0 for start <= end && iteration < maxIterations { iteration++ mid := (start + end) / 2 // Check if contract exists at mid block code, err := pv.client.CodeAt(ctx, addr, new(big.Int).SetUint64(mid)) if err != nil { pv.logger.Debug(fmt.Sprintf("Error checking code at block %d: %v", mid, err)) break } hasCode := len(code) > 0 if hasCode { // Contract exists at mid, check if it exists at mid-1 if mid == 0 { return mid } prevCode, err := pv.client.CodeAt(ctx, addr, new(big.Int).SetUint64(mid-1)) if err != nil || len(prevCode) == 0 { // Contract doesn't exist at mid-1 but exists at mid return mid } // Contract exists at both mid and mid-1, search earlier end = mid - 1 } else { // Contract doesn't exist at mid, search later start = mid + 1 } // Add small delay to avoid rate limiting if iteration%10 == 0 { select { case <-ctx.Done(): return 0 case <-time.After(100 * time.Millisecond): } } } // If we couldn't find exact block, return start as best estimate if start <= end { return start } return 0 } // checkForAttackPatterns looks for common malicious patterns func (pv *PoolValidator) checkForAttackPatterns(ctx context.Context, poolAddr common.Address, result *ValidationResult) { // Check if contract is a proxy (may be suspicious) if pv.isProxyContract(ctx, poolAddr) { result.Warnings = append(result.Warnings, "Contract appears to be a proxy") result.SecurityScore -= 10 } // Check for unusual bytecode patterns if pv.hasUnusualBytecode(ctx, poolAddr) { result.Warnings = append(result.Warnings, "Unusual bytecode patterns detected") result.SecurityScore -= 15 } } // isProxyContract checks if the contract is a proxy func (pv *PoolValidator) isProxyContract(ctx context.Context, addr common.Address) bool { code, err := pv.client.CodeAt(ctx, addr, nil) if err != nil || len(code) == 0 { return false } // Look for common proxy patterns (delegatecall, etc.) codeHex := hex.EncodeToString(code) return strings.Contains(codeHex, "f4") // delegatecall opcode } // hasUnusualBytecode checks for suspicious bytecode patterns func (pv *PoolValidator) hasUnusualBytecode(ctx context.Context, addr common.Address) bool { code, err := pv.client.CodeAt(ctx, addr, nil) if err != nil || len(code) == 0 { return false } // Check for unusual patterns if len(code) > 50000 { return true // Unusually large contract } // Check for high entropy (potential obfuscation) entropy := pv.calculateEntropy(code) return entropy > 7.5 // High entropy threshold } // calculateEntropy calculates Shannon entropy of bytecode func (pv *PoolValidator) calculateEntropy(data []byte) float64 { if len(data) == 0 { return 0 } freq := make(map[byte]int) for _, b := range data { freq[b]++ } entropy := 0.0 length := float64(len(data)) for _, f := range freq { p := float64(f) / length if p > 0 { entropy -= p * (float64(f) / length) } } return entropy } // Helper functions to get pool information func (pv *PoolValidator) getUniswapV3PoolInfo(ctx context.Context, poolAddr common.Address) (common.Address, common.Address, uint32, error) { poolABI := `[{"inputs":[],"name":"token0","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"},{"inputs":[],"name":"token1","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"},{"inputs":[],"name":"fee","outputs":[{"internalType":"uint24","name":"","type":"uint24"}],"stateMutability":"view","type":"function"}]` contractABI, err := uniswap.ParseABI(poolABI) if err != nil { return common.Address{}, common.Address{}, 0, err } // Get token0 token0Data, err := contractABI.Pack("token0") if err != nil { return common.Address{}, common.Address{}, 0, err } token0Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{ To: &poolAddr, Data: token0Data, }, nil) if err != nil { return common.Address{}, common.Address{}, 0, err } token0Unpacked, err := contractABI.Unpack("token0", token0Result) if err != nil { return common.Address{}, common.Address{}, 0, err } // Get token1 token1Data, err := contractABI.Pack("token1") if err != nil { return common.Address{}, common.Address{}, 0, err } token1Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{ To: &poolAddr, Data: token1Data, }, nil) if err != nil { return common.Address{}, common.Address{}, 0, err } token1Unpacked, err := contractABI.Unpack("token1", token1Result) if err != nil { return common.Address{}, common.Address{}, 0, err } // Get fee feeData, err := contractABI.Pack("fee") if err != nil { return common.Address{}, common.Address{}, 0, err } feeResult, err := pv.client.CallContract(ctx, ethereum.CallMsg{ To: &poolAddr, Data: feeData, }, nil) if err != nil { return common.Address{}, common.Address{}, 0, err } feeUnpacked, err := contractABI.Unpack("fee", feeResult) if err != nil { return common.Address{}, common.Address{}, 0, err } token0 := token0Unpacked[0].(common.Address) token1 := token1Unpacked[0].(common.Address) fee := feeUnpacked[0].(*big.Int).Uint64() feeUint32, err := security.SafeUint32(fee) if err != nil { return common.Address{}, common.Address{}, 0, fmt.Errorf("invalid fee conversion: %w", err) } return token0, token1, feeUint32, nil } func (pv *PoolValidator) getUniswapV2PoolInfo(ctx context.Context, poolAddr common.Address) (common.Address, common.Address, error) { poolABI := `[{"inputs":[],"name":"token0","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"},{"inputs":[],"name":"token1","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"}]` contractABI, err := uniswap.ParseABI(poolABI) if err != nil { return common.Address{}, common.Address{}, err } // Get token0 token0Data, err := contractABI.Pack("token0") if err != nil { return common.Address{}, common.Address{}, err } token0Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{ To: &poolAddr, Data: token0Data, }, nil) if err != nil { return common.Address{}, common.Address{}, err } token0Unpacked, err := contractABI.Unpack("token0", token0Result) if err != nil { return common.Address{}, common.Address{}, err } // Get token1 token1Data, err := contractABI.Pack("token1") if err != nil { return common.Address{}, common.Address{}, err } token1Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{ To: &poolAddr, Data: token1Data, }, nil) if err != nil { return common.Address{}, common.Address{}, err } token1Unpacked, err := contractABI.Unpack("token1", token1Result) if err != nil { return common.Address{}, common.Address{}, err } token0 := token0Unpacked[0].(common.Address) token1 := token1Unpacked[0].(common.Address) return token0, token1, nil } // Configuration and caching methods func (pv *PoolValidator) getDefaultConfig() *PoolValidationConfig { return &PoolValidationConfig{ RequireFactoryVerification: true, MinSecurityScore: 70, MaxValidationTime: 10 * time.Second, AllowUnknownFactories: false, RequireTokenValidation: true, } } func (pv *PoolValidator) getCachedResult(addr common.Address) *ValidationResult { if result, exists := pv.validationCache[addr]; exists { if time.Since(result.ValidatedAt) < pv.cacheTimeout { return result } delete(pv.validationCache, addr) } return nil } func (pv *PoolValidator) cacheResult(addr common.Address, result *ValidationResult) { pv.validationCache[addr] = result } // Initialization methods func (pv *PoolValidator) initializeTrustedFactories() { // Uniswap V3 pv.trustedFactories[common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984")] = "uniswap_v3" // Uniswap V2 pv.trustedFactories[common.HexToAddress("0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f")] = "uniswap_v2" // SushiSwap pv.trustedFactories[common.HexToAddress("0xC0AEe478e3658e2610c5F7A4A2E1777cE9e4f2Ac")] = "sushiswap" // Camelot V3 (Arbitrum) pv.trustedFactories[common.HexToAddress("0x1a3c9B1d2F0529D97f2afC5136Cc23e58f1FD35B")] = "camelot_v3" } func (pv *PoolValidator) initializeBannedAddresses() { // Add known malicious or problematic pool addresses // This would be populated with real banned addresses in production } // AddTrustedFactory adds a new trusted factory func (pv *PoolValidator) AddTrustedFactory(factoryAddr common.Address, name string) { pv.trustedFactories[factoryAddr] = name pv.logger.Info(fmt.Sprintf("Added trusted factory: %s (%s)", name, factoryAddr.Hex())) } // BanAddress adds an address to the banned list func (pv *PoolValidator) BanAddress(addr common.Address, reason string) { pv.bannedAddresses[addr] = reason // Clear from cache if present delete(pv.validationCache, addr) pv.logger.Warn(fmt.Sprintf("Banned address: %s (reason: %s)", addr.Hex(), reason)) }