Files
mev-beta/internal/contracts/detector.go
Krypto Kajun 850223a953 fix(multicall): resolve critical multicall parsing corruption issues
- Added comprehensive bounds checking to prevent buffer overruns in multicall parsing
- Implemented graduated validation system (Strict/Moderate/Permissive) to reduce false positives
- Added LRU caching system for address validation with 10-minute TTL
- Enhanced ABI decoder with missing Universal Router and Arbitrum-specific DEX signatures
- Fixed duplicate function declarations and import conflicts across multiple files
- Added error recovery mechanisms with multiple fallback strategies
- Updated tests to handle new validation behavior for suspicious addresses
- Fixed parser test expectations for improved validation system
- Applied gofmt formatting fixes to ensure code style compliance
- Fixed mutex copying issues in monitoring package by introducing MetricsSnapshot
- Resolved critical security vulnerabilities in heuristic address extraction
- Progress: Updated TODO audit from 10% to 35% complete

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 00:12:55 -05:00

350 lines
11 KiB
Go

package contracts
import (
"context"
"fmt"
"math/big"
"strings"
"time"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/fraktal/mev-beta/internal/logger"
)
// ContractType represents the detected type of a contract
type ContractType int
const (
ContractTypeUnknown ContractType = iota
ContractTypeERC20Token
ContractTypeUniswapV2Pool
ContractTypeUniswapV3Pool
ContractTypeUniswapV2Router
ContractTypeUniswapV3Router
ContractTypeUniversalRouter
ContractTypeFactory
ContractTypeEOA // Externally Owned Account
)
// String returns the string representation of the contract type
func (ct ContractType) String() string {
switch ct {
case ContractTypeERC20Token:
return "ERC-20"
case ContractTypeUniswapV2Pool:
return "UniswapV2Pool"
case ContractTypeUniswapV3Pool:
return "UniswapV3Pool"
case ContractTypeUniswapV2Router:
return "UniswapV2Router"
case ContractTypeUniswapV3Router:
return "UniswapV3Router"
case ContractTypeUniversalRouter:
return "UniversalRouter"
case ContractTypeFactory:
return "Factory"
case ContractTypeEOA:
return "EOA"
default:
return "Unknown"
}
}
// DetectionResult contains the result of contract type detection
type DetectionResult struct {
ContractType ContractType
IsContract bool
HasCode bool
SupportedFunctions []string
Confidence float64 // 0.0 to 1.0
Error error
Warnings []string
}
// ContractDetector provides runtime contract type detection
type ContractDetector struct {
client *ethclient.Client
logger *logger.Logger
cache map[common.Address]*DetectionResult
timeout time.Duration
}
// NewContractDetector creates a new contract detector
func NewContractDetector(client *ethclient.Client, logger *logger.Logger) *ContractDetector {
return &ContractDetector{
client: client,
logger: logger,
cache: make(map[common.Address]*DetectionResult),
timeout: 5 * time.Second,
}
}
// DetectContractType determines the type of contract at the given address
func (cd *ContractDetector) DetectContractType(ctx context.Context, address common.Address) *DetectionResult {
// Check cache first
if result, exists := cd.cache[address]; exists {
return result
}
result := &DetectionResult{
ContractType: ContractTypeUnknown,
IsContract: false,
HasCode: false,
SupportedFunctions: []string{},
Confidence: 0.0,
Warnings: []string{},
}
// Create context with timeout
ctxWithTimeout, cancel := context.WithTimeout(ctx, cd.timeout)
defer cancel()
// Check if address has code (is a contract)
code, err := cd.client.CodeAt(ctxWithTimeout, address, nil)
if err != nil {
result.Error = fmt.Errorf("failed to get code at address: %w", err)
cd.cache[address] = result
return result
}
// If no code, it's an EOA
if len(code) == 0 {
result.ContractType = ContractTypeEOA
result.IsContract = false
result.Confidence = 1.0
cd.cache[address] = result
return result
}
result.IsContract = true
result.HasCode = true
// Detect contract type by testing function signatures
contractType, confidence, supportedFunctions := cd.detectByFunctionSignatures(ctxWithTimeout, address)
result.ContractType = contractType
result.Confidence = confidence
result.SupportedFunctions = supportedFunctions
// Additional validation for high-confidence detection
if confidence > 0.8 {
if err := cd.validateContractType(ctxWithTimeout, address, contractType); err != nil {
result.Warnings = append(result.Warnings, fmt.Sprintf("validation warning: %v", err))
result.Confidence *= 0.8 // Reduce confidence
}
}
cd.cache[address] = result
return result
}
// detectByFunctionSignatures detects contract type by testing known function signatures
func (cd *ContractDetector) detectByFunctionSignatures(ctx context.Context, address common.Address) (ContractType, float64, []string) {
supportedFunctions := []string{}
scores := make(map[ContractType]float64)
// Test ERC-20 functions
erc20Functions := map[string][]byte{
"name()": {0x06, 0xfd, 0xde, 0x03},
"symbol()": {0x95, 0xd8, 0x9b, 0x41},
"decimals()": {0x31, 0x3c, 0xe5, 0x67},
"totalSupply()": {0x18, 0x16, 0x0d, 0xdd},
"balanceOf()": {0x70, 0xa0, 0x82, 0x31},
}
erc20Score := cd.testFunctionSignatures(ctx, address, erc20Functions, &supportedFunctions)
if erc20Score > 0.6 {
scores[ContractTypeERC20Token] = erc20Score
}
// Test Uniswap V2 Pool functions
v2PoolFunctions := map[string][]byte{
"token0()": {0x0d, 0xfe, 0x16, 0x81},
"token1()": {0xd2, 0x12, 0x20, 0xa7},
"getReserves()": {0x09, 0x02, 0xf1, 0xac},
"price0CumulativeLast()": {0x54, 0x1c, 0x5c, 0xfa},
"kLast()": {0x7d, 0xc0, 0xd1, 0xd0},
}
v2PoolScore := cd.testFunctionSignatures(ctx, address, v2PoolFunctions, &supportedFunctions)
if v2PoolScore > 0.6 {
scores[ContractTypeUniswapV2Pool] = v2PoolScore
}
// Test Uniswap V3 Pool functions
v3PoolFunctions := map[string][]byte{
"token0()": {0x0d, 0xfe, 0x16, 0x81},
"token1()": {0xd2, 0x12, 0x20, 0xa7},
"fee()": {0xdd, 0xca, 0x3f, 0x43},
"slot0()": {0x38, 0x50, 0xc7, 0xbd},
"liquidity()": {0x1a, 0x68, 0x65, 0x0f},
"tickSpacing()": {0xd0, 0xc9, 0x32, 0x07},
}
v3PoolScore := cd.testFunctionSignatures(ctx, address, v3PoolFunctions, &supportedFunctions)
if v3PoolScore > 0.6 {
scores[ContractTypeUniswapV3Pool] = v3PoolScore
}
// Test Router functions
routerFunctions := map[string][]byte{
"WETH()": {0xad, 0x5c, 0x46, 0x48},
"swapExactTokensForTokens()": {0x38, 0xed, 0x17, 0x39},
"factory()": {0xc4, 0x5a, 0x01, 0x55},
}
routerScore := cd.testFunctionSignatures(ctx, address, routerFunctions, &supportedFunctions)
if routerScore > 0.5 {
scores[ContractTypeUniswapV2Router] = routerScore
}
// Find highest scoring type
var bestType ContractType = ContractTypeUnknown
var bestScore float64 = 0.0
for contractType, score := range scores {
if score > bestScore {
bestScore = score
bestType = contractType
}
}
return bestType, bestScore, supportedFunctions
}
// testFunctionSignatures tests if a contract supports given function signatures
func (cd *ContractDetector) testFunctionSignatures(ctx context.Context, address common.Address, functions map[string][]byte, supportedFunctions *[]string) float64 {
supported := 0
total := len(functions)
for funcName, signature := range functions {
// Test the function call
_, err := cd.client.CallContract(ctx, ethereum.CallMsg{
To: &address,
Data: signature,
}, nil)
if err == nil {
supported++
*supportedFunctions = append(*supportedFunctions, funcName)
} else if !strings.Contains(err.Error(), "execution reverted") {
// If it's not a revert, it might be a network error, so we don't count it against
total--
}
}
if total == 0 {
return 0.0
}
return float64(supported) / float64(total)
}
// validateContractType performs additional validation for detected contract types
func (cd *ContractDetector) validateContractType(ctx context.Context, address common.Address, contractType ContractType) error {
switch contractType {
case ContractTypeERC20Token:
return cd.validateERC20(ctx, address)
case ContractTypeUniswapV2Pool:
return cd.validateUniswapV2Pool(ctx, address)
case ContractTypeUniswapV3Pool:
return cd.validateUniswapV3Pool(ctx, address)
default:
return nil // No additional validation for other types
}
}
// validateERC20 validates that a contract is actually an ERC-20 token
func (cd *ContractDetector) validateERC20(ctx context.Context, address common.Address) error {
// Test decimals() - should return a reasonable value (0-18)
decimalsData := []byte{0x31, 0x3c, 0xe5, 0x67} // decimals()
result, err := cd.client.CallContract(ctx, ethereum.CallMsg{
To: &address,
Data: decimalsData,
}, nil)
if err != nil {
return fmt.Errorf("decimals() call failed: %w", err)
}
if len(result) == 32 {
decimals := new(big.Int).SetBytes(result).Uint64()
if decimals > 18 {
return fmt.Errorf("unrealistic decimals value: %d", decimals)
}
}
return nil
}
// validateUniswapV2Pool validates that a contract is actually a Uniswap V2 pool
func (cd *ContractDetector) validateUniswapV2Pool(ctx context.Context, address common.Address) error {
// Test getReserves() - should return 3 values
getReservesData := []byte{0x09, 0x02, 0xf1, 0xac} // getReserves()
result, err := cd.client.CallContract(ctx, ethereum.CallMsg{
To: &address,
Data: getReservesData,
}, nil)
if err != nil {
return fmt.Errorf("getReserves() call failed: %w", err)
}
// Should return 3 uint112 values (reserves + timestamp)
if len(result) != 96 { // 3 * 32 bytes
return fmt.Errorf("unexpected getReserves() return length: %d", len(result))
}
return nil
}
// validateUniswapV3Pool validates that a contract is actually a Uniswap V3 pool
func (cd *ContractDetector) validateUniswapV3Pool(ctx context.Context, address common.Address) error {
// Test slot0() - should return current state
slot0Data := []byte{0x38, 0x50, 0xc7, 0xbd} // slot0()
result, err := cd.client.CallContract(ctx, ethereum.CallMsg{
To: &address,
Data: slot0Data,
}, nil)
if err != nil {
return fmt.Errorf("slot0() call failed: %w", err)
}
// Should return multiple values including sqrtPriceX96
if len(result) < 32 {
return fmt.Errorf("unexpected slot0() return length: %d", len(result))
}
return nil
}
// IsERC20Token checks if an address is an ERC-20 token
func (cd *ContractDetector) IsERC20Token(ctx context.Context, address common.Address) bool {
result := cd.DetectContractType(ctx, address)
return result.ContractType == ContractTypeERC20Token && result.Confidence > 0.7
}
// IsUniswapPool checks if an address is a Uniswap pool (V2 or V3)
func (cd *ContractDetector) IsUniswapPool(ctx context.Context, address common.Address) bool {
result := cd.DetectContractType(ctx, address)
return (result.ContractType == ContractTypeUniswapV2Pool || result.ContractType == ContractTypeUniswapV3Pool) && result.Confidence > 0.7
}
// IsRouter checks if an address is a router contract
func (cd *ContractDetector) IsRouter(ctx context.Context, address common.Address) bool {
result := cd.DetectContractType(ctx, address)
return (result.ContractType == ContractTypeUniswapV2Router ||
result.ContractType == ContractTypeUniswapV3Router ||
result.ContractType == ContractTypeUniversalRouter) && result.Confidence > 0.7
}
// ClearCache clears the detection cache
func (cd *ContractDetector) ClearCache() {
cd.cache = make(map[common.Address]*DetectionResult)
}
// GetCacheSize returns the number of cached results
func (cd *ContractDetector) GetCacheSize() int {
return len(cd.cache)
}