Files
mev-beta/internal/contracts/signature_validator.go
Krypto Kajun 850223a953 fix(multicall): resolve critical multicall parsing corruption issues
- Added comprehensive bounds checking to prevent buffer overruns in multicall parsing
- Implemented graduated validation system (Strict/Moderate/Permissive) to reduce false positives
- Added LRU caching system for address validation with 10-minute TTL
- Enhanced ABI decoder with missing Universal Router and Arbitrum-specific DEX signatures
- Fixed duplicate function declarations and import conflicts across multiple files
- Added error recovery mechanisms with multiple fallback strategies
- Updated tests to handle new validation behavior for suspicious addresses
- Fixed parser test expectations for improved validation system
- Applied gofmt formatting fixes to ensure code style compliance
- Fixed mutex copying issues in monitoring package by introducing MetricsSnapshot
- Resolved critical security vulnerabilities in heuristic address extraction
- Progress: Updated TODO audit from 10% to 35% complete

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 00:12:55 -05:00

240 lines
7.6 KiB
Go

package contracts
import (
"context"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/fraktal/mev-beta/internal/logger"
)
// FunctionSignature represents a known function signature
type FunctionSignature struct {
Name string
Selector []byte
AllowedTypes []ContractType
}
// SignatureValidator validates function calls against contract types
type SignatureValidator struct {
detector *ContractDetector
logger *logger.Logger
signatures map[string]*FunctionSignature
}
// NewSignatureValidator creates a new function signature validator
func NewSignatureValidator(detector *ContractDetector, logger *logger.Logger) *SignatureValidator {
sv := &SignatureValidator{
detector: detector,
logger: logger,
signatures: make(map[string]*FunctionSignature),
}
// Initialize known function signatures
sv.initializeSignatures()
return sv
}
// initializeSignatures initializes the known function signatures and their allowed contract types
func (sv *SignatureValidator) initializeSignatures() {
// ERC-20 token functions
sv.signatures["name()"] = &FunctionSignature{
Name: "name()",
Selector: []byte{0x06, 0xfd, 0xde, 0x03},
AllowedTypes: []ContractType{ContractTypeERC20Token},
}
sv.signatures["symbol()"] = &FunctionSignature{
Name: "symbol()",
Selector: []byte{0x95, 0xd8, 0x9b, 0x41},
AllowedTypes: []ContractType{ContractTypeERC20Token},
}
sv.signatures["decimals()"] = &FunctionSignature{
Name: "decimals()",
Selector: []byte{0x31, 0x3c, 0xe5, 0x67},
AllowedTypes: []ContractType{ContractTypeERC20Token},
}
sv.signatures["totalSupply()"] = &FunctionSignature{
Name: "totalSupply()",
Selector: []byte{0x18, 0x16, 0x0d, 0xdd},
AllowedTypes: []ContractType{ContractTypeERC20Token},
}
sv.signatures["balanceOf()"] = &FunctionSignature{
Name: "balanceOf()",
Selector: []byte{0x70, 0xa0, 0x82, 0x31},
AllowedTypes: []ContractType{ContractTypeERC20Token},
}
// Uniswap V2 Pool functions
sv.signatures["token0()"] = &FunctionSignature{
Name: "token0()",
Selector: []byte{0x0d, 0xfe, 0x16, 0x81},
AllowedTypes: []ContractType{
ContractTypeUniswapV2Pool,
ContractTypeUniswapV3Pool,
},
}
sv.signatures["token1()"] = &FunctionSignature{
Name: "token1()",
Selector: []byte{0xd2, 0x12, 0x20, 0xa7},
AllowedTypes: []ContractType{
ContractTypeUniswapV2Pool,
ContractTypeUniswapV3Pool,
},
}
sv.signatures["getReserves()"] = &FunctionSignature{
Name: "getReserves()",
Selector: []byte{0x09, 0x02, 0xf1, 0xac},
AllowedTypes: []ContractType{ContractTypeUniswapV2Pool},
}
// Uniswap V3 Pool specific functions
sv.signatures["slot0()"] = &FunctionSignature{
Name: "slot0()",
Selector: []byte{0x38, 0x50, 0xc7, 0xbd},
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
}
sv.signatures["fee()"] = &FunctionSignature{
Name: "fee()",
Selector: []byte{0xdd, 0xca, 0x3f, 0x43},
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
}
sv.signatures["liquidity()"] = &FunctionSignature{
Name: "liquidity()",
Selector: []byte{0x1a, 0x68, 0x65, 0x0f},
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
}
sv.signatures["tickSpacing()"] = &FunctionSignature{
Name: "tickSpacing()",
Selector: []byte{0xd0, 0xc9, 0x32, 0x07},
AllowedTypes: []ContractType{ContractTypeUniswapV3Pool},
}
// Router functions
sv.signatures["WETH()"] = &FunctionSignature{
Name: "WETH()",
Selector: []byte{0xad, 0x5c, 0x46, 0x48},
AllowedTypes: []ContractType{
ContractTypeUniswapV2Router,
ContractTypeUniswapV3Router,
},
}
sv.signatures["factory()"] = &FunctionSignature{
Name: "factory()",
Selector: []byte{0xc4, 0x5a, 0x01, 0x55},
AllowedTypes: []ContractType{
ContractTypeUniswapV2Router,
ContractTypeUniswapV3Router,
},
}
}
// ValidationResult contains the result of function signature validation
type ValidationResult struct {
IsValid bool
FunctionName string
ContractType ContractType
Error error
Warnings []string
}
// ValidateFunctionCall validates if a function can be called on a contract
func (sv *SignatureValidator) ValidateFunctionCall(ctx context.Context, contractAddress common.Address, functionSelector []byte) *ValidationResult {
result := &ValidationResult{
IsValid: false,
Warnings: []string{},
}
// Detect contract type
detection := sv.detector.DetectContractType(ctx, contractAddress)
result.ContractType = detection.ContractType
if detection.Error != nil {
result.Error = fmt.Errorf("contract type detection failed: %w", detection.Error)
return result
}
// Find matching function signature
var matchedSignature *FunctionSignature
for _, sig := range sv.signatures {
if len(sig.Selector) >= 4 && len(functionSelector) >= 4 {
if sig.Selector[0] == functionSelector[0] &&
sig.Selector[1] == functionSelector[1] &&
sig.Selector[2] == functionSelector[2] &&
sig.Selector[3] == functionSelector[3] {
matchedSignature = sig
result.FunctionName = sig.Name
break
}
}
}
// If no signature match found, warn but allow (could be unknown function)
if matchedSignature == nil {
result.IsValid = true
result.Warnings = append(result.Warnings, fmt.Sprintf("unknown function selector: %x", functionSelector))
return result
}
// Check if the detected contract type is allowed for this function
allowed := false
for _, allowedType := range matchedSignature.AllowedTypes {
if detection.ContractType == allowedType {
allowed = true
break
}
}
if !allowed {
result.Error = fmt.Errorf("function %s cannot be called on contract type %s (allowed types: %v)",
matchedSignature.Name, detection.ContractType.String(), matchedSignature.AllowedTypes)
return result
}
// Check confidence level
if detection.Confidence < 0.7 {
result.Warnings = append(result.Warnings, fmt.Sprintf("low confidence in contract type detection: %.2f", detection.Confidence))
}
result.IsValid = true
return result
}
// ValidateToken0Call specifically validates token0() function calls
func (sv *SignatureValidator) ValidateToken0Call(ctx context.Context, contractAddress common.Address) *ValidationResult {
token0Selector := []byte{0x0d, 0xfe, 0x16, 0x81}
return sv.ValidateFunctionCall(ctx, contractAddress, token0Selector)
}
// ValidateToken1Call specifically validates token1() function calls
func (sv *SignatureValidator) ValidateToken1Call(ctx context.Context, contractAddress common.Address) *ValidationResult {
token1Selector := []byte{0xd2, 0x12, 0x20, 0xa7}
return sv.ValidateFunctionCall(ctx, contractAddress, token1Selector)
}
// ValidateGetReservesCall specifically validates getReserves() function calls
func (sv *SignatureValidator) ValidateGetReservesCall(ctx context.Context, contractAddress common.Address) *ValidationResult {
getReservesSelector := []byte{0x09, 0x02, 0xf1, 0xac}
return sv.ValidateFunctionCall(ctx, contractAddress, getReservesSelector)
}
// ValidateSlot0Call specifically validates slot0() function calls for Uniswap V3
func (sv *SignatureValidator) ValidateSlot0Call(ctx context.Context, contractAddress common.Address) *ValidationResult {
slot0Selector := []byte{0x38, 0x50, 0xc7, 0xbd}
return sv.ValidateFunctionCall(ctx, contractAddress, slot0Selector)
}
// GetSupportedFunctions returns the functions supported by a contract type
func (sv *SignatureValidator) GetSupportedFunctions(contractType ContractType) []string {
var functions []string
for _, sig := range sv.signatures {
for _, allowedType := range sig.AllowedTypes {
if allowedType == contractType {
functions = append(functions, sig.Name)
break
}
}
}
return functions
}