Files
mev-beta/pkg/security/contract_validator.go
Krypto Kajun 850223a953 fix(multicall): resolve critical multicall parsing corruption issues
- Added comprehensive bounds checking to prevent buffer overruns in multicall parsing
- Implemented graduated validation system (Strict/Moderate/Permissive) to reduce false positives
- Added LRU caching system for address validation with 10-minute TTL
- Enhanced ABI decoder with missing Universal Router and Arbitrum-specific DEX signatures
- Fixed duplicate function declarations and import conflicts across multiple files
- Added error recovery mechanisms with multiple fallback strategies
- Updated tests to handle new validation behavior for suspicious addresses
- Fixed parser test expectations for improved validation system
- Applied gofmt formatting fixes to ensure code style compliance
- Fixed mutex copying issues in monitoring package by introducing MetricsSnapshot
- Resolved critical security vulnerabilities in heuristic address extraction
- Progress: Updated TODO audit from 10% to 35% complete

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 00:12:55 -05:00

565 lines
18 KiB
Go

package security
import (
"context"
"encoding/hex"
"fmt"
"math/big"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/fraktal/mev-beta/internal/logger"
)
// ContractInfo represents information about a verified contract
type ContractInfo struct {
Address common.Address `json:"address"`
BytecodeHash string `json:"bytecode_hash"`
Name string `json:"name"`
Version string `json:"version"`
DeployedAt *big.Int `json:"deployed_at"`
Deployer common.Address `json:"deployer"`
VerifiedAt time.Time `json:"verified_at"`
IsWhitelisted bool `json:"is_whitelisted"`
RiskLevel RiskLevel `json:"risk_level"`
Permissions ContractPermissions `json:"permissions"`
ABIHash string `json:"abi_hash,omitempty"`
SourceCodeHash string `json:"source_code_hash,omitempty"`
}
// ContractPermissions defines what operations are allowed with a contract
type ContractPermissions struct {
CanInteract bool `json:"can_interact"`
CanSendValue bool `json:"can_send_value"`
MaxValueWei *big.Int `json:"max_value_wei,omitempty"`
AllowedMethods []string `json:"allowed_methods,omitempty"`
RequireConfirm bool `json:"require_confirmation"`
DailyLimit *big.Int `json:"daily_limit,omitempty"`
}
// RiskLevel represents the risk assessment of a contract
type RiskLevel int
const (
RiskLevelLow RiskLevel = iota
RiskLevelMedium
RiskLevelHigh
RiskLevelCritical
RiskLevelBlocked
)
func (r RiskLevel) String() string {
switch r {
case RiskLevelLow:
return "Low"
case RiskLevelMedium:
return "Medium"
case RiskLevelHigh:
return "High"
case RiskLevelCritical:
return "Critical"
case RiskLevelBlocked:
return "Blocked"
default:
return "Unknown"
}
}
// ContractValidationResult contains the result of contract validation
type ContractValidationResult struct {
IsValid bool `json:"is_valid"`
ContractInfo *ContractInfo `json:"contract_info"`
ValidationError string `json:"validation_error,omitempty"`
Warnings []string `json:"warnings"`
ChecksPerformed []ValidationCheck `json:"checks_performed"`
RiskScore int `json:"risk_score"` // 1-10
}
// ValidationCheck represents a single validation check
type ValidationCheck struct {
Name string `json:"name"`
Passed bool `json:"passed"`
Description string `json:"description"`
Error string `json:"error,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
// ContractValidator provides secure contract validation and verification
type ContractValidator struct {
client *ethclient.Client
logger *logger.Logger
trustedContracts map[common.Address]*ContractInfo
contractCache map[common.Address]*ContractInfo
cacheMutex sync.RWMutex
config *ContractValidatorConfig
// Security tracking
interactionCounts map[common.Address]int64
dailyLimits map[common.Address]*big.Int
lastResetTime time.Time
limitsMutex sync.RWMutex
}
// ContractValidatorConfig provides configuration for the contract validator
type ContractValidatorConfig struct {
EnableBytecodeVerification bool `json:"enable_bytecode_verification"`
EnableABIValidation bool `json:"enable_abi_validation"`
RequireWhitelist bool `json:"require_whitelist"`
MaxBytecodeSize int `json:"max_bytecode_size"`
CacheTimeout time.Duration `json:"cache_timeout"`
MaxRiskScore int `json:"max_risk_score"`
BlockUnverifiedContracts bool `json:"block_unverified_contracts"`
RequireSourceCode bool `json:"require_source_code"`
EnableRealTimeValidation bool `json:"enable_realtime_validation"`
}
// NewContractValidator creates a new contract validator
func NewContractValidator(client *ethclient.Client, logger *logger.Logger, config *ContractValidatorConfig) *ContractValidator {
if config == nil {
config = getDefaultValidatorConfig()
}
return &ContractValidator{
client: client,
logger: logger,
config: config,
trustedContracts: make(map[common.Address]*ContractInfo),
contractCache: make(map[common.Address]*ContractInfo),
interactionCounts: make(map[common.Address]int64),
dailyLimits: make(map[common.Address]*big.Int),
lastResetTime: time.Now(),
}
}
// AddTrustedContract adds a contract to the trusted list
func (cv *ContractValidator) AddTrustedContract(info *ContractInfo) error {
cv.cacheMutex.Lock()
defer cv.cacheMutex.Unlock()
// Validate the contract info
if info.Address == (common.Address{}) {
return fmt.Errorf("invalid contract address")
}
if info.BytecodeHash == "" {
return fmt.Errorf("bytecode hash is required")
}
// Mark as whitelisted and set low risk
info.IsWhitelisted = true
if info.RiskLevel == 0 {
info.RiskLevel = RiskLevelLow
}
info.VerifiedAt = time.Now()
cv.trustedContracts[info.Address] = info
cv.contractCache[info.Address] = info
cv.logger.Info(fmt.Sprintf("Added trusted contract: %s (%s)", info.Address.Hex(), info.Name))
return nil
}
// ValidateContract performs comprehensive contract validation
func (cv *ContractValidator) ValidateContract(ctx context.Context, address common.Address) (*ContractValidationResult, error) {
result := &ContractValidationResult{
IsValid: false,
Warnings: make([]string, 0),
ChecksPerformed: make([]ValidationCheck, 0),
}
// Check if contract is in trusted list first
cv.cacheMutex.RLock()
if trusted, exists := cv.trustedContracts[address]; exists {
cv.cacheMutex.RUnlock()
result.IsValid = true
result.ContractInfo = trusted
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Trusted Contract Check",
Passed: true,
Description: "Contract found in trusted whitelist",
Timestamp: time.Now(),
})
return result, nil
}
// Check cache
if cached, exists := cv.contractCache[address]; exists {
if time.Since(cached.VerifiedAt) < cv.config.CacheTimeout {
cv.cacheMutex.RUnlock()
result.IsValid = true
result.ContractInfo = cached
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Cache Check",
Passed: true,
Description: "Contract found in validation cache",
Timestamp: time.Now(),
})
return result, nil
}
}
cv.cacheMutex.RUnlock()
// Perform real-time validation
contractInfo, err := cv.validateContractOnChain(ctx, address, result)
if err != nil {
result.ValidationError = err.Error()
return result, err
}
result.ContractInfo = contractInfo
result.RiskScore = cv.calculateRiskScore(contractInfo, result)
// Check if contract meets security requirements
if cv.config.RequireWhitelist && !contractInfo.IsWhitelisted {
result.ValidationError = "Contract not whitelisted"
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Whitelist Check",
Passed: false,
Description: "Contract not found in whitelist",
Error: "Contract not whitelisted",
Timestamp: time.Now(),
})
return result, fmt.Errorf("contract not whitelisted: %s", address.Hex())
}
if result.RiskScore > cv.config.MaxRiskScore {
result.ValidationError = fmt.Sprintf("Risk score too high: %d > %d", result.RiskScore, cv.config.MaxRiskScore)
return result, fmt.Errorf("contract risk score too high: %d", result.RiskScore)
}
// Cache the validation result
cv.cacheMutex.Lock()
cv.contractCache[address] = contractInfo
cv.cacheMutex.Unlock()
result.IsValid = true
return result, nil
}
// validateContractOnChain performs on-chain validation of a contract
func (cv *ContractValidator) validateContractOnChain(ctx context.Context, address common.Address, result *ContractValidationResult) (*ContractInfo, error) {
// Check if address is a contract
bytecode, err := cv.client.CodeAt(ctx, address, nil)
if err != nil {
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Bytecode Retrieval",
Passed: false,
Description: "Failed to retrieve contract bytecode",
Error: err.Error(),
Timestamp: time.Now(),
})
return nil, fmt.Errorf("failed to get contract bytecode: %w", err)
}
if len(bytecode) == 0 {
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Contract Existence",
Passed: false,
Description: "Address is not a contract (no bytecode)",
Error: "No bytecode found",
Timestamp: time.Now(),
})
return nil, fmt.Errorf("address is not a contract: %s", address.Hex())
}
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Contract Existence",
Passed: true,
Description: fmt.Sprintf("Contract bytecode found (%d bytes)", len(bytecode)),
Timestamp: time.Now(),
})
// Validate bytecode size
if cv.config.MaxBytecodeSize > 0 && len(bytecode) > cv.config.MaxBytecodeSize {
result.Warnings = append(result.Warnings, fmt.Sprintf("Large bytecode size: %d bytes", len(bytecode)))
}
// Create bytecode hash
bytecodeHash := crypto.Keccak256Hash(bytecode).Hex()
// Get deployment transaction info
deployedAt, deployer, err := cv.getDeploymentInfo(ctx, address)
if err != nil {
cv.logger.Warn(fmt.Sprintf("Could not retrieve deployment info for %s: %v", address.Hex(), err))
deployedAt = big.NewInt(0)
deployer = common.Address{}
}
// Create contract info
contractInfo := &ContractInfo{
Address: address,
BytecodeHash: bytecodeHash,
Name: "Unknown Contract",
Version: "unknown",
DeployedAt: deployedAt,
Deployer: deployer,
VerifiedAt: time.Now(),
IsWhitelisted: false,
RiskLevel: cv.assessRiskLevel(bytecode, result),
Permissions: cv.getDefaultPermissions(),
}
// Verify bytecode against known contracts if enabled
if cv.config.EnableBytecodeVerification {
cv.verifyBytecodeSignature(bytecode, contractInfo, result)
}
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Bytecode Validation",
Passed: true,
Description: "Bytecode hash calculated and verified",
Timestamp: time.Now(),
})
return contractInfo, nil
}
// getDeploymentInfo retrieves deployment information for a contract
func (cv *ContractValidator) getDeploymentInfo(ctx context.Context, address common.Address) (*big.Int, common.Address, error) {
// This is a simplified implementation
// In production, you would need to scan blocks or use an indexer
return big.NewInt(0), common.Address{}, fmt.Errorf("deployment info not available")
}
// assessRiskLevel assesses the risk level of a contract based on its bytecode
func (cv *ContractValidator) assessRiskLevel(bytecode []byte, result *ContractValidationResult) RiskLevel {
riskFactors := 0
// Check for suspicious patterns in bytecode
bytecodeStr := hex.EncodeToString(bytecode)
// Look for dangerous opcodes
dangerousOpcodes := []string{
"ff", // SELFDESTRUCT
"f4", // DELEGATECALL
"3d", // RETURNDATASIZE (often used in proxy patterns)
}
for _, opcode := range dangerousOpcodes {
if contains := func(haystack, needle string) bool {
return len(haystack) >= len(needle) && haystack[:len(needle)] == needle ||
len(haystack) > len(needle) && haystack[len(haystack)-len(needle):] == needle
}; contains(bytecodeStr, opcode) {
riskFactors++
}
}
// Check bytecode size (larger contracts may be more complex/risky)
if len(bytecode) > 20000 { // 20KB
riskFactors++
result.Warnings = append(result.Warnings, "Large contract size detected")
}
// Assess risk level based on factors
switch {
case riskFactors == 0:
return RiskLevelLow
case riskFactors <= 2:
return RiskLevelMedium
case riskFactors <= 4:
return RiskLevelHigh
default:
return RiskLevelCritical
}
}
// verifyBytecodeSignature verifies bytecode against known contract signatures
func (cv *ContractValidator) verifyBytecodeSignature(bytecode []byte, info *ContractInfo, result *ContractValidationResult) {
// Known contract bytecode hashes for common contracts
knownContracts := map[string]string{
// Uniswap V3 Factory
"0x1f98431c8ad98523631ae4a59f267346ea31f984": "uniswap_v3_factory",
// Uniswap V3 Router
"0xe592427a0aece92de3edee1f18e0157c05861564": "uniswap_v3_router",
// Add more known contracts...
}
addressStr := info.Address.Hex()
if name, exists := knownContracts[addressStr]; exists {
info.Name = name
info.IsWhitelisted = true
info.RiskLevel = RiskLevelLow
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Known Contract Verification",
Passed: true,
Description: fmt.Sprintf("Verified as known contract: %s", name),
Timestamp: time.Now(),
})
}
}
// calculateRiskScore calculates a numerical risk score (1-10)
func (cv *ContractValidator) calculateRiskScore(info *ContractInfo, result *ContractValidationResult) int {
score := 1 // Base score
// Adjust based on risk level
switch info.RiskLevel {
case RiskLevelLow:
score += 0
case RiskLevelMedium:
score += 2
case RiskLevelHigh:
score += 5
case RiskLevelCritical:
score += 8
case RiskLevelBlocked:
score = 10
}
// Adjust based on whitelist status
if !info.IsWhitelisted {
score += 2
}
// Adjust based on warnings
score += len(result.Warnings)
// Cap at 10
if score > 10 {
score = 10
}
return score
}
// getDefaultPermissions returns default permissions for unverified contracts
func (cv *ContractValidator) getDefaultPermissions() ContractPermissions {
return ContractPermissions{
CanInteract: true,
CanSendValue: false,
MaxValueWei: big.NewInt(0),
AllowedMethods: []string{}, // Empty means all methods allowed
RequireConfirm: true,
DailyLimit: big.NewInt(1000000000000000000), // 1 ETH
}
}
// ValidateTransaction validates a transaction against contract permissions
func (cv *ContractValidator) ValidateTransaction(ctx context.Context, tx *types.Transaction) error {
if tx.To() == nil {
return nil // Contract creation, allow
}
// Validate the contract
result, err := cv.ValidateContract(ctx, *tx.To())
if err != nil {
return fmt.Errorf("contract validation failed: %w", err)
}
if !result.IsValid {
return fmt.Errorf("transaction to invalid contract: %s", tx.To().Hex())
}
// Check permissions
permissions := result.ContractInfo.Permissions
// Check value transfer permission
if tx.Value().Sign() > 0 && !permissions.CanSendValue {
return fmt.Errorf("contract does not allow value transfers: %s", tx.To().Hex())
}
// Check value limits
if permissions.MaxValueWei != nil && tx.Value().Cmp(permissions.MaxValueWei) > 0 {
return fmt.Errorf("transaction value exceeds limit: %s > %s",
tx.Value().String(), permissions.MaxValueWei.String())
}
// Check daily limits
if err := cv.checkDailyLimit(*tx.To(), tx.Value()); err != nil {
return err
}
cv.logger.Debug(fmt.Sprintf("Transaction validated for contract %s", tx.To().Hex()))
return nil
}
// checkDailyLimit checks if transaction exceeds daily interaction limit
func (cv *ContractValidator) checkDailyLimit(contractAddr common.Address, value *big.Int) error {
cv.limitsMutex.Lock()
defer cv.limitsMutex.Unlock()
// Reset daily counters if needed
if time.Since(cv.lastResetTime) > 24*time.Hour {
cv.dailyLimits = make(map[common.Address]*big.Int)
cv.lastResetTime = time.Now()
}
// Get current daily usage
currentUsage, exists := cv.dailyLimits[contractAddr]
if !exists {
currentUsage = big.NewInt(0)
cv.dailyLimits[contractAddr] = currentUsage
}
// Get contract info for daily limit
cv.cacheMutex.RLock()
contractInfo, exists := cv.contractCache[contractAddr]
cv.cacheMutex.RUnlock()
if !exists {
return nil // No limit if contract not cached
}
if contractInfo.Permissions.DailyLimit == nil {
return nil // No daily limit set
}
// Check if adding this transaction would exceed limit
newUsage := new(big.Int).Add(currentUsage, value)
if newUsage.Cmp(contractInfo.Permissions.DailyLimit) > 0 {
return fmt.Errorf("daily limit exceeded for contract %s: %s + %s > %s",
contractAddr.Hex(),
currentUsage.String(),
value.String(),
contractInfo.Permissions.DailyLimit.String())
}
// Update usage
cv.dailyLimits[contractAddr] = newUsage
return nil
}
// getDefaultValidatorConfig returns default configuration
func getDefaultValidatorConfig() *ContractValidatorConfig {
return &ContractValidatorConfig{
EnableBytecodeVerification: true,
EnableABIValidation: false, // Requires additional infrastructure
RequireWhitelist: false, // Start permissive, can be tightened
MaxBytecodeSize: 50000, // 50KB
CacheTimeout: 1 * time.Hour,
MaxRiskScore: 7, // Allow medium-high risk
BlockUnverifiedContracts: false,
RequireSourceCode: false,
EnableRealTimeValidation: true,
}
}
// GetContractInfo returns information about a validated contract
func (cv *ContractValidator) GetContractInfo(address common.Address) (*ContractInfo, bool) {
cv.cacheMutex.RLock()
defer cv.cacheMutex.RUnlock()
if info, exists := cv.contractCache[address]; exists {
return info, true
}
return nil, false
}
// ListTrustedContracts returns all trusted contracts
func (cv *ContractValidator) ListTrustedContracts() map[common.Address]*ContractInfo {
cv.cacheMutex.RLock()
defer cv.cacheMutex.RUnlock()
// Create a copy to avoid race conditions
trusted := make(map[common.Address]*ContractInfo)
for addr, info := range cv.trustedContracts {
trusted[addr] = info
}
return trusted
}