feat: create v2-prep branch with comprehensive planning

Restructured project for V2 refactor:

**Structure Changes:**
- Moved all V1 code to orig/ folder (preserved with git mv)
- Created docs/planning/ directory
- Added orig/README_V1.md explaining V1 preservation

**Planning Documents:**
- 00_V2_MASTER_PLAN.md: Complete architecture overview
  - Executive summary of critical V1 issues
  - High-level component architecture diagrams
  - 5-phase implementation roadmap
  - Success metrics and risk mitigation

- 07_TASK_BREAKDOWN.md: Atomic task breakdown
  - 99+ hours of detailed tasks
  - Every task < 2 hours (atomic)
  - Clear dependencies and success criteria
  - Organized by implementation phase

**V2 Key Improvements:**
- Per-exchange parsers (factory pattern)
- Multi-layer strict validation
- Multi-index pool cache
- Background validation pipeline
- Comprehensive observability

**Critical Issues Addressed:**
- Zero address tokens (strict validation + cache enrichment)
- Parsing accuracy (protocol-specific parsers)
- No audit trail (background validation channel)
- Inefficient lookups (multi-index cache)
- Stats disconnection (event-driven metrics)

Next Steps:
1. Review planning documents
2. Begin Phase 1: Foundation (P1-001 through P1-010)
3. Implement parsers in Phase 2
4. Build cache system in Phase 3
5. Add validation pipeline in Phase 4
6. Migrate and test in Phase 5

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Administrator
2025-11-10 10:14:26 +01:00
parent 1773daffe7
commit 803de231ba
411 changed files with 20390 additions and 8680 deletions

View File

