package security import ( "context" "encoding/hex" "fmt" "math/big" "sync" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethclient" "github.com/fraktal/mev-beta/internal/logger" ) // ContractInfo represents information about a verified contract type ContractInfo struct { Address common.Address `json:"address"` BytecodeHash string `json:"bytecode_hash"` Name string `json:"name"` Version string `json:"version"` DeployedAt *big.Int `json:"deployed_at"` Deployer common.Address `json:"deployer"` VerifiedAt time.Time `json:"verified_at"` IsWhitelisted bool `json:"is_whitelisted"` RiskLevel RiskLevel `json:"risk_level"` Permissions ContractPermissions `json:"permissions"` ABIHash string `json:"abi_hash,omitempty"` SourceCodeHash string `json:"source_code_hash,omitempty"` } // ContractPermissions defines what operations are allowed with a contract type ContractPermissions struct { CanInteract bool `json:"can_interact"` CanSendValue bool `json:"can_send_value"` MaxValueWei *big.Int `json:"max_value_wei,omitempty"` AllowedMethods []string `json:"allowed_methods,omitempty"` RequireConfirm bool `json:"require_confirmation"` DailyLimit *big.Int `json:"daily_limit,omitempty"` } // RiskLevel represents the risk assessment of a contract type RiskLevel int const ( RiskLevelLow RiskLevel = iota RiskLevelMedium RiskLevelHigh RiskLevelCritical RiskLevelBlocked ) func (r RiskLevel) String() string { switch r { case RiskLevelLow: return "Low" case RiskLevelMedium: return "Medium" case RiskLevelHigh: return "High" case RiskLevelCritical: return "Critical" case RiskLevelBlocked: return "Blocked" default: return "Unknown" } } // ContractValidationResult contains the result of contract validation type ContractValidationResult struct { IsValid bool `json:"is_valid"` ContractInfo *ContractInfo `json:"contract_info"` ValidationError string `json:"validation_error,omitempty"` Warnings []string `json:"warnings"` ChecksPerformed []ValidationCheck `json:"checks_performed"` RiskScore int `json:"risk_score"` // 1-10 } // ValidationCheck represents a single validation check type ValidationCheck struct { Name string `json:"name"` Passed bool `json:"passed"` Description string `json:"description"` Error string `json:"error,omitempty"` Timestamp time.Time `json:"timestamp"` } // ContractValidator provides secure contract validation and verification type ContractValidator struct { client *ethclient.Client logger *logger.Logger trustedContracts map[common.Address]*ContractInfo contractCache map[common.Address]*ContractInfo cacheMutex sync.RWMutex config *ContractValidatorConfig // Security tracking interactionCounts map[common.Address]int64 dailyLimits map[common.Address]*big.Int lastResetTime time.Time limitsMutex sync.RWMutex } // ContractValidatorConfig provides configuration for the contract validator type ContractValidatorConfig struct { EnableBytecodeVerification bool `json:"enable_bytecode_verification"` EnableABIValidation bool `json:"enable_abi_validation"` RequireWhitelist bool `json:"require_whitelist"` MaxBytecodeSize int `json:"max_bytecode_size"` CacheTimeout time.Duration `json:"cache_timeout"` MaxRiskScore int `json:"max_risk_score"` BlockUnverifiedContracts bool `json:"block_unverified_contracts"` RequireSourceCode bool `json:"require_source_code"` EnableRealTimeValidation bool `json:"enable_realtime_validation"` } // NewContractValidator creates a new contract validator func NewContractValidator(client *ethclient.Client, logger *logger.Logger, config *ContractValidatorConfig) *ContractValidator { if config == nil { config = getDefaultValidatorConfig() } return &ContractValidator{ client: client, logger: logger, config: config, trustedContracts: make(map[common.Address]*ContractInfo), contractCache: make(map[common.Address]*ContractInfo), interactionCounts: make(map[common.Address]int64), dailyLimits: make(map[common.Address]*big.Int), lastResetTime: time.Now(), } } // AddTrustedContract adds a contract to the trusted list func (cv *ContractValidator) AddTrustedContract(info *ContractInfo) error { cv.cacheMutex.Lock() defer cv.cacheMutex.Unlock() // Validate the contract info if info.Address == (common.Address{}) { return fmt.Errorf("invalid contract address") } if info.BytecodeHash == "" { return fmt.Errorf("bytecode hash is required") } // Mark as whitelisted and set low risk info.IsWhitelisted = true if info.RiskLevel == 0 { info.RiskLevel = RiskLevelLow } info.VerifiedAt = time.Now() cv.trustedContracts[info.Address] = info cv.contractCache[info.Address] = info cv.logger.Info(fmt.Sprintf("Added trusted contract: %s (%s)", info.Address.Hex(), info.Name)) return nil } // ValidateContract performs comprehensive contract validation func (cv *ContractValidator) ValidateContract(ctx context.Context, address common.Address) (*ContractValidationResult, error) { result := &ContractValidationResult{ IsValid: false, Warnings: make([]string, 0), ChecksPerformed: make([]ValidationCheck, 0), } // Check if contract is in trusted list first cv.cacheMutex.RLock() if trusted, exists := cv.trustedContracts[address]; exists { cv.cacheMutex.RUnlock() result.IsValid = true result.ContractInfo = trusted result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{ Name: "Trusted Contract Check", Passed: true, Description: "Contract found in trusted whitelist", Timestamp: time.Now(), }) return result, nil } // Check cache if cached, exists := cv.contractCache[address]; exists { if time.Since(cached.VerifiedAt) < cv.config.CacheTimeout { cv.cacheMutex.RUnlock() result.IsValid = true result.ContractInfo = cached result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{ Name: "Cache Check", Passed: true, Description: "Contract found in validation cache", Timestamp: time.Now(), }) return result, nil } } cv.cacheMutex.RUnlock() // Perform real-time validation contractInfo, err := cv.validateContractOnChain(ctx, address, result) if err != nil { result.ValidationError = err.Error() return result, err } result.ContractInfo = contractInfo result.RiskScore = cv.calculateRiskScore(contractInfo, result) // Check if contract meets security requirements if cv.config.RequireWhitelist && !contractInfo.IsWhitelisted { result.ValidationError = "Contract not whitelisted" result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{ Name: "Whitelist Check", Passed: false, Description: "Contract not found in whitelist", Error: "Contract not whitelisted", Timestamp: time.Now(), }) return result, fmt.Errorf("contract not whitelisted: %s", address.Hex()) } if result.RiskScore > cv.config.MaxRiskScore { result.ValidationError = fmt.Sprintf("Risk score too high: %d > %d", result.RiskScore, cv.config.MaxRiskScore) return result, fmt.Errorf("contract risk score too high: %d", result.RiskScore) } // Cache the validation result cv.cacheMutex.Lock() cv.contractCache[address] = contractInfo cv.cacheMutex.Unlock() result.IsValid = true return result, nil } // validateContractOnChain performs on-chain validation of a contract func (cv *ContractValidator) validateContractOnChain(ctx context.Context, address common.Address, result *ContractValidationResult) (*ContractInfo, error) { // Check if address is a contract bytecode, err := cv.client.CodeAt(ctx, address, nil) if err != nil { result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{ Name: "Bytecode Retrieval", Passed: false, Description: "Failed to retrieve contract bytecode", Error: err.Error(), Timestamp: time.Now(), }) return nil, fmt.Errorf("failed to get contract bytecode: %w", err) } if len(bytecode) == 0 { result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{ Name: "Contract Existence", Passed: false, Description: "Address is not a contract (no bytecode)", Error: "No bytecode found", Timestamp: time.Now(), }) return nil, fmt.Errorf("address is not a contract: %s", address.Hex()) } result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{ Name: "Contract Existence", Passed: true, Description: fmt.Sprintf("Contract bytecode found (%d bytes)", len(bytecode)), Timestamp: time.Now(), }) // Validate bytecode size if cv.config.MaxBytecodeSize > 0 && len(bytecode) > cv.config.MaxBytecodeSize { result.Warnings = append(result.Warnings, fmt.Sprintf("Large bytecode size: %d bytes", len(bytecode))) } // Create bytecode hash bytecodeHash := crypto.Keccak256Hash(bytecode).Hex() // Get deployment transaction info deployedAt, deployer, err := cv.getDeploymentInfo(ctx, address) if err != nil { cv.logger.Warn(fmt.Sprintf("Could not retrieve deployment info for %s: %v", address.Hex(), err)) deployedAt = big.NewInt(0) deployer = common.Address{} } // Create contract info contractInfo := &ContractInfo{ Address: address, BytecodeHash: bytecodeHash, Name: "Unknown Contract", Version: "unknown", DeployedAt: deployedAt, Deployer: deployer, VerifiedAt: time.Now(), IsWhitelisted: false, RiskLevel: cv.assessRiskLevel(bytecode, result), Permissions: cv.getDefaultPermissions(), } // Verify bytecode against known contracts if enabled if cv.config.EnableBytecodeVerification { cv.verifyBytecodeSignature(bytecode, contractInfo, result) } result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{ Name: "Bytecode Validation", Passed: true, Description: "Bytecode hash calculated and verified", Timestamp: time.Now(), }) return contractInfo, nil } // getDeploymentInfo retrieves deployment information for a contract func (cv *ContractValidator) getDeploymentInfo(ctx context.Context, address common.Address) (*big.Int, common.Address, error) { // This is a simplified implementation // In production, you would need to scan blocks or use an indexer return big.NewInt(0), common.Address{}, fmt.Errorf("deployment info not available") } // assessRiskLevel assesses the risk level of a contract based on its bytecode func (cv *ContractValidator) assessRiskLevel(bytecode []byte, result *ContractValidationResult) RiskLevel { riskFactors := 0 // Check for suspicious patterns in bytecode bytecodeStr := hex.EncodeToString(bytecode) // Look for dangerous opcodes dangerousOpcodes := []string{ "ff", // SELFDESTRUCT "f4", // DELEGATECALL "3d", // RETURNDATASIZE (often used in proxy patterns) } for _, opcode := range dangerousOpcodes { if contains := func(haystack, needle string) bool { return len(haystack) >= len(needle) && haystack[:len(needle)] == needle || len(haystack) > len(needle) && haystack[len(haystack)-len(needle):] == needle }; contains(bytecodeStr, opcode) { riskFactors++ } } // Check bytecode size (larger contracts may be more complex/risky) if len(bytecode) > 20000 { // 20KB riskFactors++ result.Warnings = append(result.Warnings, "Large contract size detected") } // Assess risk level based on factors switch { case riskFactors == 0: return RiskLevelLow case riskFactors <= 2: return RiskLevelMedium case riskFactors <= 4: return RiskLevelHigh default: return RiskLevelCritical } } // verifyBytecodeSignature verifies bytecode against known contract signatures func (cv *ContractValidator) verifyBytecodeSignature(bytecode []byte, info *ContractInfo, result *ContractValidationResult) { // Known contract bytecode hashes for common contracts knownContracts := map[string]string{ // Uniswap V3 Factory "0x1f98431c8ad98523631ae4a59f267346ea31f984": "uniswap_v3_factory", // Uniswap V3 Router "0xe592427a0aece92de3edee1f18e0157c05861564": "uniswap_v3_router", // Add more known contracts... } addressStr := info.Address.Hex() if name, exists := knownContracts[addressStr]; exists { info.Name = name info.IsWhitelisted = true info.RiskLevel = RiskLevelLow result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{ Name: "Known Contract Verification", Passed: true, Description: fmt.Sprintf("Verified as known contract: %s", name), Timestamp: time.Now(), }) } } // calculateRiskScore calculates a numerical risk score (1-10) func (cv *ContractValidator) calculateRiskScore(info *ContractInfo, result *ContractValidationResult) int { score := 1 // Base score // Adjust based on risk level switch info.RiskLevel { case RiskLevelLow: score += 0 case RiskLevelMedium: score += 2 case RiskLevelHigh: score += 5 case RiskLevelCritical: score += 8 case RiskLevelBlocked: score = 10 } // Adjust based on whitelist status if !info.IsWhitelisted { score += 2 } // Adjust based on warnings score += len(result.Warnings) // Cap at 10 if score > 10 { score = 10 } return score } // getDefaultPermissions returns default permissions for unverified contracts func (cv *ContractValidator) getDefaultPermissions() ContractPermissions { return ContractPermissions{ CanInteract: true, CanSendValue: false, MaxValueWei: big.NewInt(0), AllowedMethods: []string{}, // Empty means all methods allowed RequireConfirm: true, DailyLimit: big.NewInt(1000000000000000000), // 1 ETH } } // ValidateTransaction validates a transaction against contract permissions func (cv *ContractValidator) ValidateTransaction(ctx context.Context, tx *types.Transaction) error { if tx.To() == nil { return nil // Contract creation, allow } // Validate the contract result, err := cv.ValidateContract(ctx, *tx.To()) if err != nil { return fmt.Errorf("contract validation failed: %w", err) } if !result.IsValid { return fmt.Errorf("transaction to invalid contract: %s", tx.To().Hex()) } // Check permissions permissions := result.ContractInfo.Permissions // Check value transfer permission if tx.Value().Sign() > 0 && !permissions.CanSendValue { return fmt.Errorf("contract does not allow value transfers: %s", tx.To().Hex()) } // Check value limits if permissions.MaxValueWei != nil && tx.Value().Cmp(permissions.MaxValueWei) > 0 { return fmt.Errorf("transaction value exceeds limit: %s > %s", tx.Value().String(), permissions.MaxValueWei.String()) } // Check daily limits if err := cv.checkDailyLimit(*tx.To(), tx.Value()); err != nil { return err } cv.logger.Debug(fmt.Sprintf("Transaction validated for contract %s", tx.To().Hex())) return nil } // checkDailyLimit checks if transaction exceeds daily interaction limit func (cv *ContractValidator) checkDailyLimit(contractAddr common.Address, value *big.Int) error { cv.limitsMutex.Lock() defer cv.limitsMutex.Unlock() // Reset daily counters if needed if time.Since(cv.lastResetTime) > 24*time.Hour { cv.dailyLimits = make(map[common.Address]*big.Int) cv.lastResetTime = time.Now() } // Get current daily usage currentUsage, exists := cv.dailyLimits[contractAddr] if !exists { currentUsage = big.NewInt(0) cv.dailyLimits[contractAddr] = currentUsage } // Get contract info for daily limit cv.cacheMutex.RLock() contractInfo, exists := cv.contractCache[contractAddr] cv.cacheMutex.RUnlock() if !exists { return nil // No limit if contract not cached } if contractInfo.Permissions.DailyLimit == nil { return nil // No daily limit set } // Check if adding this transaction would exceed limit newUsage := new(big.Int).Add(currentUsage, value) if newUsage.Cmp(contractInfo.Permissions.DailyLimit) > 0 { return fmt.Errorf("daily limit exceeded for contract %s: %s + %s > %s", contractAddr.Hex(), currentUsage.String(), value.String(), contractInfo.Permissions.DailyLimit.String()) } // Update usage cv.dailyLimits[contractAddr] = newUsage return nil } // getDefaultValidatorConfig returns default configuration func getDefaultValidatorConfig() *ContractValidatorConfig { return &ContractValidatorConfig{ EnableBytecodeVerification: true, EnableABIValidation: false, // Requires additional infrastructure RequireWhitelist: false, // Start permissive, can be tightened MaxBytecodeSize: 50000, // 50KB CacheTimeout: 1 * time.Hour, MaxRiskScore: 7, // Allow medium-high risk BlockUnverifiedContracts: false, RequireSourceCode: false, EnableRealTimeValidation: true, } } // GetContractInfo returns information about a validated contract func (cv *ContractValidator) GetContractInfo(address common.Address) (*ContractInfo, bool) { cv.cacheMutex.RLock() defer cv.cacheMutex.RUnlock() if info, exists := cv.contractCache[address]; exists { return info, true } return nil, false } // ListTrustedContracts returns all trusted contracts func (cv *ContractValidator) ListTrustedContracts() map[common.Address]*ContractInfo { cv.cacheMutex.RLock() defer cv.cacheMutex.RUnlock() // Create a copy to avoid race conditions trusted := make(map[common.Address]*ContractInfo) for addr, info := range cv.trustedContracts { trusted[addr] = info } return trusted }