package contracts import ( "context" "fmt" "math/big" "strings" "time" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethclient" "github.com/fraktal/mev-beta/internal/logger" ) // ContractType represents the detected type of a contract type ContractType int const ( ContractTypeUnknown ContractType = iota ContractTypeERC20Token ContractTypeUniswapV2Pool ContractTypeUniswapV3Pool ContractTypeUniswapV2Router ContractTypeUniswapV3Router ContractTypeUniversalRouter ContractTypeFactory ContractTypeEOA // Externally Owned Account ) // String returns the string representation of the contract type func (ct ContractType) String() string { switch ct { case ContractTypeERC20Token: return "ERC-20" case ContractTypeUniswapV2Pool: return "UniswapV2Pool" case ContractTypeUniswapV3Pool: return "UniswapV3Pool" case ContractTypeUniswapV2Router: return "UniswapV2Router" case ContractTypeUniswapV3Router: return "UniswapV3Router" case ContractTypeUniversalRouter: return "UniversalRouter" case ContractTypeFactory: return "Factory" case ContractTypeEOA: return "EOA" default: return "Unknown" } } // DetectionResult contains the result of contract type detection type DetectionResult struct { ContractType ContractType IsContract bool HasCode bool SupportedFunctions []string Confidence float64 // 0.0 to 1.0 Error error Warnings []string } // ContractDetector provides runtime contract type detection type ContractDetector struct { client *ethclient.Client logger *logger.Logger cache map[common.Address]*DetectionResult timeout time.Duration } // NewContractDetector creates a new contract detector func NewContractDetector(client *ethclient.Client, logger *logger.Logger) *ContractDetector { return &ContractDetector{ client: client, logger: logger, cache: make(map[common.Address]*DetectionResult), timeout: 5 * time.Second, } } // DetectContractType determines the type of contract at the given address func (cd *ContractDetector) DetectContractType(ctx context.Context, address common.Address) *DetectionResult { // Check cache first if result, exists := cd.cache[address]; exists { return result } result := &DetectionResult{ ContractType: ContractTypeUnknown, IsContract: false, HasCode: false, SupportedFunctions: []string{}, Confidence: 0.0, Warnings: []string{}, } // Create context with timeout ctxWithTimeout, cancel := context.WithTimeout(ctx, cd.timeout) defer cancel() // Check if address has code (is a contract) code, err := cd.client.CodeAt(ctxWithTimeout, address, nil) if err != nil { result.Error = fmt.Errorf("failed to get code at address: %w", err) cd.cache[address] = result return result } // If no code, it's an EOA if len(code) == 0 { result.ContractType = ContractTypeEOA result.IsContract = false result.Confidence = 1.0 cd.cache[address] = result return result } result.IsContract = true result.HasCode = true // Detect contract type by testing function signatures contractType, confidence, supportedFunctions := cd.detectByFunctionSignatures(ctxWithTimeout, address) result.ContractType = contractType result.Confidence = confidence result.SupportedFunctions = supportedFunctions // Additional validation for high-confidence detection if confidence > 0.8 { if err := cd.validateContractType(ctxWithTimeout, address, contractType); err != nil { result.Warnings = append(result.Warnings, fmt.Sprintf("validation warning: %v", err)) result.Confidence *= 0.8 // Reduce confidence } } cd.cache[address] = result return result } // detectByFunctionSignatures detects contract type by testing known function signatures func (cd *ContractDetector) detectByFunctionSignatures(ctx context.Context, address common.Address) (ContractType, float64, []string) { supportedFunctions := []string{} scores := make(map[ContractType]float64) // Test ERC-20 functions erc20Functions := map[string][]byte{ "name()": {0x06, 0xfd, 0xde, 0x03}, "symbol()": {0x95, 0xd8, 0x9b, 0x41}, "decimals()": {0x31, 0x3c, 0xe5, 0x67}, "totalSupply()": {0x18, 0x16, 0x0d, 0xdd}, "balanceOf()": {0x70, 0xa0, 0x82, 0x31}, } erc20Score := cd.testFunctionSignatures(ctx, address, erc20Functions, &supportedFunctions) if erc20Score > 0.6 { scores[ContractTypeERC20Token] = erc20Score } // Test Uniswap V2 Pool functions v2PoolFunctions := map[string][]byte{ "token0()": {0x0d, 0xfe, 0x16, 0x81}, "token1()": {0xd2, 0x12, 0x20, 0xa7}, "getReserves()": {0x09, 0x02, 0xf1, 0xac}, "price0CumulativeLast()": {0x54, 0x1c, 0x5c, 0xfa}, "kLast()": {0x7d, 0xc0, 0xd1, 0xd0}, } v2PoolScore := cd.testFunctionSignatures(ctx, address, v2PoolFunctions, &supportedFunctions) if v2PoolScore > 0.6 { scores[ContractTypeUniswapV2Pool] = v2PoolScore } // Test Uniswap V3 Pool functions v3PoolFunctions := map[string][]byte{ "token0()": {0x0d, 0xfe, 0x16, 0x81}, "token1()": {0xd2, 0x12, 0x20, 0xa7}, "fee()": {0xdd, 0xca, 0x3f, 0x43}, "slot0()": {0x38, 0x50, 0xc7, 0xbd}, "liquidity()": {0x1a, 0x68, 0x65, 0x0f}, "tickSpacing()": {0xd0, 0xc9, 0x32, 0x07}, } v3PoolScore := cd.testFunctionSignatures(ctx, address, v3PoolFunctions, &supportedFunctions) if v3PoolScore > 0.6 { scores[ContractTypeUniswapV3Pool] = v3PoolScore } // Test Router functions routerFunctions := map[string][]byte{ "WETH()": {0xad, 0x5c, 0x46, 0x48}, "swapExactTokensForTokens()": {0x38, 0xed, 0x17, 0x39}, "factory()": {0xc4, 0x5a, 0x01, 0x55}, } routerScore := cd.testFunctionSignatures(ctx, address, routerFunctions, &supportedFunctions) if routerScore > 0.5 { scores[ContractTypeUniswapV2Router] = routerScore } // Find highest scoring type var bestType ContractType = ContractTypeUnknown var bestScore float64 = 0.0 for contractType, score := range scores { if score > bestScore { bestScore = score bestType = contractType } } return bestType, bestScore, supportedFunctions } // testFunctionSignatures tests if a contract supports given function signatures func (cd *ContractDetector) testFunctionSignatures(ctx context.Context, address common.Address, functions map[string][]byte, supportedFunctions *[]string) float64 { supported := 0 total := len(functions) for funcName, signature := range functions { // Test the function call _, err := cd.client.CallContract(ctx, ethereum.CallMsg{ To: &address, Data: signature, }, nil) if err == nil { supported++ *supportedFunctions = append(*supportedFunctions, funcName) } else if !strings.Contains(err.Error(), "execution reverted") { // If it's not a revert, it might be a network error, so we don't count it against total-- } } if total == 0 { return 0.0 } return float64(supported) / float64(total) } // validateContractType performs additional validation for detected contract types func (cd *ContractDetector) validateContractType(ctx context.Context, address common.Address, contractType ContractType) error { switch contractType { case ContractTypeERC20Token: return cd.validateERC20(ctx, address) case ContractTypeUniswapV2Pool: return cd.validateUniswapV2Pool(ctx, address) case ContractTypeUniswapV3Pool: return cd.validateUniswapV3Pool(ctx, address) default: return nil // No additional validation for other types } } // validateERC20 validates that a contract is actually an ERC-20 token func (cd *ContractDetector) validateERC20(ctx context.Context, address common.Address) error { // Test decimals() - should return a reasonable value (0-18) decimalsData := []byte{0x31, 0x3c, 0xe5, 0x67} // decimals() result, err := cd.client.CallContract(ctx, ethereum.CallMsg{ To: &address, Data: decimalsData, }, nil) if err != nil { return fmt.Errorf("decimals() call failed: %w", err) } if len(result) == 32 { decimals := new(big.Int).SetBytes(result).Uint64() if decimals > 18 { return fmt.Errorf("unrealistic decimals value: %d", decimals) } } return nil } // validateUniswapV2Pool validates that a contract is actually a Uniswap V2 pool func (cd *ContractDetector) validateUniswapV2Pool(ctx context.Context, address common.Address) error { // Test getReserves() - should return 3 values getReservesData := []byte{0x09, 0x02, 0xf1, 0xac} // getReserves() result, err := cd.client.CallContract(ctx, ethereum.CallMsg{ To: &address, Data: getReservesData, }, nil) if err != nil { return fmt.Errorf("getReserves() call failed: %w", err) } // Should return 3 uint112 values (reserves + timestamp) if len(result) != 96 { // 3 * 32 bytes return fmt.Errorf("unexpected getReserves() return length: %d", len(result)) } return nil } // validateUniswapV3Pool validates that a contract is actually a Uniswap V3 pool func (cd *ContractDetector) validateUniswapV3Pool(ctx context.Context, address common.Address) error { // Test slot0() - should return current state slot0Data := []byte{0x38, 0x50, 0xc7, 0xbd} // slot0() result, err := cd.client.CallContract(ctx, ethereum.CallMsg{ To: &address, Data: slot0Data, }, nil) if err != nil { return fmt.Errorf("slot0() call failed: %w", err) } // Should return multiple values including sqrtPriceX96 if len(result) < 32 { return fmt.Errorf("unexpected slot0() return length: %d", len(result)) } return nil } // IsERC20Token checks if an address is an ERC-20 token func (cd *ContractDetector) IsERC20Token(ctx context.Context, address common.Address) bool { result := cd.DetectContractType(ctx, address) return result.ContractType == ContractTypeERC20Token && result.Confidence > 0.7 } // IsUniswapPool checks if an address is a Uniswap pool (V2 or V3) func (cd *ContractDetector) IsUniswapPool(ctx context.Context, address common.Address) bool { result := cd.DetectContractType(ctx, address) return (result.ContractType == ContractTypeUniswapV2Pool || result.ContractType == ContractTypeUniswapV3Pool) && result.Confidence > 0.7 } // IsRouter checks if an address is a router contract func (cd *ContractDetector) IsRouter(ctx context.Context, address common.Address) bool { result := cd.DetectContractType(ctx, address) return (result.ContractType == ContractTypeUniswapV2Router || result.ContractType == ContractTypeUniswapV3Router || result.ContractType == ContractTypeUniversalRouter) && result.Confidence > 0.7 } // ClearCache clears the detection cache func (cd *ContractDetector) ClearCache() { cd.cache = make(map[common.Address]*DetectionResult) } // GetCacheSize returns the number of cached results func (cd *ContractDetector) GetCacheSize() int { return len(cd.cache) }