package contracts import ( "context" "fmt" "github.com/ethereum/go-ethereum/common" "github.com/fraktal/mev-beta/internal/logger" ) // FunctionSignature represents a known function signature type FunctionSignature struct { Name string Selector []byte AllowedTypes []ContractType } // SignatureValidator validates function calls against contract types type SignatureValidator struct { detector *ContractDetector logger *logger.Logger signatures map[string]*FunctionSignature } // NewSignatureValidator creates a new function signature validator func NewSignatureValidator(detector *ContractDetector, logger *logger.Logger) *SignatureValidator { sv := &SignatureValidator{ detector: detector, logger: logger, signatures: make(map[string]*FunctionSignature), } // Initialize known function signatures sv.initializeSignatures() return sv } // initializeSignatures initializes the known function signatures and their allowed contract types func (sv *SignatureValidator) initializeSignatures() { // ERC-20 token functions sv.signatures["name()"] = &FunctionSignature{ Name: "name()", Selector: []byte{0x06, 0xfd, 0xde, 0x03}, AllowedTypes: []ContractType{ContractTypeERC20Token}, } sv.signatures["symbol()"] = &FunctionSignature{ Name: "symbol()", Selector: []byte{0x95, 0xd8, 0x9b, 0x41}, AllowedTypes: []ContractType{ContractTypeERC20Token}, } sv.signatures["decimals()"] = &FunctionSignature{ Name: "decimals()", Selector: []byte{0x31, 0x3c, 0xe5, 0x67}, AllowedTypes: []ContractType{ContractTypeERC20Token}, } sv.signatures["totalSupply()"] = &FunctionSignature{ Name: "totalSupply()", Selector: []byte{0x18, 0x16, 0x0d, 0xdd}, AllowedTypes: []ContractType{ContractTypeERC20Token}, } sv.signatures["balanceOf()"] = &FunctionSignature{ Name: "balanceOf()", Selector: []byte{0x70, 0xa0, 0x82, 0x31}, AllowedTypes: []ContractType{ContractTypeERC20Token}, } // Uniswap V2 Pool functions sv.signatures["token0()"] = &FunctionSignature{ Name: "token0()", Selector: []byte{0x0d, 0xfe, 0x16, 0x81}, AllowedTypes: []ContractType{ ContractTypeUniswapV2Pool, ContractTypeUniswapV3Pool, }, } sv.signatures["token1()"] = &FunctionSignature{ Name: "token1()", Selector: []byte{0xd2, 0x12, 0x20, 0xa7}, AllowedTypes: []ContractType{ ContractTypeUniswapV2Pool, ContractTypeUniswapV3Pool, }, } sv.signatures["getReserves()"] = &FunctionSignature{ Name: "getReserves()", Selector: []byte{0x09, 0x02, 0xf1, 0xac}, AllowedTypes: []ContractType{ContractTypeUniswapV2Pool}, } // Uniswap V3 Pool specific functions sv.signatures["slot0()"] = &FunctionSignature{ Name: "slot0()", Selector: []byte{0x38, 0x50, 0xc7, 0xbd}, AllowedTypes: []ContractType{ContractTypeUniswapV3Pool}, } sv.signatures["fee()"] = &FunctionSignature{ Name: "fee()", Selector: []byte{0xdd, 0xca, 0x3f, 0x43}, AllowedTypes: []ContractType{ContractTypeUniswapV3Pool}, } sv.signatures["liquidity()"] = &FunctionSignature{ Name: "liquidity()", Selector: []byte{0x1a, 0x68, 0x65, 0x0f}, AllowedTypes: []ContractType{ContractTypeUniswapV3Pool}, } sv.signatures["tickSpacing()"] = &FunctionSignature{ Name: "tickSpacing()", Selector: []byte{0xd0, 0xc9, 0x32, 0x07}, AllowedTypes: []ContractType{ContractTypeUniswapV3Pool}, } // Router functions sv.signatures["WETH()"] = &FunctionSignature{ Name: "WETH()", Selector: []byte{0xad, 0x5c, 0x46, 0x48}, AllowedTypes: []ContractType{ ContractTypeUniswapV2Router, ContractTypeUniswapV3Router, }, } sv.signatures["factory()"] = &FunctionSignature{ Name: "factory()", Selector: []byte{0xc4, 0x5a, 0x01, 0x55}, AllowedTypes: []ContractType{ ContractTypeUniswapV2Router, ContractTypeUniswapV3Router, }, } } // ValidationResult contains the result of function signature validation type ValidationResult struct { IsValid bool FunctionName string ContractType ContractType Error error Warnings []string } // ValidateFunctionCall validates if a function can be called on a contract func (sv *SignatureValidator) ValidateFunctionCall(ctx context.Context, contractAddress common.Address, functionSelector []byte) *ValidationResult { result := &ValidationResult{ IsValid: false, Warnings: []string{}, } // Detect contract type detection := sv.detector.DetectContractType(ctx, contractAddress) result.ContractType = detection.ContractType if detection.Error != nil { result.Error = fmt.Errorf("contract type detection failed: %w", detection.Error) return result } // Find matching function signature var matchedSignature *FunctionSignature for _, sig := range sv.signatures { if len(sig.Selector) >= 4 && len(functionSelector) >= 4 { if sig.Selector[0] == functionSelector[0] && sig.Selector[1] == functionSelector[1] && sig.Selector[2] == functionSelector[2] && sig.Selector[3] == functionSelector[3] { matchedSignature = sig result.FunctionName = sig.Name break } } } // If no signature match found, warn but allow (could be unknown function) if matchedSignature == nil { result.IsValid = true result.Warnings = append(result.Warnings, fmt.Sprintf("unknown function selector: %x", functionSelector)) return result } // Check if the detected contract type is allowed for this function allowed := false for _, allowedType := range matchedSignature.AllowedTypes { if detection.ContractType == allowedType { allowed = true break } } if !allowed { result.Error = fmt.Errorf("function %s cannot be called on contract type %s (allowed types: %v)", matchedSignature.Name, detection.ContractType.String(), matchedSignature.AllowedTypes) return result } // Check confidence level if detection.Confidence < 0.7 { result.Warnings = append(result.Warnings, fmt.Sprintf("low confidence in contract type detection: %.2f", detection.Confidence)) } result.IsValid = true return result } // ValidateToken0Call specifically validates token0() function calls func (sv *SignatureValidator) ValidateToken0Call(ctx context.Context, contractAddress common.Address) *ValidationResult { token0Selector := []byte{0x0d, 0xfe, 0x16, 0x81} return sv.ValidateFunctionCall(ctx, contractAddress, token0Selector) } // ValidateToken1Call specifically validates token1() function calls func (sv *SignatureValidator) ValidateToken1Call(ctx context.Context, contractAddress common.Address) *ValidationResult { token1Selector := []byte{0xd2, 0x12, 0x20, 0xa7} return sv.ValidateFunctionCall(ctx, contractAddress, token1Selector) } // ValidateGetReservesCall specifically validates getReserves() function calls func (sv *SignatureValidator) ValidateGetReservesCall(ctx context.Context, contractAddress common.Address) *ValidationResult { getReservesSelector := []byte{0x09, 0x02, 0xf1, 0xac} return sv.ValidateFunctionCall(ctx, contractAddress, getReservesSelector) } // ValidateSlot0Call specifically validates slot0() function calls for Uniswap V3 func (sv *SignatureValidator) ValidateSlot0Call(ctx context.Context, contractAddress common.Address) *ValidationResult { slot0Selector := []byte{0x38, 0x50, 0xc7, 0xbd} return sv.ValidateFunctionCall(ctx, contractAddress, slot0Selector) } // GetSupportedFunctions returns the functions supported by a contract type func (sv *SignatureValidator) GetSupportedFunctions(contractType ContractType) []string { var functions []string for _, sig := range sv.signatures { for _, allowedType := range sig.AllowedTypes { if allowedType == contractType { functions = append(functions, sig.Name) break } } } return functions }