@@ -1,956 +0,0 @@
package validation
import (
"fmt"
"math"
"math/big"
"regexp"
"strconv"
"strings"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/fraktal/mev-beta/internal/logger"
"github.com/fraktal/mev-beta/internal/utils"
"github.com/fraktal/mev-beta/pkg/security"
)
// safeConvertInt64ToUint64 safely converts an int64 to uint64, ensuring no negative values
func safeConvertInt64ToUint64(v int64) uint64 {
if v < 0 {
return 0
}
return uint64(v)
}
// InputValidator provides comprehensive validation for transaction parameters and user inputs
type InputValidator struct {
logger *logger.Logger
maxGasLimit uint64
maxGasPrice *big.Int
maxValue *big.Int
allowedMethods map[string]bool // method signatures that are allowed
// Regex patterns for validation
addressPattern *regexp.Regexp
txHashPattern *regexp.Regexp
blockHashPattern *regexp.Regexp
hexDataPattern *regexp.Regexp
}
// ValidationConfig contains configuration for input validation
type ValidationConfig struct {
MaxGasLimit uint64 `json:"max_gas_limit"`
MaxGasPriceGwei int64 `json:"max_gas_price_gwei"`
MaxValueEther int64 `json:"max_value_ether"`
AllowedMethods []string `json:"allowed_methods"`
RequireDeadline bool `json:"require_deadline"`
MaxDeadlineHours int `json:"max_deadline_hours"`
}
// TransactionValidationResult contains the result of transaction validation
type TransactionValidationResult struct {
IsValid bool `json:"is_valid"`
Errors []string `json:"errors"`
Warnings []string `json:"warnings"`
RiskLevel string `json:"risk_level"` // "low", "medium", "high", "critical"
EstimatedCost *big.Int `json:"estimated_cost,omitempty"`
}
// SwapParams represents swap transaction parameters
type SwapParams struct {
TokenIn common.Address `json:"token_in"`
TokenOut common.Address `json:"token_out"`
AmountIn *big.Int `json:"amount_in"`
AmountOutMinimum *big.Int `json:"amount_out_minimum"`
Fee uint32 `json:"fee"`
Recipient common.Address `json:"recipient"`
Deadline uint64 `json:"deadline"`
SlippageTolerance *big.Int `json:"slippage_tolerance"` // in basis points
}
// ArbitrageParams represents arbitrage transaction parameters
type ArbitrageParams struct {
Path []common.Address `json:"path"`
AmountIn *big.Int `json:"amount_in"`
MinAmountOut *big.Int `json:"min_amount_out"`
Deadline uint64 `json:"deadline"`
MaxGasPrice *big.Int `json:"max_gas_price"`
ProfitThreshold *big.Int `json:"profit_threshold"`
MaxSlippageBps *big.Int `json:"max_slippage_bps"`
}
// LiquidityParams represents liquidity provision parameters
type LiquidityParams struct {
Token0 common.Address `json:"token0"`
Token1 common.Address `json:"token1"`
Fee uint32 `json:"fee"`
TickLower int32 `json:"tick_lower"`
TickUpper int32 `json:"tick_upper"`
Amount0Desired *big.Int `json:"amount0_desired"`
Amount1Desired *big.Int `json:"amount1_desired"`
Amount0Min *big.Int `json:"amount0_min"`
Amount1Min *big.Int `json:"amount1_min"`
Recipient common.Address `json:"recipient"`
Deadline uint64 `json:"deadline"`
}
// NewInputValidator creates a new input validator
func NewInputValidator(config *ValidationConfig, logger *logger.Logger) *InputValidator {
if config == nil {
config = getDefaultValidationConfig()
}
validator := &InputValidator{
logger: logger,
maxGasLimit: config.MaxGasLimit,
maxGasPrice: big.NewInt(config.MaxGasPriceGwei * 1e9), // Convert Gwei to Wei
maxValue: big.NewInt(config.MaxValueEther * 1e18), // Convert Ether to Wei
allowedMethods: make(map[string]bool),
addressPattern: regexp.MustCompile(`^0x[a-fA-F0-9]{40}$`),
txHashPattern: regexp.MustCompile(`^0x[a-fA-F0-9]{64}$`),
blockHashPattern: regexp.MustCompile(`^0x[a-fA-F0-9]{64}$`),
hexDataPattern: regexp.MustCompile(`^0x[a-fA-F0-9]*$`),
}
// Initialize allowed methods
for _, method := range config.AllowedMethods {
validator.allowedMethods[method] = true
}
return validator
}
// ValidateTransaction performs comprehensive validation of a transaction
func (iv *InputValidator) ValidateTransaction(tx *types.Transaction) (*TransactionValidationResult, error) {
result := &TransactionValidationResult{
IsValid: true,
Errors: make([]string, 0),
Warnings: make([]string, 0),
RiskLevel: "low",
}
// 0. Early check for nil or malformed transactions
if tx == nil {
result.IsValid = false
result.Errors = append(result.Errors, "transaction is nil")
return result, nil
}
// Skip validation for known problematic transactions to reduce log spam
txHash := tx.Hash().Hex()
if iv.isKnownProblematicTransaction(txHash) {
result.IsValid = false
// Don't add to errors to avoid logging spam
return result, nil
}
// 1. Basic transaction validation
iv.validateBasicTransaction(tx, result)
// 2. Gas validation
iv.validateGas(tx, result)
// 3. Value validation
iv.validateValue(tx, result)
// 4. Recipient validation
iv.validateRecipient(tx, result)
// 5. Data validation (for contract calls)
if len(tx.Data()) > 0 {
iv.validateData(tx.Data(), result)
}
// 6. Calculate estimated cost
result.EstimatedCost = iv.calculateEstimatedCost(tx)
// 7. Determine final validity and risk level
iv.finalizeValidation(result)
if len(result.Errors) > 0 {
iv.logger.Warn(fmt.Sprintf("Transaction validation failed: %v", result.Errors))
}
return result, nil
}
// ValidateSwapParams validates swap transaction parameters
func (iv *InputValidator) ValidateSwapParams(params *SwapParams) (*TransactionValidationResult, error) {
result := &TransactionValidationResult{
IsValid: true,
Errors: make([]string, 0),
Warnings: make([]string, 0),
RiskLevel: "low",
}
// 1. Validate token addresses
if err := iv.ValidateAddress(params.TokenIn); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid token_in address: %v", err))
}
if err := iv.ValidateAddress(params.TokenOut); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid token_out address: %v", err))
}
// 2. Check tokens are different
if params.TokenIn == params.TokenOut {
result.Errors = append(result.Errors, "token_in and token_out must be different")
}
// 3. Validate amounts
if err := iv.ValidateAmount(params.AmountIn); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid amount_in: %v", err))
}
if err := iv.ValidateAmount(params.AmountOutMinimum); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid amount_out_minimum: %v", err))
}
// 4. Validate slippage tolerance
if err := iv.ValidateSlippage(params.SlippageTolerance); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid slippage tolerance: %v", err))
}
// 5. Validate fee tier
if err := iv.validateFeeTier(params.Fee); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid fee tier: %v", err))
}
// 6. Validate recipient
if err := iv.ValidateAddress(params.Recipient); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid recipient address: %v", err))
}
// 7. Validate deadline
if err := iv.validateDeadline(params.Deadline); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid deadline: %v", err))
}
// 8. Additional security checks
iv.performSwapSecurityChecks(params, result)
iv.finalizeValidation(result)
return result, nil
}
// ValidateArbitrageParams validates arbitrage transaction parameters
func (iv *InputValidator) ValidateArbitrageParams(params *ArbitrageParams) (*TransactionValidationResult, error) {
result := &TransactionValidationResult{
IsValid: true,
Errors: make([]string, 0),
Warnings: make([]string, 0),
RiskLevel: "medium", // Arbitrage is inherently riskier
}
// 1. Validate path
if len(params.Path) < 2 {
result.Errors = append(result.Errors, "arbitrage path must have at least 2 tokens")
}
if len(params.Path) > 5 {
result.Warnings = append(result.Warnings, "long arbitrage paths increase gas costs and slippage")
}
for i, addr := range params.Path {
if err := iv.ValidateAddress(addr); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid address at path[%d]: %v", i, err))
}
}
// 2. Check for duplicate tokens in path
seen := make(map[common.Address]bool)
for _, addr := range params.Path {
if seen[addr] {
result.Errors = append(result.Errors, fmt.Sprintf("Duplicate token in path: %s", addr.Hex()))
}
seen[addr] = true
}
// 3. Validate amounts
if err := iv.ValidateAmount(params.AmountIn); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid amount_in: %v", err))
}
if err := iv.ValidateAmount(params.MinAmountOut); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid min_amount_out: %v", err))
}
if err := iv.ValidateAmount(params.ProfitThreshold); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid profit_threshold: %v", err))
}
// 4. Validate profit expectation
if params.MinAmountOut.Cmp(params.AmountIn) <= 0 {
result.Errors = append(result.Errors, "min_amount_out must be greater than amount_in for profitable arbitrage")
}
// 5. Validate slippage
if err := iv.ValidateSlippage(params.MaxSlippageBps); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid max_slippage: %v", err))
}
// 6. Validate gas price
if params.MaxGasPrice != nil && params.MaxGasPrice.Cmp(iv.maxGasPrice) > 0 {
result.Warnings = append(result.Warnings, "very high gas price may eat into profits")
}
// 7. Validate deadline
if err := iv.validateDeadline(params.Deadline); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Invalid deadline: %v", err))
}
iv.finalizeValidation(result)
return result, nil
}
// ValidateAddress validates an Ethereum address
func (iv *InputValidator) ValidateAddress(addr common.Address) error {
if addr == (common.Address{}) {
return fmt.Errorf("address cannot be zero")
}
// Check format using regex
if !iv.addressPattern.MatchString(addr.Hex()) {
return fmt.Errorf("invalid address format")
}
// Check for common invalid addresses
if iv.isKnownInvalidAddress(addr) {
return fmt.Errorf("address is known to be invalid or malicious")
}
return nil
}
// ValidateAmount validates a big.Int amount
func (iv *InputValidator) ValidateAmount(amount *big.Int) error {
if amount == nil {
return fmt.Errorf("amount cannot be nil")
}
if amount.Sign() < 0 {
return fmt.Errorf("amount cannot be negative")
}
if amount.Sign() == 0 {
return fmt.Errorf("amount cannot be zero")
}
// Check for unreasonably large amounts (prevent overflow attacks)
maxAmount := new(big.Int).Exp(big.NewInt(10), big.NewInt(30), nil) // 10^30 wei
if amount.Cmp(maxAmount) > 0 {
return fmt.Errorf("amount exceeds maximum allowed value")
}
return nil
}
// ValidateSlippage validates slippage tolerance in basis points
func (iv *InputValidator) ValidateSlippage(slippageBps *big.Int) error {
if slippageBps == nil {
return fmt.Errorf("slippage cannot be nil")
}
if slippageBps.Sign() < 0 {
return fmt.Errorf("slippage cannot be negative")
}
// Maximum 50% slippage (5000 basis points)
maxSlippage := big.NewInt(5000)
if slippageBps.Cmp(maxSlippage) > 0 {
return fmt.Errorf("slippage tolerance cannot exceed 50%%")
}
return nil
}
// validateBasicTransaction validates basic transaction properties
func (iv *InputValidator) validateBasicTransaction(tx *types.Transaction, result *TransactionValidationResult) {
// Check nonce
if tx.Nonce() > 1000000 {
result.Warnings = append(result.Warnings, "unusually high nonce")
}
// Check transaction size
txSize := len(tx.Data()) + 200 // approximate overhead
if txSize > 128*1024 { // 128KB limit
result.Errors = append(result.Errors, "transaction size exceeds limit")
}
}
// validateGas validates gas-related parameters
func (iv *InputValidator) validateGas(tx *types.Transaction, result *TransactionValidationResult) {
// Validate gas limit
if tx.Gas() == 0 {
result.Errors = append(result.Errors, "gas limit cannot be zero")
}
if tx.Gas() > iv.maxGasLimit {
result.Errors = append(result.Errors, fmt.Sprintf("gas limit %d exceeds maximum %d", tx.Gas(), iv.maxGasLimit))
}
// Validate gas price
if tx.GasPrice() != nil {
if tx.GasPrice().Sign() == 0 {
result.Errors = append(result.Errors, "gas price cannot be zero")
}
if tx.GasPrice().Cmp(iv.maxGasPrice) > 0 {
result.Errors = append(result.Errors, fmt.Sprintf("gas price exceeds maximum"))
}
// Warn about very high gas prices
highGasPrice := new(big.Int).Mul(big.NewInt(100), big.NewInt(1e9)) // 100 Gwei
if tx.GasPrice().Cmp(highGasPrice) > 0 {
result.Warnings = append(result.Warnings, "very high gas price")
result.RiskLevel = "medium"
}
}
// Validate gas fee cap and tip for EIP-1559 transactions
if tx.GasFeeCap() != nil {
if tx.GasFeeCap().Sign() == 0 {
result.Errors = append(result.Errors, "gas fee cap cannot be zero")
}
if tx.GasFeeCap().Cmp(iv.maxGasPrice) > 0 {
result.Errors = append(result.Errors, "gas fee cap exceeds maximum")
}
}
if tx.GasTipCap() != nil {
if tx.GasTipCap().Sign() < 0 {
result.Errors = append(result.Errors, "gas tip cap cannot be negative")
}
if tx.GasFeeCap() != nil && tx.GasTipCap().Cmp(tx.GasFeeCap()) > 0 {
result.Errors = append(result.Errors, "gas tip cap cannot exceed gas fee cap")
}
}
}
// validateValue validates the transaction value
func (iv *InputValidator) validateValue(tx *types.Transaction, result *TransactionValidationResult) {
if tx.Value() == nil {
return
}
if tx.Value().Sign() < 0 {
result.Errors = append(result.Errors, "transaction value cannot be negative")
}
if tx.Value().Cmp(iv.maxValue) > 0 {
result.Errors = append(result.Errors, "transaction value exceeds maximum allowed")
}
// Warn about large value transfers
largeValue := new(big.Int).Mul(big.NewInt(10), big.NewInt(1e18)) // 10 ETH
if tx.Value().Cmp(largeValue) > 0 {
result.Warnings = append(result.Warnings, "large value transfer")
result.RiskLevel = "high"
}
}
// validateRecipient validates the transaction recipient
func (iv *InputValidator) validateRecipient(tx *types.Transaction, result *TransactionValidationResult) {
if tx.To() == nil {
// Contract creation transaction
result.Warnings = append(result.Warnings, "contract creation transaction")
result.RiskLevel = "high"
return
}
// Check for zero address
if *tx.To() == (common.Address{}) {
result.Errors = append(result.Errors, "recipient cannot be zero address")
}
// Check for known malicious addresses
if iv.isKnownInvalidAddress(*tx.To()) {
result.Errors = append(result.Errors, "recipient is known malicious address")
}
}
// validateData validates transaction data for contract calls
func (iv *InputValidator) validateData(data []byte, result *TransactionValidationResult) {
if len(data) == 0 {
return
}
if len(data) < 4 {
result.Errors = append(result.Errors, "invalid function call data")
return
}
// Extract function selector
selector := data[:4]
methodSig := fmt.Sprintf("0x%x", selector)
// Check if method is allowed
if len(iv.allowedMethods) > 0 && !iv.allowedMethods[methodSig] {
result.Errors = append(result.Errors, fmt.Sprintf("method %s not allowed", methodSig))
}
// Check for suspicious patterns
if iv.hasSuspiciousPatterns(data) {
result.Warnings = append(result.Warnings, "transaction data contains suspicious patterns")
result.RiskLevel = "high"
}
}
// validateFeeTier validates Uniswap V3 fee tiers
func (iv *InputValidator) validateFeeTier(fee uint32) error {
validFees := []uint32{100, 500, 3000, 10000} // 0.01%, 0.05%, 0.3%, 1%
for _, validFee := range validFees {
if fee == validFee {
return nil
}
}
return fmt.Errorf("invalid fee tier: %d (must be one of: 100, 500, 3000, 10000)", fee)
}
// validateDeadline validates transaction deadline
func (iv *InputValidator) validateDeadline(deadline uint64) error {
if deadline == 0 {
return fmt.Errorf("deadline cannot be zero")
}
now := safeConvertInt64ToUint64(time.Now().Unix())
if deadline <= now {
return fmt.Errorf("deadline must be in the future")
}
// Warn about very long deadlines
maxDeadline := now + 24*60*60 // 24 hours from now
if deadline > maxDeadline {
return fmt.Errorf("deadline too far in future (max 24 hours)")
}
return nil
}
// performSwapSecurityChecks performs additional security checks for swap parameters
func (iv *InputValidator) performSwapSecurityChecks(params *SwapParams, result *TransactionValidationResult) {
// Check for sandwich attack vulnerability
if params.SlippageTolerance != nil && params.SlippageTolerance.Cmp(big.NewInt(500)) > 0 { // >5%
result.Warnings = append(result.Warnings, "high slippage tolerance increases sandwich attack risk")
result.RiskLevel = "medium"
}
// Check for MEV vulnerability
if params.AmountIn != nil {
// Large trades are more susceptible to MEV
largeTradeThreshold := new(big.Int).Mul(big.NewInt(100), big.NewInt(1e18)) // 100 tokens
if params.AmountIn.Cmp(largeTradeThreshold) > 0 {
result.Warnings = append(result.Warnings, "large trade may be subject to MEV attacks")
}
}
// Check deadline proximity
now := safeConvertInt64ToUint64(time.Now().Unix())
if params.Deadline-now < 60 { // Less than 1 minute
result.Warnings = append(result.Warnings, "very short deadline may cause transaction failures")
}
}
// calculateEstimatedCost estimates the total cost of a transaction
func (iv *InputValidator) calculateEstimatedCost(tx *types.Transaction) *big.Int {
cost := new(big.Int)
// Gas cost
if tx.GasPrice() != nil {
gasInt64, err := security.SafeUint64ToInt64(tx.Gas())
if err != nil {
// Log the error but use a safe fallback
iv.logger.Error("Gas value exceeds int64 maximum", "gas", tx.Gas(), "error", err)
gasInt64 = math.MaxInt64 // Use maximum safe value as fallback
}
gasCost := new(big.Int).Mul(big.NewInt(gasInt64), tx.GasPrice())
cost.Add(cost, gasCost)
} else if tx.GasFeeCap() != nil {
// For EIP-1559 transactions, use fee cap as estimate
gasInt64, err := security.SafeUint64ToInt64(tx.Gas())
if err != nil {
// Log the error but use a safe fallback
iv.logger.Error("Gas value exceeds int64 maximum", "gas", tx.Gas(), "error", err)
gasInt64 = math.MaxInt64 // Use maximum safe value as fallback
}
gasCost := new(big.Int).Mul(big.NewInt(gasInt64), tx.GasFeeCap())
cost.Add(cost, gasCost)
}
// Value transfer
if tx.Value() != nil {
cost.Add(cost, tx.Value())
}
return cost
}
// finalizeValidation determines final validation result
func (iv *InputValidator) finalizeValidation(result *TransactionValidationResult) {
if len(result.Errors) > 0 {
result.IsValid = false
result.RiskLevel = "critical"
return
}
// Adjust risk level based on warnings
if len(result.Warnings) > 2 {
if result.RiskLevel == "low" {
result.RiskLevel = "medium"
} else if result.RiskLevel == "medium" {
result.RiskLevel = "high"
}
}
}
// Helper functions
func (iv *InputValidator) isKnownInvalidAddress(addr common.Address) bool {
// Check against known malicious addresses
// This would be populated from a real blacklist in production
maliciousAddresses := map[common.Address]bool{
// Add known malicious addresses here
}
return maliciousAddresses[addr]
}
func (iv *InputValidator) isKnownProblematicTransaction(txHash string) bool {
// List of known problematic transaction hashes that should be skipped
problematicTxs := map[string]bool{
"0xe79e4719c6770b41405f691c18be3346b691e220d730d6b61abb5dd3ac9d71f0": true,
// Add other problematic transaction hashes here
}
return problematicTxs[txHash]
}
func (iv *InputValidator) hasSuspiciousPatterns(data []byte) bool {
// Check for suspicious patterns in transaction data
// This is a simplified implementation
// Check for self-destruct calls
if len(data) >= 4 {
// selfdestruct selector: 0xff
if data[0] == 0xff {
return true
}
}
// Check for delegate calls to unknown addresses
// This would require more sophisticated analysis in production
return false
}
func getDefaultValidationConfig() *ValidationConfig {
return &ValidationConfig{
MaxGasLimit: 10000000, // 10M gas
MaxGasPriceGwei: 500, // 500 Gwei
MaxValueEther: 1000, // 1000 ETH
AllowedMethods: []string{}, // Empty means all methods allowed
RequireDeadline: true,
MaxDeadlineHours: 24,
}
}
// Legacy validation functions (keeping for backward compatibility)
// ValidateEthereumAddress validates an Ethereum address string
func (iv *InputValidator) ValidateEthereumAddress(address string) error {
if !iv.addressPattern.MatchString(address) {
return fmt.Errorf("invalid Ethereum address format")
}
return nil
}
// ValidateTransactionHash validates a transaction hash string
func (iv *InputValidator) ValidateTransactionHash(hash string) error {
if !iv.txHashPattern.MatchString(hash) {
return fmt.Errorf("invalid transaction hash format")
}
return nil
}
// ValidateBlockHash validates a block hash string
func (iv *InputValidator) ValidateBlockHash(hash string) error {
if !iv.blockHashPattern.MatchString(hash) {
return fmt.Errorf("invalid block hash format")
}
return nil
}
// ValidateEvent validates an event structure with comprehensive checks
func (iv *InputValidator) ValidateEvent(event interface{}) error {
if event == nil {
return fmt.Errorf("event cannot be nil")
}
// Use reflection to validate event structure based on type
eventType := fmt.Sprintf("%T", event)
iv.logger.Debug(fmt.Sprintf("Validating event of type: %s", eventType))
// Type-specific validation based on event structure
switch e := event.(type) {
case map[string]interface{}:
return iv.validateEventMap(e)
default:
// For other types, perform basic structural validation
return iv.validateEventStructure(event)
}
}
// validateEventMap validates map-based event structures
func (iv *InputValidator) validateEventMap(eventMap map[string]interface{}) error {
// Check for required common fields
requiredFields := []string{"type", "timestamp"}
for _, field := range requiredFields {
if _, exists := eventMap[field]; !exists {
return fmt.Errorf("missing required field: %s", field)
}
}
// Validate timestamp if present
if timestamp, ok := eventMap["timestamp"]; ok {
if err := iv.validateTimestamp(timestamp); err != nil {
return fmt.Errorf("invalid timestamp: %w", err)
}
}
// Validate addresses if present
addressFields := []string{"address", "token0", "token1", "pool", "sender", "recipient"}
for _, field := range addressFields {
if addr, exists := eventMap[field]; exists {
if addrStr, ok := addr.(string); ok {
// PHASE 2 FIX: Use safe address conversion
conversionResult := utils.SafeHexToAddress(addrStr)
if !conversionResult.IsValid {
return fmt.Errorf("invalid address in field %s: %v", field, conversionResult.Error)
}
if err := iv.ValidateCommonAddress(conversionResult.Address); err != nil {
return fmt.Errorf("invalid address in field %s: %w", field, err)
}
}
}
}
// Validate amounts if present
amountFields := []string{"amount", "amount0", "amount1", "amountIn", "amountOut", "value"}
for _, field := range amountFields {
if amount, exists := eventMap[field]; exists {
if err := iv.validateAmount(amount); err != nil {
return fmt.Errorf("invalid amount in field %s: %w", field, err)
}
}
}
iv.logger.Debug("Event map validation completed successfully")
return nil
}
// validateEventStructure validates arbitrary event structures using reflection
func (iv *InputValidator) validateEventStructure(event interface{}) error {
// Basic structural validation
eventStr := fmt.Sprintf("%+v", event)
// Check if event structure is not empty
if len(eventStr) < 10 {
return fmt.Errorf("event structure appears to be empty or malformed")
}
// Check for common patterns that indicate valid events
validPatterns := []string{
"BlockNumber", "TxHash", "Address", "Token", "Amount", "Pool",
"block", "transaction", "address", "token", "amount", "pool",
}
hasValidPattern := false
for _, pattern := range validPatterns {
if strings.Contains(eventStr, pattern) {
hasValidPattern = true
break
}
}
if !hasValidPattern {
iv.logger.Warn(fmt.Sprintf("Event structure may not contain expected fields: %s", eventStr[:min(100, len(eventStr))]))
}
iv.logger.Debug("Event structure validation completed")
return nil
}
// validateTimestamp validates timestamp values in various formats
func (iv *InputValidator) validateTimestamp(timestamp interface{}) error {
switch ts := timestamp.(type) {
case int64:
if ts < 0 || ts > time.Now().Unix()+86400 { // Not more than 1 day in future
return fmt.Errorf("timestamp out of valid range")
}
case uint64:
if ts > safeConvertInt64ToUint64(time.Now().Unix()+86400) { // Not more than 1 day in future
return fmt.Errorf("timestamp out of valid range")
}
case time.Time:
if ts.Before(time.Unix(0, 0)) || ts.After(time.Now().Add(24*time.Hour)) {
return fmt.Errorf("timestamp out of valid range")
}
case string:
// Try to parse as RFC3339 or Unix timestamp
if _, err := time.Parse(time.RFC3339, ts); err != nil {
if _, err := strconv.ParseInt(ts, 10, 64); err != nil {
return fmt.Errorf("invalid timestamp format")
}
}
default:
return fmt.Errorf("unsupported timestamp type: %T", timestamp)
}
return nil
}
// validateAmount validates amount values in various formats
func (iv *InputValidator) validateAmount(amount interface{}) error {
switch a := amount.(type) {
case *big.Int:
if a == nil {
return fmt.Errorf("amount cannot be nil")
}
if a.Sign() < 0 {
return fmt.Errorf("amount cannot be negative")
}
// Check for unreasonably large amounts (> 1e30)
maxAmount := new(big.Int).Exp(big.NewInt(10), big.NewInt(30), nil)
if a.Cmp(maxAmount) > 0 {
return fmt.Errorf("amount exceeds maximum allowed value")
}
case int64:
if a < 0 {
return fmt.Errorf("amount cannot be negative")
}
case uint64:
// Always valid for uint64
case string:
if _, ok := new(big.Int).SetString(a, 10); !ok {
return fmt.Errorf("invalid amount format")
}
case float64:
if a < 0 {
return fmt.Errorf("amount cannot be negative")
}
if a > 1e30 {
return fmt.Errorf("amount exceeds maximum allowed value")
}
default:
return fmt.Errorf("unsupported amount type: %T", amount)
}
return nil
}
// min returns the minimum of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}
// ValidateHexData validates hex data string
func (iv *InputValidator) ValidateHexData(data string) error {
if !iv.hexDataPattern.MatchString(data) {
return fmt.Errorf("invalid hex data format")
}
return nil
}
// SanitizeInput sanitizes string inputs to prevent injection attacks
func SanitizeInput(input string) string {
// Remove potentially dangerous characters
reg := regexp.MustCompile(`[^\w\s\-\.]`)
sanitized := reg.ReplaceAllString(input, "")
// Limit length
if len(sanitized) > 1000 {
sanitized = sanitized[:1000]
}
return strings.TrimSpace(sanitized)
}
// ValidateHexString validates a hex string
func ValidateHexString(hexStr string) error {
if !strings.HasPrefix(hexStr, "0x") {
return fmt.Errorf("hex string must start with 0x")
}
hexStr = hexStr[2:] // Remove 0x prefix
if len(hexStr)%2 != 0 {
return fmt.Errorf("hex string must have even length")
}
matched, err := regexp.MatchString("^[0-9a-fA-F]*$", hexStr)
if err != nil {
return err
}
if !matched {
return fmt.Errorf("invalid hex characters")
}
return nil
}
// ValidateCommonAddress validates an Ethereum address (common.Address type)
func (iv *InputValidator) ValidateCommonAddress(addr common.Address) error {
return iv.ValidateAddress(addr)
}
// ValidateBigInt validates a big.Int value with context
func (iv *InputValidator) ValidateBigInt(value *big.Int, fieldName string) error {
if value == nil {
return fmt.Errorf("%s cannot be nil", fieldName)
}
if value.Sign() < 0 {
return fmt.Errorf("%s cannot be negative", fieldName)
}
if value.Sign() == 0 {
return fmt.Errorf("%s cannot be zero", fieldName)
}
// Check for unreasonably large values
maxValue := new(big.Int).Exp(big.NewInt(10), big.NewInt(30), nil)
if value.Cmp(maxValue) > 0 {
return fmt.Errorf("%s exceeds maximum allowed value", fieldName)
}
return nil
}
// ValidateSlippageTolerance validates slippage tolerance (same as ValidateSlippage)
func (iv *InputValidator) ValidateSlippageTolerance(slippage interface{}) error {
switch v := slippage.(type) {
case *big.Int:
return iv.ValidateSlippage(v)
case float64:
if v < 0 {
return fmt.Errorf("slippage cannot be negative")
}
if v > 50.0 { // 50% maximum
return fmt.Errorf("slippage tolerance cannot exceed 50%%")
}
return nil
default:
return fmt.Errorf("unsupported slippage type: must be *big.Int or float64")
}
}
// ValidateDeadline validates a deadline timestamp (public wrapper for validateDeadline)
func (iv *InputValidator) ValidateDeadline(deadline uint64) error {
return iv.validateDeadline(deadline)
}

View File

@@ -1,12 +0,0 @@
package validation
import (
"testing"
"github.com/stretchr/testify/assert"
)
// Basic test to verify package compiles
func TestValidation(t *testing.T) {
assert.True(t, true)
}

View File

@@ -1,774 +0,0 @@
package validation
import (
"context"
"encoding/hex"
"fmt"
"math/big"
"strings"
"time"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/fraktal/mev-beta/internal/logger"
"github.com/fraktal/mev-beta/pkg/pools"
"github.com/fraktal/mev-beta/pkg/security"
"github.com/fraktal/mev-beta/pkg/uniswap"
)
// PoolValidator provides comprehensive security validation for liquidity pools
type PoolValidator struct {
client *ethclient.Client
logger *logger.Logger
create2Calculator *pools.CREATE2Calculator
trustedFactories map[common.Address]string // factory address -> name
bannedAddresses map[common.Address]string // banned address -> reason
validationCache map[common.Address]*ValidationResult
cacheTimeout time.Duration
}
// ValidationResult contains the result of pool validation
type ValidationResult struct {
IsValid bool `json:"is_valid"`
SecurityScore int `json:"security_score"` // 0-100, higher is better
Warnings []string `json:"warnings"`
Errors []string `json:"errors"`
PoolType string `json:"pool_type"` // "uniswap_v3", "uniswap_v2", etc.
Factory string `json:"factory"`
Token0 common.Address `json:"token0"`
Token1 common.Address `json:"token1"`
Fee uint32 `json:"fee,omitempty"`
CreationBlock uint64 `json:"creation_block,omitempty"`
ValidatedAt time.Time `json:"validated_at"`
FactoryVerified bool `json:"factory_verified"`
InterfaceValid bool `json:"interface_valid"`
TokensValid bool `json:"tokens_valid"`
}
// PoolValidationConfig contains configuration for pool validation
type PoolValidationConfig struct {
RequireFactoryVerification bool // Whether factory verification is mandatory
MinSecurityScore int // Minimum security score to accept (0-100)
MaxValidationTime time.Duration // Maximum time to spend on validation
AllowUnknownFactories bool // Whether to allow pools from unknown factories
RequireTokenValidation bool // Whether to validate token contracts
}
// NewPoolValidator creates a new pool validator
func NewPoolValidator(client *ethclient.Client, logger *logger.Logger) *PoolValidator {
pv := &PoolValidator{
client: client,
logger: logger,
create2Calculator: pools.NewCREATE2Calculator(logger, client),
trustedFactories: make(map[common.Address]string),
bannedAddresses: make(map[common.Address]string),
validationCache: make(map[common.Address]*ValidationResult),
cacheTimeout: 5 * time.Minute,
}
pv.initializeTrustedFactories()
pv.initializeBannedAddresses()
return pv
}
// ValidatePool performs comprehensive security validation of a pool
func (pv *PoolValidator) ValidatePool(ctx context.Context, poolAddr common.Address, config *PoolValidationConfig) (*ValidationResult, error) {
if config == nil {
config = pv.getDefaultConfig()
}
// Check cache first
if cached := pv.getCachedResult(poolAddr); cached != nil {
return cached, nil
}
// Create timeout context
timeoutCtx, cancel := context.WithTimeout(ctx, config.MaxValidationTime)
defer cancel()
result := &ValidationResult{
ValidatedAt: time.Now(),
SecurityScore: 0,
Warnings: make([]string, 0),
Errors: make([]string, 0),
}
// 1. Basic existence check
if err := pv.validateBasicExistence(timeoutCtx, poolAddr, result); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Basic validation failed: %v", err))
result.IsValid = false
pv.cacheResult(poolAddr, result)
return result, nil
}
// 2. Check against banned addresses
if err := pv.checkBannedAddresses(poolAddr, result); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Banned address check failed: %v", err))
result.IsValid = false
result.SecurityScore = 0
pv.cacheResult(poolAddr, result)
return result, nil
}
// 3. Detect pool type and validate interface
if err := pv.validatePoolInterface(timeoutCtx, poolAddr, result); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Interface validation failed: %v", err))
if config.RequireFactoryVerification {
result.IsValid = false
}
result.SecurityScore -= 30
} else {
result.InterfaceValid = true
result.SecurityScore += 25
}
// 4. Validate factory deployment (critical security check)
if err := pv.validateFactoryDeployment(timeoutCtx, poolAddr, result); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Factory validation failed: %v", err))
if config.RequireFactoryVerification {
result.IsValid = false
}
result.SecurityScore -= 40
} else {
result.FactoryVerified = true
result.SecurityScore += 30
}
// 5. Validate token contracts
if config.RequireTokenValidation {
if err := pv.validateTokenContracts(timeoutCtx, result); err != nil {
result.Warnings = append(result.Warnings, fmt.Sprintf("Token validation warning: %v", err))
result.SecurityScore -= 10
} else {
result.TokensValid = true
result.SecurityScore += 15
}
}
// 6. Additional security checks
pv.performAdditionalSecurityChecks(timeoutCtx, poolAddr, result)
// 7. Final validation decision
if result.SecurityScore >= config.MinSecurityScore && len(result.Errors) == 0 {
result.IsValid = true
} else {
result.IsValid = false
}
// Ensure security score is within bounds
if result.SecurityScore < 0 {
result.SecurityScore = 0
} else if result.SecurityScore > 100 {
result.SecurityScore = 100
}
// Cache the result
pv.cacheResult(poolAddr, result)
pv.logger.Debug(fmt.Sprintf("Pool validation complete: %s, valid=%v, score=%d",
poolAddr.Hex(), result.IsValid, result.SecurityScore))
return result, nil
}
// validateBasicExistence checks if the pool contract exists and has code
func (pv *PoolValidator) validateBasicExistence(ctx context.Context, poolAddr common.Address, result *ValidationResult) error {
// Check if contract has code
code, err := pv.client.CodeAt(ctx, poolAddr, nil)
if err != nil {
return fmt.Errorf("failed to get contract code: %w", err)
}
if len(code) == 0 {
return fmt.Errorf("no contract code at address %s", poolAddr.Hex())
}
// Basic code size check - legitimate pools should have substantial code
if len(code) < 100 {
result.Warnings = append(result.Warnings, "Contract has very small code size")
result.SecurityScore -= 10
}
return nil
}
// checkBannedAddresses verifies the pool is not on the banned list
func (pv *PoolValidator) checkBannedAddresses(poolAddr common.Address, result *ValidationResult) error {
if reason, banned := pv.bannedAddresses[poolAddr]; banned {
return fmt.Errorf("pool %s is banned: %s", poolAddr.Hex(), reason)
}
return nil
}
// validatePoolInterface detects pool type and validates the interface
func (pv *PoolValidator) validatePoolInterface(ctx context.Context, poolAddr common.Address, result *ValidationResult) error {
// Try to detect pool type by calling standard functions
// Check for Uniswap V3 interface
if pv.isUniswapV3Pool(ctx, poolAddr) {
result.PoolType = "uniswap_v3"
return pv.validateUniswapV3Interface(ctx, poolAddr, result)
}
// Check for Uniswap V2 interface
if pv.isUniswapV2Pool(ctx, poolAddr) {
result.PoolType = "uniswap_v2"
return pv.validateUniswapV2Interface(ctx, poolAddr, result)
}
// Unknown pool type
result.PoolType = "unknown"
result.Warnings = append(result.Warnings, "Unknown pool type")
return fmt.Errorf("unknown pool interface")
}
// isUniswapV3Pool checks if the pool implements Uniswap V3 interface
func (pv *PoolValidator) isUniswapV3Pool(ctx context.Context, poolAddr common.Address) bool {
// Try to call slot0() function (unique to Uniswap V3)
slot0ABI := `[{"inputs":[],"name":"slot0","outputs":[{"internalType":"uint160","name":"sqrtPriceX96","type":"uint160"},{"internalType":"int24","name":"tick","type":"int24"},{"internalType":"uint16","name":"observationIndex","type":"uint16"},{"internalType":"uint16","name":"observationCardinality","type":"uint16"},{"internalType":"uint16","name":"observationCardinalityNext","type":"uint16"},{"internalType":"uint8","name":"feeProtocol","type":"uint8"},{"internalType":"bool","name":"unlocked","type":"bool"}],"stateMutability":"view","type":"function"}]`
contractABI, err := uniswap.ParseABI(slot0ABI)
if err != nil {
return false
}
callData, err := contractABI.Pack("slot0")
if err != nil {
return false
}
_, err = pv.client.CallContract(ctx, ethereum.CallMsg{
To: &poolAddr,
Data: callData,
}, nil)
return err == nil
}
// isUniswapV2Pool checks if the pool implements Uniswap V2 interface
func (pv *PoolValidator) isUniswapV2Pool(ctx context.Context, poolAddr common.Address) bool {
// Try to call getReserves() function (standard in Uniswap V2)
reservesABI := `[{"inputs":[],"name":"getReserves","outputs":[{"internalType":"uint112","name":"_reserve0","type":"uint112"},{"internalType":"uint112","name":"_reserve1","type":"uint112"},{"internalType":"uint32","name":"_blockTimestampLast","type":"uint32"}],"stateMutability":"view","type":"function"}]`
contractABI, err := uniswap.ParseABI(reservesABI)
if err != nil {
return false
}
callData, err := contractABI.Pack("getReserves")
if err != nil {
return false
}
_, err = pv.client.CallContract(ctx, ethereum.CallMsg{
To: &poolAddr,
Data: callData,
}, nil)
return err == nil
}
// validateUniswapV3Interface validates Uniswap V3 specific functions
func (pv *PoolValidator) validateUniswapV3Interface(ctx context.Context, poolAddr common.Address, result *ValidationResult) error {
// Get token addresses and fee
token0, token1, fee, err := pv.getUniswapV3PoolInfo(ctx, poolAddr)
if err != nil {
return fmt.Errorf("failed to get V3 pool info: %w", err)
}
result.Token0 = token0
result.Token1 = token1
result.Fee = fee
// Validate token ordering (token0 < token1)
if token0.Big().Cmp(token1.Big()) >= 0 {
return fmt.Errorf("invalid token ordering: token0 must be < token1")
}
// Validate fee tier
validFees := []uint32{500, 3000, 10000, 100} // 0.05%, 0.3%, 1%, 0.01%
feeValid := false
for _, validFee := range validFees {
if fee == validFee {
feeValid = true
break
}
}
if !feeValid {
result.Warnings = append(result.Warnings, fmt.Sprintf("Unusual fee tier: %d", fee))
}
return nil
}
// validateUniswapV2Interface validates Uniswap V2 specific functions
func (pv *PoolValidator) validateUniswapV2Interface(ctx context.Context, poolAddr common.Address, result *ValidationResult) error {
// Get token addresses
token0, token1, err := pv.getUniswapV2PoolInfo(ctx, poolAddr)
if err != nil {
return fmt.Errorf("failed to get V2 pool info: %w", err)
}
result.Token0 = token0
result.Token1 = token1
result.Fee = 3000 // V2 has fixed 0.3% fee
// Validate token ordering
if token0.Big().Cmp(token1.Big()) >= 0 {
return fmt.Errorf("invalid token ordering: token0 must be < token1")
}
return nil
}
// validateFactoryDeployment verifies the pool was deployed by a trusted factory
func (pv *PoolValidator) validateFactoryDeployment(ctx context.Context, poolAddr common.Address, result *ValidationResult) error {
// For each known factory, try to verify if this pool could have been deployed by it
for factoryAddr, factoryName := range pv.trustedFactories {
if pv.verifyFactoryDeployment(factoryAddr, factoryName, poolAddr, result) {
result.Factory = factoryName
return nil
}
}
return fmt.Errorf("pool not deployed by any trusted factory")
}
// verifyFactoryDeployment verifies a specific factory deployed the pool
func (pv *PoolValidator) verifyFactoryDeployment(factoryAddr common.Address, factoryName string, poolAddr common.Address, result *ValidationResult) bool {
if result.Token0 == (common.Address{}) || result.Token1 == (common.Address{}) {
return false
}
// Use CREATE2 calculator to verify the pool address
return pv.create2Calculator.ValidatePoolAddress(factoryName, result.Token0, result.Token1, result.Fee, poolAddr)
}
// validateTokenContracts validates the token contracts in the pool
func (pv *PoolValidator) validateTokenContracts(ctx context.Context, result *ValidationResult) error {
if result.Token0 == (common.Address{}) || result.Token1 == (common.Address{}) {
return fmt.Errorf("token addresses not available")
}
// Validate token0
if err := pv.validateTokenContract(ctx, result.Token0); err != nil {
return fmt.Errorf("token0 validation failed: %w", err)
}
// Validate token1
if err := pv.validateTokenContract(ctx, result.Token1); err != nil {
return fmt.Errorf("token1 validation failed: %w", err)
}
return nil
}
// validateTokenContract validates a single token contract
func (pv *PoolValidator) validateTokenContract(ctx context.Context, tokenAddr common.Address) error {
// Check if contract exists
code, err := pv.client.CodeAt(ctx, tokenAddr, nil)
if err != nil {
return fmt.Errorf("failed to get token contract code: %w", err)
}
if len(code) == 0 {
return fmt.Errorf("no contract code at token address %s", tokenAddr.Hex())
}
// Try to call standard ERC20 functions
return pv.validateERC20Interface(ctx, tokenAddr)
}
// validateERC20Interface validates ERC20 token interface
func (pv *PoolValidator) validateERC20Interface(ctx context.Context, tokenAddr common.Address) error {
// Try to call totalSupply() function
erc20ABI := `[{"inputs":[],"name":"totalSupply","outputs":[{"internalType":"uint256","name":"","type":"uint256"}],"stateMutability":"view","type":"function"}]`
contractABI, err := uniswap.ParseABI(erc20ABI)
if err != nil {
return fmt.Errorf("failed to parse ERC20 ABI: %w", err)
}
callData, err := contractABI.Pack("totalSupply")
if err != nil {
return fmt.Errorf("failed to pack totalSupply call: %w", err)
}
_, err = pv.client.CallContract(ctx, ethereum.CallMsg{
To: &tokenAddr,
Data: callData,
}, nil)
if err != nil {
return fmt.Errorf("totalSupply call failed: %w", err)
}
return nil
}
// performAdditionalSecurityChecks performs various security checks
func (pv *PoolValidator) performAdditionalSecurityChecks(ctx context.Context, poolAddr common.Address, result *ValidationResult) {
// Check contract creation time
if creationBlock := pv.getContractCreationBlock(ctx, poolAddr); creationBlock > 0 {
result.CreationBlock = creationBlock
// Warn about very new contracts
currentBlock, err := pv.client.BlockNumber(ctx)
if err == nil && currentBlock-creationBlock < 100 {
result.Warnings = append(result.Warnings, "Pool is very new (< 100 blocks old)")
result.SecurityScore -= 5
}
}
// Check for common attack patterns
pv.checkForAttackPatterns(ctx, poolAddr, result)
}
// getContractCreationBlock attempts to find when the contract was created using binary search
func (pv *PoolValidator) getContractCreationBlock(ctx context.Context, addr common.Address) uint64 {
pv.logger.Debug(fmt.Sprintf("Finding creation block for contract %s", addr.Hex()))
// Get latest block number first
latestBlock, err := pv.client.BlockNumber(ctx)
if err != nil {
pv.logger.Warn(fmt.Sprintf("Failed to get latest block number: %v", err))
return 0
}
// Check if contract exists at latest block
codeAtLatest, err := pv.client.CodeAt(ctx, addr, new(big.Int).SetUint64(latestBlock))
if err != nil || len(codeAtLatest) == 0 {
pv.logger.Debug(fmt.Sprintf("Contract %s does not exist at latest block", addr.Hex()))
return 0
}
// Binary search to find creation block
// Start with a reasonable range - most pools created in last 10M blocks
searchStart := uint64(0)
if latestBlock > 10000000 {
searchStart = latestBlock - 10000000
}
creationBlock := pv.binarySearchCreationBlock(ctx, addr, searchStart, latestBlock)
if creationBlock > 0 {
pv.logger.Debug(fmt.Sprintf("Contract %s created at block %d", addr.Hex(), creationBlock))
}
return creationBlock
}
// binarySearchCreationBlock performs binary search to find the exact creation block
func (pv *PoolValidator) binarySearchCreationBlock(ctx context.Context, addr common.Address, start, end uint64) uint64 {
// Limit search iterations to prevent infinite loops
maxIterations := 50
iteration := 0
for start <= end && iteration < maxIterations {
iteration++
mid := (start + end) / 2
// Check if contract exists at mid block
code, err := pv.client.CodeAt(ctx, addr, new(big.Int).SetUint64(mid))
if err != nil {
pv.logger.Debug(fmt.Sprintf("Error checking code at block %d: %v", mid, err))
break
}
hasCode := len(code) > 0
if hasCode {
// Contract exists at mid, check if it exists at mid-1
if mid == 0 {
return mid
}
prevCode, err := pv.client.CodeAt(ctx, addr, new(big.Int).SetUint64(mid-1))
if err != nil || len(prevCode) == 0 {
// Contract doesn't exist at mid-1 but exists at mid
return mid
}
// Contract exists at both mid and mid-1, search earlier
end = mid - 1
} else {
// Contract doesn't exist at mid, search later
start = mid + 1
}
// Add small delay to avoid rate limiting
if iteration%10 == 0 {
select {
case <-ctx.Done():
return 0
case <-time.After(100 * time.Millisecond):
}
}
}
// If we couldn't find exact block, return start as best estimate
if start <= end {
return start
}
return 0
}
// checkForAttackPatterns looks for common malicious patterns
func (pv *PoolValidator) checkForAttackPatterns(ctx context.Context, poolAddr common.Address, result *ValidationResult) {
// Check if contract is a proxy (may be suspicious)
if pv.isProxyContract(ctx, poolAddr) {
result.Warnings = append(result.Warnings, "Contract appears to be a proxy")
result.SecurityScore -= 10
}
// Check for unusual bytecode patterns
if pv.hasUnusualBytecode(ctx, poolAddr) {
result.Warnings = append(result.Warnings, "Unusual bytecode patterns detected")
result.SecurityScore -= 15
}
}
// isProxyContract checks if the contract is a proxy
func (pv *PoolValidator) isProxyContract(ctx context.Context, addr common.Address) bool {
code, err := pv.client.CodeAt(ctx, addr, nil)
if err != nil || len(code) == 0 {
return false
}
// Look for common proxy patterns (delegatecall, etc.)
codeHex := hex.EncodeToString(code)
return strings.Contains(codeHex, "f4") // delegatecall opcode
}
// hasUnusualBytecode checks for suspicious bytecode patterns
func (pv *PoolValidator) hasUnusualBytecode(ctx context.Context, addr common.Address) bool {
code, err := pv.client.CodeAt(ctx, addr, nil)
if err != nil || len(code) == 0 {
return false
}
// Check for unusual patterns
if len(code) > 50000 {
return true // Unusually large contract
}
// Check for high entropy (potential obfuscation)
entropy := pv.calculateEntropy(code)
return entropy > 7.5 // High entropy threshold
}
// calculateEntropy calculates Shannon entropy of bytecode
func (pv *PoolValidator) calculateEntropy(data []byte) float64 {
if len(data) == 0 {
return 0
}
freq := make(map[byte]int)
for _, b := range data {
freq[b]++
}
entropy := 0.0
length := float64(len(data))
for _, f := range freq {
p := float64(f) / length
if p > 0 {
entropy -= p * (float64(f) / length)
}
}
return entropy
}
// Helper functions to get pool information
func (pv *PoolValidator) getUniswapV3PoolInfo(ctx context.Context, poolAddr common.Address) (common.Address, common.Address, uint32, error) {
poolABI := `[{"inputs":[],"name":"token0","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"},{"inputs":[],"name":"token1","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"},{"inputs":[],"name":"fee","outputs":[{"internalType":"uint24","name":"","type":"uint24"}],"stateMutability":"view","type":"function"}]`
contractABI, err := uniswap.ParseABI(poolABI)
if err != nil {
return common.Address{}, common.Address{}, 0, err
}
// Get token0
token0Data, err := contractABI.Pack("token0")
if err != nil {
return common.Address{}, common.Address{}, 0, err
}
token0Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{
To: &poolAddr,
Data: token0Data,
}, nil)
if err != nil {
return common.Address{}, common.Address{}, 0, err
}
token0Unpacked, err := contractABI.Unpack("token0", token0Result)
if err != nil {
return common.Address{}, common.Address{}, 0, err
}
// Get token1
token1Data, err := contractABI.Pack("token1")
if err != nil {
return common.Address{}, common.Address{}, 0, err
}
token1Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{
To: &poolAddr,
Data: token1Data,
}, nil)
if err != nil {
return common.Address{}, common.Address{}, 0, err
}
token1Unpacked, err := contractABI.Unpack("token1", token1Result)
if err != nil {
return common.Address{}, common.Address{}, 0, err
}
// Get fee
feeData, err := contractABI.Pack("fee")
if err != nil {
return common.Address{}, common.Address{}, 0, err
}
feeResult, err := pv.client.CallContract(ctx, ethereum.CallMsg{
To: &poolAddr,
Data: feeData,
}, nil)
if err != nil {
return common.Address{}, common.Address{}, 0, err
}
feeUnpacked, err := contractABI.Unpack("fee", feeResult)
if err != nil {
return common.Address{}, common.Address{}, 0, err
}
token0 := token0Unpacked[0].(common.Address)
token1 := token1Unpacked[0].(common.Address)
fee := feeUnpacked[0].(*big.Int).Uint64()
feeUint32, err := security.SafeUint32(fee)
if err != nil {
return common.Address{}, common.Address{}, 0, fmt.Errorf("invalid fee conversion: %w", err)
}
return token0, token1, feeUint32, nil
}
func (pv *PoolValidator) getUniswapV2PoolInfo(ctx context.Context, poolAddr common.Address) (common.Address, common.Address, error) {
poolABI := `[{"inputs":[],"name":"token0","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"},{"inputs":[],"name":"token1","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"}]`
contractABI, err := uniswap.ParseABI(poolABI)
if err != nil {
return common.Address{}, common.Address{}, err
}
// Get token0
token0Data, err := contractABI.Pack("token0")
if err != nil {
return common.Address{}, common.Address{}, err
}
token0Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{
To: &poolAddr,
Data: token0Data,
}, nil)
if err != nil {
return common.Address{}, common.Address{}, err
}
token0Unpacked, err := contractABI.Unpack("token0", token0Result)
if err != nil {
return common.Address{}, common.Address{}, err
}
// Get token1
token1Data, err := contractABI.Pack("token1")
if err != nil {
return common.Address{}, common.Address{}, err
}
token1Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{
To: &poolAddr,
Data: token1Data,
}, nil)
if err != nil {
return common.Address{}, common.Address{}, err
}
token1Unpacked, err := contractABI.Unpack("token1", token1Result)
if err != nil {
return common.Address{}, common.Address{}, err
}
token0 := token0Unpacked[0].(common.Address)
token1 := token1Unpacked[0].(common.Address)
return token0, token1, nil
}
// Configuration and caching methods
func (pv *PoolValidator) getDefaultConfig() *PoolValidationConfig {
return &PoolValidationConfig{
RequireFactoryVerification: true,
MinSecurityScore: 70,
MaxValidationTime: 10 * time.Second,
AllowUnknownFactories: false,
RequireTokenValidation: true,
}
}
func (pv *PoolValidator) getCachedResult(addr common.Address) *ValidationResult {
if result, exists := pv.validationCache[addr]; exists {
if time.Since(result.ValidatedAt) < pv.cacheTimeout {
return result
}
delete(pv.validationCache, addr)
}
return nil
}
func (pv *PoolValidator) cacheResult(addr common.Address, result *ValidationResult) {
pv.validationCache[addr] = result
}
// Initialization methods
func (pv *PoolValidator) initializeTrustedFactories() {
// Uniswap V3
pv.trustedFactories[common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984")] = "uniswap_v3"
// Uniswap V2
pv.trustedFactories[common.HexToAddress("0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f")] = "uniswap_v2"
// SushiSwap
pv.trustedFactories[common.HexToAddress("0xC0AEe478e3658e2610c5F7A4A2E1777cE9e4f2Ac")] = "sushiswap"
// Camelot V3 (Arbitrum)
pv.trustedFactories[common.HexToAddress("0x1a3c9B1d2F0529D97f2afC5136Cc23e58f1FD35B")] = "camelot_v3"
}
func (pv *PoolValidator) initializeBannedAddresses() {
// Add known malicious or problematic pool addresses
// This would be populated with real banned addresses in production
}
// AddTrustedFactory adds a new trusted factory
func (pv *PoolValidator) AddTrustedFactory(factoryAddr common.Address, name string) {
pv.trustedFactories[factoryAddr] = name
pv.logger.Info(fmt.Sprintf("Added trusted factory: %s (%s)", name, factoryAddr.Hex()))
}
// BanAddress adds an address to the banned list
func (pv *PoolValidator) BanAddress(addr common.Address, reason string) {
pv.bannedAddresses[addr] = reason
// Clear from cache if present
delete(pv.validationCache, addr)
pv.logger.Warn(fmt.Sprintf("Banned address: %s (reason: %s)", addr.Hex(), reason))
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,265 +0,0 @@
package validation
import (
"fmt"
"math/big"
)
// PriceImpactThresholds defines the acceptable price impact levels
type PriceImpactThresholds struct {
// Low risk: < 0.5% price impact
LowThreshold float64
// Medium risk: 0.5% - 2% price impact
MediumThreshold float64
// High risk: 2% - 5% price impact
HighThreshold float64
// Extreme risk: > 5% price impact (typically unprofitable due to slippage)
ExtremeThreshold float64
// Maximum acceptable: Reject anything above this (e.g., 10%)
MaxAcceptable float64
}
// DefaultPriceImpactThresholds returns conservative production-ready thresholds
func DefaultPriceImpactThresholds() *PriceImpactThresholds {
return &PriceImpactThresholds{
LowThreshold: 0.5, // 0.5%
MediumThreshold: 2.0, // 2%
HighThreshold: 5.0, // 5%
ExtremeThreshold: 10.0, // 10%
MaxAcceptable: 15.0, // 15% - reject anything higher
}
}
// AggressivePriceImpactThresholds returns more aggressive thresholds for higher volumes
func AggressivePriceImpactThresholds() *PriceImpactThresholds {
return &PriceImpactThresholds{
LowThreshold: 1.0, // 1%
MediumThreshold: 3.0, // 3%
HighThreshold: 7.0, // 7%
ExtremeThreshold: 15.0, // 15%
MaxAcceptable: 25.0, // 25%
}
}
// ConservativePriceImpactThresholds returns very conservative thresholds for safety
func ConservativePriceImpactThresholds() *PriceImpactThresholds {
return &PriceImpactThresholds{
LowThreshold: 0.1, // 0.1%
MediumThreshold: 0.5, // 0.5%
HighThreshold: 1.0, // 1%
ExtremeThreshold: 2.0, // 2%
MaxAcceptable: 5.0, // 5%
}
}
// PriceImpactRiskLevel represents the risk level of a price impact
type PriceImpactRiskLevel string
const (
RiskLevelNegligible PriceImpactRiskLevel = "Negligible" // < 0.1%
RiskLevelLow PriceImpactRiskLevel = "Low" // 0.1-0.5%
RiskLevelMedium PriceImpactRiskLevel = "Medium" // 0.5-2%
RiskLevelHigh PriceImpactRiskLevel = "High" // 2-5%
RiskLevelExtreme PriceImpactRiskLevel = "Extreme" // 5-10%
RiskLevelUnacceptable PriceImpactRiskLevel = "Unacceptable" // > 10%
)
// PriceImpactValidationResult contains the result of price impact validation
type PriceImpactValidationResult struct {
PriceImpact float64 // The calculated price impact percentage
RiskLevel PriceImpactRiskLevel // The risk categorization
IsAcceptable bool // Whether this price impact is acceptable
Recommendation string // Human-readable recommendation
Details map[string]interface{} // Additional details
}
// PriceImpactValidator validates price impacts against configured thresholds
type PriceImpactValidator struct {
thresholds *PriceImpactThresholds
}
// NewPriceImpactValidator creates a new price impact validator
func NewPriceImpactValidator(thresholds *PriceImpactThresholds) *PriceImpactValidator {
if thresholds == nil {
thresholds = DefaultPriceImpactThresholds()
}
return &PriceImpactValidator{
thresholds: thresholds,
}
}
// ValidatePriceImpact validates a price impact percentage
func (piv *PriceImpactValidator) ValidatePriceImpact(priceImpact float64) *PriceImpactValidationResult {
result := &PriceImpactValidationResult{
PriceImpact: priceImpact,
Details: make(map[string]interface{}),
}
// Determine risk level
result.RiskLevel = piv.categorizePriceImpact(priceImpact)
// Determine if acceptable
result.IsAcceptable = priceImpact <= piv.thresholds.MaxAcceptable
// Generate recommendation
result.Recommendation = piv.generateRecommendation(priceImpact, result.RiskLevel)
// Add threshold details
result.Details["thresholds"] = map[string]float64{
"low": piv.thresholds.LowThreshold,
"medium": piv.thresholds.MediumThreshold,
"high": piv.thresholds.HighThreshold,
"extreme": piv.thresholds.ExtremeThreshold,
"max": piv.thresholds.MaxAcceptable,
}
// Add risk-specific details
result.Details["risk_level"] = string(result.RiskLevel)
result.Details["acceptable"] = result.IsAcceptable
result.Details["price_impact_percent"] = priceImpact
return result
}
// categorizePriceImpact categorizes the price impact into risk levels
func (piv *PriceImpactValidator) categorizePriceImpact(priceImpact float64) PriceImpactRiskLevel {
switch {
case priceImpact < 0.1:
return RiskLevelNegligible
case priceImpact < piv.thresholds.LowThreshold:
return RiskLevelLow
case priceImpact < piv.thresholds.MediumThreshold:
return RiskLevelMedium
case priceImpact < piv.thresholds.HighThreshold:
return RiskLevelHigh
case priceImpact < piv.thresholds.ExtremeThreshold:
return RiskLevelExtreme
default:
return RiskLevelUnacceptable
}
}
// generateRecommendation generates a recommendation based on price impact
func (piv *PriceImpactValidator) generateRecommendation(priceImpact float64, riskLevel PriceImpactRiskLevel) string {
switch riskLevel {
case RiskLevelNegligible:
return fmt.Sprintf("Excellent: Price impact of %.4f%% is negligible. Safe to execute.", priceImpact)
case RiskLevelLow:
return fmt.Sprintf("Good: Price impact of %.4f%% is low. Execute with standard slippage protection.", priceImpact)
case RiskLevelMedium:
return fmt.Sprintf("Moderate: Price impact of %.4f%% is medium. Use enhanced slippage protection and consider splitting the trade.", priceImpact)
case RiskLevelHigh:
return fmt.Sprintf("Caution: Price impact of %.4f%% is high. Strongly recommend splitting into smaller trades or waiting for better liquidity.", priceImpact)
case RiskLevelExtreme:
return fmt.Sprintf("Warning: Price impact of %.4f%% is extreme. Trade size is too large for current liquidity. Split trade or skip.", priceImpact)
case RiskLevelUnacceptable:
return fmt.Sprintf("Reject: Price impact of %.4f%% exceeds maximum acceptable threshold (%.2f%%). Do not execute.", priceImpact, piv.thresholds.MaxAcceptable)
default:
return "Unknown risk level"
}
}
// ValidatePriceImpactWithLiquidity validates price impact considering trade size and liquidity
func (piv *PriceImpactValidator) ValidatePriceImpactWithLiquidity(tradeSize, liquidity *big.Int) *PriceImpactValidationResult {
if tradeSize == nil || liquidity == nil || liquidity.Sign() == 0 {
return &PriceImpactValidationResult{
PriceImpact: 0,
RiskLevel: RiskLevelUnacceptable,
IsAcceptable: false,
Recommendation: "Invalid input: trade size or liquidity is nil/zero",
Details: make(map[string]interface{}),
}
}
// Calculate price impact: tradeSize / (liquidity + tradeSize) * 100
tradeSizeFloat := new(big.Float).SetInt(tradeSize)
liquidityFloat := new(big.Float).SetInt(liquidity)
// Price impact = tradeSize / (liquidity + tradeSize)
denominator := new(big.Float).Add(liquidityFloat, tradeSizeFloat)
priceImpactRatio := new(big.Float).Quo(tradeSizeFloat, denominator)
priceImpactPercent, _ := priceImpactRatio.Float64()
priceImpactPercent *= 100.0
result := piv.ValidatePriceImpact(priceImpactPercent)
// Add liquidity-specific details
result.Details["trade_size"] = tradeSize.String()
result.Details["liquidity"] = liquidity.String()
result.Details["trade_to_liquidity_ratio"] = new(big.Float).Quo(tradeSizeFloat, liquidityFloat).Text('f', 6)
return result
}
// ShouldRejectTrade determines if a trade should be rejected based on price impact
func (piv *PriceImpactValidator) ShouldRejectTrade(priceImpact float64) bool {
return priceImpact > piv.thresholds.MaxAcceptable
}
// ShouldSplitTrade determines if a trade should be split based on price impact
func (piv *PriceImpactValidator) ShouldSplitTrade(priceImpact float64) bool {
return priceImpact >= piv.thresholds.MediumThreshold
}
// GetRecommendedSplitCount recommends how many parts to split a trade into
func (piv *PriceImpactValidator) GetRecommendedSplitCount(priceImpact float64) int {
switch {
case priceImpact < piv.thresholds.MediumThreshold:
return 1 // No split needed
case priceImpact < piv.thresholds.HighThreshold:
return 2 // Split into 2
case priceImpact < piv.thresholds.ExtremeThreshold:
return 4 // Split into 4
case priceImpact < piv.thresholds.MaxAcceptable:
return 8 // Split into 8
default:
return 0 // Reject trade
}
}
// CalculateMaxTradeSize calculates the maximum trade size for a given price impact target
func (piv *PriceImpactValidator) CalculateMaxTradeSize(liquidity *big.Int, targetPriceImpact float64) *big.Int {
if liquidity == nil || liquidity.Sign() == 0 {
return big.NewInt(0)
}
// From: priceImpact = tradeSize / (liquidity + tradeSize)
// Solve for tradeSize: tradeSize = (priceImpact * liquidity) / (1 - priceImpact)
priceImpactDecimal := targetPriceImpact / 100.0
if priceImpactDecimal >= 1.0 {
return big.NewInt(0) // Invalid: 100% price impact or more
}
liquidityFloat := new(big.Float).SetInt(liquidity)
priceImpactFloat := big.NewFloat(priceImpactDecimal)
// numerator = priceImpact * liquidity
numerator := new(big.Float).Mul(priceImpactFloat, liquidityFloat)
// denominator = 1 - priceImpact
denominator := new(big.Float).Sub(big.NewFloat(1.0), priceImpactFloat)
// maxTradeSize = numerator / denominator
maxTradeSize := new(big.Float).Quo(numerator, denominator)
result, _ := maxTradeSize.Int(nil)
return result
}
// GetThresholds returns the current threshold configuration
func (piv *PriceImpactValidator) GetThresholds() *PriceImpactThresholds {
return piv.thresholds
}
// SetThresholds updates the threshold configuration
func (piv *PriceImpactValidator) SetThresholds(thresholds *PriceImpactThresholds) {
if thresholds != nil {
piv.thresholds = thresholds
}
}
// FormatPriceImpact formats a price impact value for display
func FormatPriceImpact(priceImpact float64) string {
return fmt.Sprintf("%.4f%%", priceImpact)
}

View File

@@ -1,242 +0,0 @@
package validation
import (
"math/big"
"testing"
)
func TestDefaultPriceImpactThresholds(t *testing.T) {
thresholds := DefaultPriceImpactThresholds()
tests := []struct {
name string
value float64
expected float64
}{
{"Low threshold", thresholds.LowThreshold, 0.5},
{"Medium threshold", thresholds.MediumThreshold, 2.0},
{"High threshold", thresholds.HighThreshold, 5.0},
{"Extreme threshold", thresholds.ExtremeThreshold, 10.0},
{"Max acceptable", thresholds.MaxAcceptable, 15.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.value != tt.expected {
t.Errorf("%s = %v, want %v", tt.name, tt.value, tt.expected)
}
})
}
}
func TestCategorizePriceImpact(t *testing.T) {
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
tests := []struct {
name string
priceImpact float64
expectedLevel PriceImpactRiskLevel
}{
{"Negligible impact", 0.05, RiskLevelNegligible},
{"Low impact", 0.3, RiskLevelLow},
{"Medium impact", 1.0, RiskLevelMedium},
{"High impact", 3.0, RiskLevelHigh},
{"Extreme impact", 7.0, RiskLevelExtreme},
{"Unacceptable impact", 20.0, RiskLevelUnacceptable},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.ValidatePriceImpact(tt.priceImpact)
if result.RiskLevel != tt.expectedLevel {
t.Errorf("Risk level = %v, want %v", result.RiskLevel, tt.expectedLevel)
}
})
}
}
func TestShouldRejectTrade(t *testing.T) {
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
tests := []struct {
name string
priceImpact float64
shouldReject bool
}{
{"Low impact - accept", 0.5, false},
{"Medium impact - accept", 2.0, false},
{"High impact - accept", 5.0, false},
{"Extreme impact - accept", 10.0, false},
{"At max threshold - accept", 15.0, false},
{"Above max threshold - reject", 15.1, true},
{"Very high - reject", 30.0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.ShouldRejectTrade(tt.priceImpact)
if result != tt.shouldReject {
t.Errorf("ShouldRejectTrade(%v) = %v, want %v", tt.priceImpact, result, tt.shouldReject)
}
})
}
}
func TestShouldSplitTrade(t *testing.T) {
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
tests := []struct {
name string
priceImpact float64
shouldSplit bool
}{
{"Negligible - no split", 0.1, false},
{"Low - no split", 0.5, false},
{"Just below medium - no split", 1.9, false},
{"At medium threshold - split", 2.0, true},
{"High - split", 5.0, true},
{"Extreme - split", 10.0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.ShouldSplitTrade(tt.priceImpact)
if result != tt.shouldSplit {
t.Errorf("ShouldSplitTrade(%v) = %v, want %v", tt.priceImpact, result, tt.shouldSplit)
}
})
}
}
func TestGetRecommendedSplitCount(t *testing.T) {
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
tests := []struct {
name string
priceImpact float64
expectedSplit int
}{
{"Low impact - no split", 0.5, 1},
{"Medium impact - split in 2", 2.5, 2},
{"High impact - split in 4", 6.0, 4},
{"Extreme impact - split in 8", 12.0, 8},
{"Unacceptable - reject (0)", 20.0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.GetRecommendedSplitCount(tt.priceImpact)
if result != tt.expectedSplit {
t.Errorf("GetRecommendedSplitCount(%v) = %v, want %v", tt.priceImpact, result, tt.expectedSplit)
}
})
}
}
func TestCalculateMaxTradeSize(t *testing.T) {
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
liquidity := big.NewInt(1000000) // 1M units of liquidity
tests := []struct {
name string
liquidity *big.Int
targetPriceImpact float64
expectedApproximate int64 // Approximate expected value
}{
{"0.5% impact", liquidity, 0.5, 5025}, // ~0.5% of 1M
{"1% impact", liquidity, 1.0, 10101}, // ~1% of 1M
{"2% impact", liquidity, 2.0, 20408}, // ~2% of 1M
{"5% impact", liquidity, 5.0, 52631}, // ~5% of 1M
{"10% impact", liquidity, 10.0, 111111}, // ~10% of 1M
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.CalculateMaxTradeSize(tt.liquidity, tt.targetPriceImpact)
// Check if result is within 5% of expected value
resultInt64 := result.Int64()
lowerBound := int64(float64(tt.expectedApproximate) * 0.95)
upperBound := int64(float64(tt.expectedApproximate) * 1.05)
if resultInt64 < lowerBound || resultInt64 > upperBound {
t.Errorf("CalculateMaxTradeSize() = %v, expected approximately %v (±5%%)", result, tt.expectedApproximate)
}
})
}
}
func TestValidatePriceImpactWithLiquidity(t *testing.T) {
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
liquidity := big.NewInt(1000000) // 1M units
tests := []struct {
name string
tradeSize *big.Int
liquidity *big.Int
expectedRiskLevel PriceImpactRiskLevel
}{
{"Small trade", big.NewInt(1000), liquidity, RiskLevelNegligible},
{"Medium trade", big.NewInt(20000), liquidity, RiskLevelMedium},
{"Large trade", big.NewInt(100000), liquidity, RiskLevelExtreme},
{"Very large trade", big.NewInt(500000), liquidity, RiskLevelUnacceptable},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.ValidatePriceImpactWithLiquidity(tt.tradeSize, tt.liquidity)
if result.RiskLevel != tt.expectedRiskLevel {
t.Errorf("Risk level = %v, want %v (price impact: %.2f%%)",
result.RiskLevel, tt.expectedRiskLevel, result.PriceImpact)
}
})
}
}
func TestConservativeThresholds(t *testing.T) {
validator := NewPriceImpactValidator(ConservativePriceImpactThresholds())
// Test that conservative thresholds are more strict
// With conservative: High=1.0%, Extreme=2.0%
// So 1.0% exactly is at the boundary and goes to Extreme
result := validator.ValidatePriceImpact(1.0)
if result.RiskLevel != RiskLevelExtreme {
t.Errorf("With conservative thresholds, 1%% should be Extreme risk, got %v", result.RiskLevel)
}
}
func TestAggressiveThresholds(t *testing.T) {
validator := NewPriceImpactValidator(AggressivePriceImpactThresholds())
// Test that aggressive thresholds are more lenient
// With aggressive: Low=1.0%, Medium=3.0%
// So 2.0% falls in the Medium range (between 1.0 and 3.0)
result := validator.ValidatePriceImpact(2.0)
if result.RiskLevel != RiskLevelMedium {
t.Errorf("With aggressive thresholds, 2%% should be Medium risk, got %v", result.RiskLevel)
}
}
func BenchmarkValidatePriceImpact(b *testing.B) {
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
b.ResetTimer()
for i := 0; i < b.N; i++ {
validator.ValidatePriceImpact(2.5)
}
}
func BenchmarkValidatePriceImpactWithLiquidity(b *testing.B) {
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
tradeSize := big.NewInt(50000)
liquidity := big.NewInt(1000000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
validator.ValidatePriceImpactWithLiquidity(tradeSize, liquidity)
}
}