feat: create v2-prep branch with comprehensive planning
Restructured project for V2 refactor: **Structure Changes:** - Moved all V1 code to orig/ folder (preserved with git mv) - Created docs/planning/ directory - Added orig/README_V1.md explaining V1 preservation **Planning Documents:** - 00_V2_MASTER_PLAN.md: Complete architecture overview - Executive summary of critical V1 issues - High-level component architecture diagrams - 5-phase implementation roadmap - Success metrics and risk mitigation - 07_TASK_BREAKDOWN.md: Atomic task breakdown - 99+ hours of detailed tasks - Every task < 2 hours (atomic) - Clear dependencies and success criteria - Organized by implementation phase **V2 Key Improvements:** - Per-exchange parsers (factory pattern) - Multi-layer strict validation - Multi-index pool cache - Background validation pipeline - Comprehensive observability **Critical Issues Addressed:** - Zero address tokens (strict validation + cache enrichment) - Parsing accuracy (protocol-specific parsers) - No audit trail (background validation channel) - Inefficient lookups (multi-index cache) - Stats disconnection (event-driven metrics) Next Steps: 1. Review planning documents 2. Begin Phase 1: Foundation (P1-001 through P1-010) 3. Implement parsers in Phase 2 4. Build cache system in Phase 3 5. Add validation pipeline in Phase 4 6. Migrate and test in Phase 5 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,956 +0,0 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/core/types"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
"github.com/fraktal/mev-beta/internal/utils"
|
||||
"github.com/fraktal/mev-beta/pkg/security"
|
||||
)
|
||||
|
||||
// safeConvertInt64ToUint64 safely converts an int64 to uint64, ensuring no negative values
|
||||
func safeConvertInt64ToUint64(v int64) uint64 {
|
||||
if v < 0 {
|
||||
return 0
|
||||
}
|
||||
return uint64(v)
|
||||
}
|
||||
|
||||
// InputValidator provides comprehensive validation for transaction parameters and user inputs
|
||||
type InputValidator struct {
|
||||
logger *logger.Logger
|
||||
maxGasLimit uint64
|
||||
maxGasPrice *big.Int
|
||||
maxValue *big.Int
|
||||
allowedMethods map[string]bool // method signatures that are allowed
|
||||
// Regex patterns for validation
|
||||
addressPattern *regexp.Regexp
|
||||
txHashPattern *regexp.Regexp
|
||||
blockHashPattern *regexp.Regexp
|
||||
hexDataPattern *regexp.Regexp
|
||||
}
|
||||
|
||||
// ValidationConfig contains configuration for input validation
|
||||
type ValidationConfig struct {
|
||||
MaxGasLimit uint64 `json:"max_gas_limit"`
|
||||
MaxGasPriceGwei int64 `json:"max_gas_price_gwei"`
|
||||
MaxValueEther int64 `json:"max_value_ether"`
|
||||
AllowedMethods []string `json:"allowed_methods"`
|
||||
RequireDeadline bool `json:"require_deadline"`
|
||||
MaxDeadlineHours int `json:"max_deadline_hours"`
|
||||
}
|
||||
|
||||
// TransactionValidationResult contains the result of transaction validation
|
||||
type TransactionValidationResult struct {
|
||||
IsValid bool `json:"is_valid"`
|
||||
Errors []string `json:"errors"`
|
||||
Warnings []string `json:"warnings"`
|
||||
RiskLevel string `json:"risk_level"` // "low", "medium", "high", "critical"
|
||||
EstimatedCost *big.Int `json:"estimated_cost,omitempty"`
|
||||
}
|
||||
|
||||
// SwapParams represents swap transaction parameters
|
||||
type SwapParams struct {
|
||||
TokenIn common.Address `json:"token_in"`
|
||||
TokenOut common.Address `json:"token_out"`
|
||||
AmountIn *big.Int `json:"amount_in"`
|
||||
AmountOutMinimum *big.Int `json:"amount_out_minimum"`
|
||||
Fee uint32 `json:"fee"`
|
||||
Recipient common.Address `json:"recipient"`
|
||||
Deadline uint64 `json:"deadline"`
|
||||
SlippageTolerance *big.Int `json:"slippage_tolerance"` // in basis points
|
||||
}
|
||||
|
||||
// ArbitrageParams represents arbitrage transaction parameters
|
||||
type ArbitrageParams struct {
|
||||
Path []common.Address `json:"path"`
|
||||
AmountIn *big.Int `json:"amount_in"`
|
||||
MinAmountOut *big.Int `json:"min_amount_out"`
|
||||
Deadline uint64 `json:"deadline"`
|
||||
MaxGasPrice *big.Int `json:"max_gas_price"`
|
||||
ProfitThreshold *big.Int `json:"profit_threshold"`
|
||||
MaxSlippageBps *big.Int `json:"max_slippage_bps"`
|
||||
}
|
||||
|
||||
// LiquidityParams represents liquidity provision parameters
|
||||
type LiquidityParams struct {
|
||||
Token0 common.Address `json:"token0"`
|
||||
Token1 common.Address `json:"token1"`
|
||||
Fee uint32 `json:"fee"`
|
||||
TickLower int32 `json:"tick_lower"`
|
||||
TickUpper int32 `json:"tick_upper"`
|
||||
Amount0Desired *big.Int `json:"amount0_desired"`
|
||||
Amount1Desired *big.Int `json:"amount1_desired"`
|
||||
Amount0Min *big.Int `json:"amount0_min"`
|
||||
Amount1Min *big.Int `json:"amount1_min"`
|
||||
Recipient common.Address `json:"recipient"`
|
||||
Deadline uint64 `json:"deadline"`
|
||||
}
|
||||
|
||||
// NewInputValidator creates a new input validator
|
||||
func NewInputValidator(config *ValidationConfig, logger *logger.Logger) *InputValidator {
|
||||
if config == nil {
|
||||
config = getDefaultValidationConfig()
|
||||
}
|
||||
|
||||
validator := &InputValidator{
|
||||
logger: logger,
|
||||
maxGasLimit: config.MaxGasLimit,
|
||||
maxGasPrice: big.NewInt(config.MaxGasPriceGwei * 1e9), // Convert Gwei to Wei
|
||||
maxValue: big.NewInt(config.MaxValueEther * 1e18), // Convert Ether to Wei
|
||||
allowedMethods: make(map[string]bool),
|
||||
addressPattern: regexp.MustCompile(`^0x[a-fA-F0-9]{40}$`),
|
||||
txHashPattern: regexp.MustCompile(`^0x[a-fA-F0-9]{64}$`),
|
||||
blockHashPattern: regexp.MustCompile(`^0x[a-fA-F0-9]{64}$`),
|
||||
hexDataPattern: regexp.MustCompile(`^0x[a-fA-F0-9]*$`),
|
||||
}
|
||||
|
||||
// Initialize allowed methods
|
||||
for _, method := range config.AllowedMethods {
|
||||
validator.allowedMethods[method] = true
|
||||
}
|
||||
|
||||
return validator
|
||||
}
|
||||
|
||||
// ValidateTransaction performs comprehensive validation of a transaction
|
||||
func (iv *InputValidator) ValidateTransaction(tx *types.Transaction) (*TransactionValidationResult, error) {
|
||||
result := &TransactionValidationResult{
|
||||
IsValid: true,
|
||||
Errors: make([]string, 0),
|
||||
Warnings: make([]string, 0),
|
||||
RiskLevel: "low",
|
||||
}
|
||||
|
||||
// 0. Early check for nil or malformed transactions
|
||||
if tx == nil {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "transaction is nil")
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Skip validation for known problematic transactions to reduce log spam
|
||||
txHash := tx.Hash().Hex()
|
||||
if iv.isKnownProblematicTransaction(txHash) {
|
||||
result.IsValid = false
|
||||
// Don't add to errors to avoid logging spam
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 1. Basic transaction validation
|
||||
iv.validateBasicTransaction(tx, result)
|
||||
|
||||
// 2. Gas validation
|
||||
iv.validateGas(tx, result)
|
||||
|
||||
// 3. Value validation
|
||||
iv.validateValue(tx, result)
|
||||
|
||||
// 4. Recipient validation
|
||||
iv.validateRecipient(tx, result)
|
||||
|
||||
// 5. Data validation (for contract calls)
|
||||
if len(tx.Data()) > 0 {
|
||||
iv.validateData(tx.Data(), result)
|
||||
}
|
||||
|
||||
// 6. Calculate estimated cost
|
||||
result.EstimatedCost = iv.calculateEstimatedCost(tx)
|
||||
|
||||
// 7. Determine final validity and risk level
|
||||
iv.finalizeValidation(result)
|
||||
|
||||
if len(result.Errors) > 0 {
|
||||
iv.logger.Warn(fmt.Sprintf("Transaction validation failed: %v", result.Errors))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ValidateSwapParams validates swap transaction parameters
|
||||
func (iv *InputValidator) ValidateSwapParams(params *SwapParams) (*TransactionValidationResult, error) {
|
||||
result := &TransactionValidationResult{
|
||||
IsValid: true,
|
||||
Errors: make([]string, 0),
|
||||
Warnings: make([]string, 0),
|
||||
RiskLevel: "low",
|
||||
}
|
||||
|
||||
// 1. Validate token addresses
|
||||
if err := iv.ValidateAddress(params.TokenIn); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid token_in address: %v", err))
|
||||
}
|
||||
|
||||
if err := iv.ValidateAddress(params.TokenOut); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid token_out address: %v", err))
|
||||
}
|
||||
|
||||
// 2. Check tokens are different
|
||||
if params.TokenIn == params.TokenOut {
|
||||
result.Errors = append(result.Errors, "token_in and token_out must be different")
|
||||
}
|
||||
|
||||
// 3. Validate amounts
|
||||
if err := iv.ValidateAmount(params.AmountIn); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid amount_in: %v", err))
|
||||
}
|
||||
|
||||
if err := iv.ValidateAmount(params.AmountOutMinimum); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid amount_out_minimum: %v", err))
|
||||
}
|
||||
|
||||
// 4. Validate slippage tolerance
|
||||
if err := iv.ValidateSlippage(params.SlippageTolerance); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid slippage tolerance: %v", err))
|
||||
}
|
||||
|
||||
// 5. Validate fee tier
|
||||
if err := iv.validateFeeTier(params.Fee); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid fee tier: %v", err))
|
||||
}
|
||||
|
||||
// 6. Validate recipient
|
||||
if err := iv.ValidateAddress(params.Recipient); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid recipient address: %v", err))
|
||||
}
|
||||
|
||||
// 7. Validate deadline
|
||||
if err := iv.validateDeadline(params.Deadline); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid deadline: %v", err))
|
||||
}
|
||||
|
||||
// 8. Additional security checks
|
||||
iv.performSwapSecurityChecks(params, result)
|
||||
|
||||
iv.finalizeValidation(result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ValidateArbitrageParams validates arbitrage transaction parameters
|
||||
func (iv *InputValidator) ValidateArbitrageParams(params *ArbitrageParams) (*TransactionValidationResult, error) {
|
||||
result := &TransactionValidationResult{
|
||||
IsValid: true,
|
||||
Errors: make([]string, 0),
|
||||
Warnings: make([]string, 0),
|
||||
RiskLevel: "medium", // Arbitrage is inherently riskier
|
||||
}
|
||||
|
||||
// 1. Validate path
|
||||
if len(params.Path) < 2 {
|
||||
result.Errors = append(result.Errors, "arbitrage path must have at least 2 tokens")
|
||||
}
|
||||
|
||||
if len(params.Path) > 5 {
|
||||
result.Warnings = append(result.Warnings, "long arbitrage paths increase gas costs and slippage")
|
||||
}
|
||||
|
||||
for i, addr := range params.Path {
|
||||
if err := iv.ValidateAddress(addr); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid address at path[%d]: %v", i, err))
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check for duplicate tokens in path
|
||||
seen := make(map[common.Address]bool)
|
||||
for _, addr := range params.Path {
|
||||
if seen[addr] {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Duplicate token in path: %s", addr.Hex()))
|
||||
}
|
||||
seen[addr] = true
|
||||
}
|
||||
|
||||
// 3. Validate amounts
|
||||
if err := iv.ValidateAmount(params.AmountIn); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid amount_in: %v", err))
|
||||
}
|
||||
|
||||
if err := iv.ValidateAmount(params.MinAmountOut); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid min_amount_out: %v", err))
|
||||
}
|
||||
|
||||
if err := iv.ValidateAmount(params.ProfitThreshold); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid profit_threshold: %v", err))
|
||||
}
|
||||
|
||||
// 4. Validate profit expectation
|
||||
if params.MinAmountOut.Cmp(params.AmountIn) <= 0 {
|
||||
result.Errors = append(result.Errors, "min_amount_out must be greater than amount_in for profitable arbitrage")
|
||||
}
|
||||
|
||||
// 5. Validate slippage
|
||||
if err := iv.ValidateSlippage(params.MaxSlippageBps); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid max_slippage: %v", err))
|
||||
}
|
||||
|
||||
// 6. Validate gas price
|
||||
if params.MaxGasPrice != nil && params.MaxGasPrice.Cmp(iv.maxGasPrice) > 0 {
|
||||
result.Warnings = append(result.Warnings, "very high gas price may eat into profits")
|
||||
}
|
||||
|
||||
// 7. Validate deadline
|
||||
if err := iv.validateDeadline(params.Deadline); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Invalid deadline: %v", err))
|
||||
}
|
||||
|
||||
iv.finalizeValidation(result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ValidateAddress validates an Ethereum address
|
||||
func (iv *InputValidator) ValidateAddress(addr common.Address) error {
|
||||
if addr == (common.Address{}) {
|
||||
return fmt.Errorf("address cannot be zero")
|
||||
}
|
||||
|
||||
// Check format using regex
|
||||
if !iv.addressPattern.MatchString(addr.Hex()) {
|
||||
return fmt.Errorf("invalid address format")
|
||||
}
|
||||
|
||||
// Check for common invalid addresses
|
||||
if iv.isKnownInvalidAddress(addr) {
|
||||
return fmt.Errorf("address is known to be invalid or malicious")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAmount validates a big.Int amount
|
||||
func (iv *InputValidator) ValidateAmount(amount *big.Int) error {
|
||||
if amount == nil {
|
||||
return fmt.Errorf("amount cannot be nil")
|
||||
}
|
||||
|
||||
if amount.Sign() < 0 {
|
||||
return fmt.Errorf("amount cannot be negative")
|
||||
}
|
||||
|
||||
if amount.Sign() == 0 {
|
||||
return fmt.Errorf("amount cannot be zero")
|
||||
}
|
||||
|
||||
// Check for unreasonably large amounts (prevent overflow attacks)
|
||||
maxAmount := new(big.Int).Exp(big.NewInt(10), big.NewInt(30), nil) // 10^30 wei
|
||||
if amount.Cmp(maxAmount) > 0 {
|
||||
return fmt.Errorf("amount exceeds maximum allowed value")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSlippage validates slippage tolerance in basis points
|
||||
func (iv *InputValidator) ValidateSlippage(slippageBps *big.Int) error {
|
||||
if slippageBps == nil {
|
||||
return fmt.Errorf("slippage cannot be nil")
|
||||
}
|
||||
|
||||
if slippageBps.Sign() < 0 {
|
||||
return fmt.Errorf("slippage cannot be negative")
|
||||
}
|
||||
|
||||
// Maximum 50% slippage (5000 basis points)
|
||||
maxSlippage := big.NewInt(5000)
|
||||
if slippageBps.Cmp(maxSlippage) > 0 {
|
||||
return fmt.Errorf("slippage tolerance cannot exceed 50%%")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateBasicTransaction validates basic transaction properties
|
||||
func (iv *InputValidator) validateBasicTransaction(tx *types.Transaction, result *TransactionValidationResult) {
|
||||
// Check nonce
|
||||
if tx.Nonce() > 1000000 {
|
||||
result.Warnings = append(result.Warnings, "unusually high nonce")
|
||||
}
|
||||
|
||||
// Check transaction size
|
||||
txSize := len(tx.Data()) + 200 // approximate overhead
|
||||
if txSize > 128*1024 { // 128KB limit
|
||||
result.Errors = append(result.Errors, "transaction size exceeds limit")
|
||||
}
|
||||
}
|
||||
|
||||
// validateGas validates gas-related parameters
|
||||
func (iv *InputValidator) validateGas(tx *types.Transaction, result *TransactionValidationResult) {
|
||||
// Validate gas limit
|
||||
if tx.Gas() == 0 {
|
||||
result.Errors = append(result.Errors, "gas limit cannot be zero")
|
||||
}
|
||||
|
||||
if tx.Gas() > iv.maxGasLimit {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("gas limit %d exceeds maximum %d", tx.Gas(), iv.maxGasLimit))
|
||||
}
|
||||
|
||||
// Validate gas price
|
||||
if tx.GasPrice() != nil {
|
||||
if tx.GasPrice().Sign() == 0 {
|
||||
result.Errors = append(result.Errors, "gas price cannot be zero")
|
||||
}
|
||||
|
||||
if tx.GasPrice().Cmp(iv.maxGasPrice) > 0 {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("gas price exceeds maximum"))
|
||||
}
|
||||
|
||||
// Warn about very high gas prices
|
||||
highGasPrice := new(big.Int).Mul(big.NewInt(100), big.NewInt(1e9)) // 100 Gwei
|
||||
if tx.GasPrice().Cmp(highGasPrice) > 0 {
|
||||
result.Warnings = append(result.Warnings, "very high gas price")
|
||||
result.RiskLevel = "medium"
|
||||
}
|
||||
}
|
||||
|
||||
// Validate gas fee cap and tip for EIP-1559 transactions
|
||||
if tx.GasFeeCap() != nil {
|
||||
if tx.GasFeeCap().Sign() == 0 {
|
||||
result.Errors = append(result.Errors, "gas fee cap cannot be zero")
|
||||
}
|
||||
|
||||
if tx.GasFeeCap().Cmp(iv.maxGasPrice) > 0 {
|
||||
result.Errors = append(result.Errors, "gas fee cap exceeds maximum")
|
||||
}
|
||||
}
|
||||
|
||||
if tx.GasTipCap() != nil {
|
||||
if tx.GasTipCap().Sign() < 0 {
|
||||
result.Errors = append(result.Errors, "gas tip cap cannot be negative")
|
||||
}
|
||||
|
||||
if tx.GasFeeCap() != nil && tx.GasTipCap().Cmp(tx.GasFeeCap()) > 0 {
|
||||
result.Errors = append(result.Errors, "gas tip cap cannot exceed gas fee cap")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateValue validates the transaction value
|
||||
func (iv *InputValidator) validateValue(tx *types.Transaction, result *TransactionValidationResult) {
|
||||
if tx.Value() == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if tx.Value().Sign() < 0 {
|
||||
result.Errors = append(result.Errors, "transaction value cannot be negative")
|
||||
}
|
||||
|
||||
if tx.Value().Cmp(iv.maxValue) > 0 {
|
||||
result.Errors = append(result.Errors, "transaction value exceeds maximum allowed")
|
||||
}
|
||||
|
||||
// Warn about large value transfers
|
||||
largeValue := new(big.Int).Mul(big.NewInt(10), big.NewInt(1e18)) // 10 ETH
|
||||
if tx.Value().Cmp(largeValue) > 0 {
|
||||
result.Warnings = append(result.Warnings, "large value transfer")
|
||||
result.RiskLevel = "high"
|
||||
}
|
||||
}
|
||||
|
||||
// validateRecipient validates the transaction recipient
|
||||
func (iv *InputValidator) validateRecipient(tx *types.Transaction, result *TransactionValidationResult) {
|
||||
if tx.To() == nil {
|
||||
// Contract creation transaction
|
||||
result.Warnings = append(result.Warnings, "contract creation transaction")
|
||||
result.RiskLevel = "high"
|
||||
return
|
||||
}
|
||||
|
||||
// Check for zero address
|
||||
if *tx.To() == (common.Address{}) {
|
||||
result.Errors = append(result.Errors, "recipient cannot be zero address")
|
||||
}
|
||||
|
||||
// Check for known malicious addresses
|
||||
if iv.isKnownInvalidAddress(*tx.To()) {
|
||||
result.Errors = append(result.Errors, "recipient is known malicious address")
|
||||
}
|
||||
}
|
||||
|
||||
// validateData validates transaction data for contract calls
|
||||
func (iv *InputValidator) validateData(data []byte, result *TransactionValidationResult) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if len(data) < 4 {
|
||||
result.Errors = append(result.Errors, "invalid function call data")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract function selector
|
||||
selector := data[:4]
|
||||
methodSig := fmt.Sprintf("0x%x", selector)
|
||||
|
||||
// Check if method is allowed
|
||||
if len(iv.allowedMethods) > 0 && !iv.allowedMethods[methodSig] {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("method %s not allowed", methodSig))
|
||||
}
|
||||
|
||||
// Check for suspicious patterns
|
||||
if iv.hasSuspiciousPatterns(data) {
|
||||
result.Warnings = append(result.Warnings, "transaction data contains suspicious patterns")
|
||||
result.RiskLevel = "high"
|
||||
}
|
||||
}
|
||||
|
||||
// validateFeeTier validates Uniswap V3 fee tiers
|
||||
func (iv *InputValidator) validateFeeTier(fee uint32) error {
|
||||
validFees := []uint32{100, 500, 3000, 10000} // 0.01%, 0.05%, 0.3%, 1%
|
||||
for _, validFee := range validFees {
|
||||
if fee == validFee {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("invalid fee tier: %d (must be one of: 100, 500, 3000, 10000)", fee)
|
||||
}
|
||||
|
||||
// validateDeadline validates transaction deadline
|
||||
func (iv *InputValidator) validateDeadline(deadline uint64) error {
|
||||
if deadline == 0 {
|
||||
return fmt.Errorf("deadline cannot be zero")
|
||||
}
|
||||
|
||||
now := safeConvertInt64ToUint64(time.Now().Unix())
|
||||
if deadline <= now {
|
||||
return fmt.Errorf("deadline must be in the future")
|
||||
}
|
||||
|
||||
// Warn about very long deadlines
|
||||
maxDeadline := now + 24*60*60 // 24 hours from now
|
||||
if deadline > maxDeadline {
|
||||
return fmt.Errorf("deadline too far in future (max 24 hours)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// performSwapSecurityChecks performs additional security checks for swap parameters
|
||||
func (iv *InputValidator) performSwapSecurityChecks(params *SwapParams, result *TransactionValidationResult) {
|
||||
// Check for sandwich attack vulnerability
|
||||
if params.SlippageTolerance != nil && params.SlippageTolerance.Cmp(big.NewInt(500)) > 0 { // >5%
|
||||
result.Warnings = append(result.Warnings, "high slippage tolerance increases sandwich attack risk")
|
||||
result.RiskLevel = "medium"
|
||||
}
|
||||
|
||||
// Check for MEV vulnerability
|
||||
if params.AmountIn != nil {
|
||||
// Large trades are more susceptible to MEV
|
||||
largeTradeThreshold := new(big.Int).Mul(big.NewInt(100), big.NewInt(1e18)) // 100 tokens
|
||||
if params.AmountIn.Cmp(largeTradeThreshold) > 0 {
|
||||
result.Warnings = append(result.Warnings, "large trade may be subject to MEV attacks")
|
||||
}
|
||||
}
|
||||
|
||||
// Check deadline proximity
|
||||
now := safeConvertInt64ToUint64(time.Now().Unix())
|
||||
if params.Deadline-now < 60 { // Less than 1 minute
|
||||
result.Warnings = append(result.Warnings, "very short deadline may cause transaction failures")
|
||||
}
|
||||
}
|
||||
|
||||
// calculateEstimatedCost estimates the total cost of a transaction
|
||||
func (iv *InputValidator) calculateEstimatedCost(tx *types.Transaction) *big.Int {
|
||||
cost := new(big.Int)
|
||||
|
||||
// Gas cost
|
||||
if tx.GasPrice() != nil {
|
||||
gasInt64, err := security.SafeUint64ToInt64(tx.Gas())
|
||||
if err != nil {
|
||||
// Log the error but use a safe fallback
|
||||
iv.logger.Error("Gas value exceeds int64 maximum", "gas", tx.Gas(), "error", err)
|
||||
gasInt64 = math.MaxInt64 // Use maximum safe value as fallback
|
||||
}
|
||||
gasCost := new(big.Int).Mul(big.NewInt(gasInt64), tx.GasPrice())
|
||||
cost.Add(cost, gasCost)
|
||||
} else if tx.GasFeeCap() != nil {
|
||||
// For EIP-1559 transactions, use fee cap as estimate
|
||||
gasInt64, err := security.SafeUint64ToInt64(tx.Gas())
|
||||
if err != nil {
|
||||
// Log the error but use a safe fallback
|
||||
iv.logger.Error("Gas value exceeds int64 maximum", "gas", tx.Gas(), "error", err)
|
||||
gasInt64 = math.MaxInt64 // Use maximum safe value as fallback
|
||||
}
|
||||
gasCost := new(big.Int).Mul(big.NewInt(gasInt64), tx.GasFeeCap())
|
||||
cost.Add(cost, gasCost)
|
||||
}
|
||||
|
||||
// Value transfer
|
||||
if tx.Value() != nil {
|
||||
cost.Add(cost, tx.Value())
|
||||
}
|
||||
|
||||
return cost
|
||||
}
|
||||
|
||||
// finalizeValidation determines final validation result
|
||||
func (iv *InputValidator) finalizeValidation(result *TransactionValidationResult) {
|
||||
if len(result.Errors) > 0 {
|
||||
result.IsValid = false
|
||||
result.RiskLevel = "critical"
|
||||
return
|
||||
}
|
||||
|
||||
// Adjust risk level based on warnings
|
||||
if len(result.Warnings) > 2 {
|
||||
if result.RiskLevel == "low" {
|
||||
result.RiskLevel = "medium"
|
||||
} else if result.RiskLevel == "medium" {
|
||||
result.RiskLevel = "high"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func (iv *InputValidator) isKnownInvalidAddress(addr common.Address) bool {
|
||||
// Check against known malicious addresses
|
||||
// This would be populated from a real blacklist in production
|
||||
maliciousAddresses := map[common.Address]bool{
|
||||
// Add known malicious addresses here
|
||||
}
|
||||
|
||||
return maliciousAddresses[addr]
|
||||
}
|
||||
|
||||
func (iv *InputValidator) isKnownProblematicTransaction(txHash string) bool {
|
||||
// List of known problematic transaction hashes that should be skipped
|
||||
problematicTxs := map[string]bool{
|
||||
"0xe79e4719c6770b41405f691c18be3346b691e220d730d6b61abb5dd3ac9d71f0": true,
|
||||
// Add other problematic transaction hashes here
|
||||
}
|
||||
|
||||
return problematicTxs[txHash]
|
||||
}
|
||||
|
||||
func (iv *InputValidator) hasSuspiciousPatterns(data []byte) bool {
|
||||
// Check for suspicious patterns in transaction data
|
||||
// This is a simplified implementation
|
||||
|
||||
// Check for self-destruct calls
|
||||
if len(data) >= 4 {
|
||||
// selfdestruct selector: 0xff
|
||||
if data[0] == 0xff {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for delegate calls to unknown addresses
|
||||
// This would require more sophisticated analysis in production
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getDefaultValidationConfig() *ValidationConfig {
|
||||
return &ValidationConfig{
|
||||
MaxGasLimit: 10000000, // 10M gas
|
||||
MaxGasPriceGwei: 500, // 500 Gwei
|
||||
MaxValueEther: 1000, // 1000 ETH
|
||||
AllowedMethods: []string{}, // Empty means all methods allowed
|
||||
RequireDeadline: true,
|
||||
MaxDeadlineHours: 24,
|
||||
}
|
||||
}
|
||||
|
||||
// Legacy validation functions (keeping for backward compatibility)
|
||||
|
||||
// ValidateEthereumAddress validates an Ethereum address string
|
||||
func (iv *InputValidator) ValidateEthereumAddress(address string) error {
|
||||
if !iv.addressPattern.MatchString(address) {
|
||||
return fmt.Errorf("invalid Ethereum address format")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateTransactionHash validates a transaction hash string
|
||||
func (iv *InputValidator) ValidateTransactionHash(hash string) error {
|
||||
if !iv.txHashPattern.MatchString(hash) {
|
||||
return fmt.Errorf("invalid transaction hash format")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateBlockHash validates a block hash string
|
||||
func (iv *InputValidator) ValidateBlockHash(hash string) error {
|
||||
if !iv.blockHashPattern.MatchString(hash) {
|
||||
return fmt.Errorf("invalid block hash format")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateEvent validates an event structure with comprehensive checks
|
||||
func (iv *InputValidator) ValidateEvent(event interface{}) error {
|
||||
if event == nil {
|
||||
return fmt.Errorf("event cannot be nil")
|
||||
}
|
||||
|
||||
// Use reflection to validate event structure based on type
|
||||
eventType := fmt.Sprintf("%T", event)
|
||||
iv.logger.Debug(fmt.Sprintf("Validating event of type: %s", eventType))
|
||||
|
||||
// Type-specific validation based on event structure
|
||||
switch e := event.(type) {
|
||||
case map[string]interface{}:
|
||||
return iv.validateEventMap(e)
|
||||
default:
|
||||
// For other types, perform basic structural validation
|
||||
return iv.validateEventStructure(event)
|
||||
}
|
||||
}
|
||||
|
||||
// validateEventMap validates map-based event structures
|
||||
func (iv *InputValidator) validateEventMap(eventMap map[string]interface{}) error {
|
||||
// Check for required common fields
|
||||
requiredFields := []string{"type", "timestamp"}
|
||||
for _, field := range requiredFields {
|
||||
if _, exists := eventMap[field]; !exists {
|
||||
return fmt.Errorf("missing required field: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate timestamp if present
|
||||
if timestamp, ok := eventMap["timestamp"]; ok {
|
||||
if err := iv.validateTimestamp(timestamp); err != nil {
|
||||
return fmt.Errorf("invalid timestamp: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate addresses if present
|
||||
addressFields := []string{"address", "token0", "token1", "pool", "sender", "recipient"}
|
||||
for _, field := range addressFields {
|
||||
if addr, exists := eventMap[field]; exists {
|
||||
if addrStr, ok := addr.(string); ok {
|
||||
// PHASE 2 FIX: Use safe address conversion
|
||||
conversionResult := utils.SafeHexToAddress(addrStr)
|
||||
if !conversionResult.IsValid {
|
||||
return fmt.Errorf("invalid address in field %s: %v", field, conversionResult.Error)
|
||||
}
|
||||
if err := iv.ValidateCommonAddress(conversionResult.Address); err != nil {
|
||||
return fmt.Errorf("invalid address in field %s: %w", field, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate amounts if present
|
||||
amountFields := []string{"amount", "amount0", "amount1", "amountIn", "amountOut", "value"}
|
||||
for _, field := range amountFields {
|
||||
if amount, exists := eventMap[field]; exists {
|
||||
if err := iv.validateAmount(amount); err != nil {
|
||||
return fmt.Errorf("invalid amount in field %s: %w", field, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
iv.logger.Debug("Event map validation completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateEventStructure validates arbitrary event structures using reflection
|
||||
func (iv *InputValidator) validateEventStructure(event interface{}) error {
|
||||
// Basic structural validation
|
||||
eventStr := fmt.Sprintf("%+v", event)
|
||||
|
||||
// Check if event structure is not empty
|
||||
if len(eventStr) < 10 {
|
||||
return fmt.Errorf("event structure appears to be empty or malformed")
|
||||
}
|
||||
|
||||
// Check for common patterns that indicate valid events
|
||||
validPatterns := []string{
|
||||
"BlockNumber", "TxHash", "Address", "Token", "Amount", "Pool",
|
||||
"block", "transaction", "address", "token", "amount", "pool",
|
||||
}
|
||||
|
||||
hasValidPattern := false
|
||||
for _, pattern := range validPatterns {
|
||||
if strings.Contains(eventStr, pattern) {
|
||||
hasValidPattern = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasValidPattern {
|
||||
iv.logger.Warn(fmt.Sprintf("Event structure may not contain expected fields: %s", eventStr[:min(100, len(eventStr))]))
|
||||
}
|
||||
|
||||
iv.logger.Debug("Event structure validation completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTimestamp validates timestamp values in various formats
|
||||
func (iv *InputValidator) validateTimestamp(timestamp interface{}) error {
|
||||
switch ts := timestamp.(type) {
|
||||
case int64:
|
||||
if ts < 0 || ts > time.Now().Unix()+86400 { // Not more than 1 day in future
|
||||
return fmt.Errorf("timestamp out of valid range")
|
||||
}
|
||||
case uint64:
|
||||
if ts > safeConvertInt64ToUint64(time.Now().Unix()+86400) { // Not more than 1 day in future
|
||||
return fmt.Errorf("timestamp out of valid range")
|
||||
}
|
||||
case time.Time:
|
||||
if ts.Before(time.Unix(0, 0)) || ts.After(time.Now().Add(24*time.Hour)) {
|
||||
return fmt.Errorf("timestamp out of valid range")
|
||||
}
|
||||
case string:
|
||||
// Try to parse as RFC3339 or Unix timestamp
|
||||
if _, err := time.Parse(time.RFC3339, ts); err != nil {
|
||||
if _, err := strconv.ParseInt(ts, 10, 64); err != nil {
|
||||
return fmt.Errorf("invalid timestamp format")
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported timestamp type: %T", timestamp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateAmount validates amount values in various formats
|
||||
func (iv *InputValidator) validateAmount(amount interface{}) error {
|
||||
switch a := amount.(type) {
|
||||
case *big.Int:
|
||||
if a == nil {
|
||||
return fmt.Errorf("amount cannot be nil")
|
||||
}
|
||||
if a.Sign() < 0 {
|
||||
return fmt.Errorf("amount cannot be negative")
|
||||
}
|
||||
// Check for unreasonably large amounts (> 1e30)
|
||||
maxAmount := new(big.Int).Exp(big.NewInt(10), big.NewInt(30), nil)
|
||||
if a.Cmp(maxAmount) > 0 {
|
||||
return fmt.Errorf("amount exceeds maximum allowed value")
|
||||
}
|
||||
case int64:
|
||||
if a < 0 {
|
||||
return fmt.Errorf("amount cannot be negative")
|
||||
}
|
||||
case uint64:
|
||||
// Always valid for uint64
|
||||
case string:
|
||||
if _, ok := new(big.Int).SetString(a, 10); !ok {
|
||||
return fmt.Errorf("invalid amount format")
|
||||
}
|
||||
case float64:
|
||||
if a < 0 {
|
||||
return fmt.Errorf("amount cannot be negative")
|
||||
}
|
||||
if a > 1e30 {
|
||||
return fmt.Errorf("amount exceeds maximum allowed value")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported amount type: %T", amount)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// min returns the minimum of two integers
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// ValidateHexData validates hex data string
|
||||
func (iv *InputValidator) ValidateHexData(data string) error {
|
||||
if !iv.hexDataPattern.MatchString(data) {
|
||||
return fmt.Errorf("invalid hex data format")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizeInput sanitizes string inputs to prevent injection attacks
|
||||
func SanitizeInput(input string) string {
|
||||
// Remove potentially dangerous characters
|
||||
reg := regexp.MustCompile(`[^\w\s\-\.]`)
|
||||
sanitized := reg.ReplaceAllString(input, "")
|
||||
|
||||
// Limit length
|
||||
if len(sanitized) > 1000 {
|
||||
sanitized = sanitized[:1000]
|
||||
}
|
||||
|
||||
return strings.TrimSpace(sanitized)
|
||||
}
|
||||
|
||||
// ValidateHexString validates a hex string
|
||||
func ValidateHexString(hexStr string) error {
|
||||
if !strings.HasPrefix(hexStr, "0x") {
|
||||
return fmt.Errorf("hex string must start with 0x")
|
||||
}
|
||||
|
||||
hexStr = hexStr[2:] // Remove 0x prefix
|
||||
|
||||
if len(hexStr)%2 != 0 {
|
||||
return fmt.Errorf("hex string must have even length")
|
||||
}
|
||||
|
||||
matched, err := regexp.MatchString("^[0-9a-fA-F]*$", hexStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !matched {
|
||||
return fmt.Errorf("invalid hex characters")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateCommonAddress validates an Ethereum address (common.Address type)
|
||||
func (iv *InputValidator) ValidateCommonAddress(addr common.Address) error {
|
||||
return iv.ValidateAddress(addr)
|
||||
}
|
||||
|
||||
// ValidateBigInt validates a big.Int value with context
|
||||
func (iv *InputValidator) ValidateBigInt(value *big.Int, fieldName string) error {
|
||||
if value == nil {
|
||||
return fmt.Errorf("%s cannot be nil", fieldName)
|
||||
}
|
||||
|
||||
if value.Sign() < 0 {
|
||||
return fmt.Errorf("%s cannot be negative", fieldName)
|
||||
}
|
||||
|
||||
if value.Sign() == 0 {
|
||||
return fmt.Errorf("%s cannot be zero", fieldName)
|
||||
}
|
||||
|
||||
// Check for unreasonably large values
|
||||
maxValue := new(big.Int).Exp(big.NewInt(10), big.NewInt(30), nil)
|
||||
if value.Cmp(maxValue) > 0 {
|
||||
return fmt.Errorf("%s exceeds maximum allowed value", fieldName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSlippageTolerance validates slippage tolerance (same as ValidateSlippage)
|
||||
func (iv *InputValidator) ValidateSlippageTolerance(slippage interface{}) error {
|
||||
switch v := slippage.(type) {
|
||||
case *big.Int:
|
||||
return iv.ValidateSlippage(v)
|
||||
case float64:
|
||||
if v < 0 {
|
||||
return fmt.Errorf("slippage cannot be negative")
|
||||
}
|
||||
if v > 50.0 { // 50% maximum
|
||||
return fmt.Errorf("slippage tolerance cannot exceed 50%%")
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported slippage type: must be *big.Int or float64")
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateDeadline validates a deadline timestamp (public wrapper for validateDeadline)
|
||||
func (iv *InputValidator) ValidateDeadline(deadline uint64) error {
|
||||
return iv.validateDeadline(deadline)
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Basic test to verify package compiles
|
||||
func TestValidation(t *testing.T) {
|
||||
assert.True(t, true)
|
||||
}
|
||||
@@ -1,774 +0,0 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum"
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/ethclient"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
"github.com/fraktal/mev-beta/pkg/pools"
|
||||
"github.com/fraktal/mev-beta/pkg/security"
|
||||
"github.com/fraktal/mev-beta/pkg/uniswap"
|
||||
)
|
||||
|
||||
// PoolValidator provides comprehensive security validation for liquidity pools
|
||||
type PoolValidator struct {
|
||||
client *ethclient.Client
|
||||
logger *logger.Logger
|
||||
create2Calculator *pools.CREATE2Calculator
|
||||
trustedFactories map[common.Address]string // factory address -> name
|
||||
bannedAddresses map[common.Address]string // banned address -> reason
|
||||
validationCache map[common.Address]*ValidationResult
|
||||
cacheTimeout time.Duration
|
||||
}
|
||||
|
||||
// ValidationResult contains the result of pool validation
|
||||
type ValidationResult struct {
|
||||
IsValid bool `json:"is_valid"`
|
||||
SecurityScore int `json:"security_score"` // 0-100, higher is better
|
||||
Warnings []string `json:"warnings"`
|
||||
Errors []string `json:"errors"`
|
||||
PoolType string `json:"pool_type"` // "uniswap_v3", "uniswap_v2", etc.
|
||||
Factory string `json:"factory"`
|
||||
Token0 common.Address `json:"token0"`
|
||||
Token1 common.Address `json:"token1"`
|
||||
Fee uint32 `json:"fee,omitempty"`
|
||||
CreationBlock uint64 `json:"creation_block,omitempty"`
|
||||
ValidatedAt time.Time `json:"validated_at"`
|
||||
FactoryVerified bool `json:"factory_verified"`
|
||||
InterfaceValid bool `json:"interface_valid"`
|
||||
TokensValid bool `json:"tokens_valid"`
|
||||
}
|
||||
|
||||
// PoolValidationConfig contains configuration for pool validation
|
||||
type PoolValidationConfig struct {
|
||||
RequireFactoryVerification bool // Whether factory verification is mandatory
|
||||
MinSecurityScore int // Minimum security score to accept (0-100)
|
||||
MaxValidationTime time.Duration // Maximum time to spend on validation
|
||||
AllowUnknownFactories bool // Whether to allow pools from unknown factories
|
||||
RequireTokenValidation bool // Whether to validate token contracts
|
||||
}
|
||||
|
||||
// NewPoolValidator creates a new pool validator
|
||||
func NewPoolValidator(client *ethclient.Client, logger *logger.Logger) *PoolValidator {
|
||||
pv := &PoolValidator{
|
||||
client: client,
|
||||
logger: logger,
|
||||
create2Calculator: pools.NewCREATE2Calculator(logger, client),
|
||||
trustedFactories: make(map[common.Address]string),
|
||||
bannedAddresses: make(map[common.Address]string),
|
||||
validationCache: make(map[common.Address]*ValidationResult),
|
||||
cacheTimeout: 5 * time.Minute,
|
||||
}
|
||||
|
||||
pv.initializeTrustedFactories()
|
||||
pv.initializeBannedAddresses()
|
||||
|
||||
return pv
|
||||
}
|
||||
|
||||
// ValidatePool performs comprehensive security validation of a pool
|
||||
func (pv *PoolValidator) ValidatePool(ctx context.Context, poolAddr common.Address, config *PoolValidationConfig) (*ValidationResult, error) {
|
||||
if config == nil {
|
||||
config = pv.getDefaultConfig()
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
if cached := pv.getCachedResult(poolAddr); cached != nil {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// Create timeout context
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, config.MaxValidationTime)
|
||||
defer cancel()
|
||||
|
||||
result := &ValidationResult{
|
||||
ValidatedAt: time.Now(),
|
||||
SecurityScore: 0,
|
||||
Warnings: make([]string, 0),
|
||||
Errors: make([]string, 0),
|
||||
}
|
||||
|
||||
// 1. Basic existence check
|
||||
if err := pv.validateBasicExistence(timeoutCtx, poolAddr, result); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Basic validation failed: %v", err))
|
||||
result.IsValid = false
|
||||
pv.cacheResult(poolAddr, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 2. Check against banned addresses
|
||||
if err := pv.checkBannedAddresses(poolAddr, result); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Banned address check failed: %v", err))
|
||||
result.IsValid = false
|
||||
result.SecurityScore = 0
|
||||
pv.cacheResult(poolAddr, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 3. Detect pool type and validate interface
|
||||
if err := pv.validatePoolInterface(timeoutCtx, poolAddr, result); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Interface validation failed: %v", err))
|
||||
if config.RequireFactoryVerification {
|
||||
result.IsValid = false
|
||||
}
|
||||
result.SecurityScore -= 30
|
||||
} else {
|
||||
result.InterfaceValid = true
|
||||
result.SecurityScore += 25
|
||||
}
|
||||
|
||||
// 4. Validate factory deployment (critical security check)
|
||||
if err := pv.validateFactoryDeployment(timeoutCtx, poolAddr, result); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("Factory validation failed: %v", err))
|
||||
if config.RequireFactoryVerification {
|
||||
result.IsValid = false
|
||||
}
|
||||
result.SecurityScore -= 40
|
||||
} else {
|
||||
result.FactoryVerified = true
|
||||
result.SecurityScore += 30
|
||||
}
|
||||
|
||||
// 5. Validate token contracts
|
||||
if config.RequireTokenValidation {
|
||||
if err := pv.validateTokenContracts(timeoutCtx, result); err != nil {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("Token validation warning: %v", err))
|
||||
result.SecurityScore -= 10
|
||||
} else {
|
||||
result.TokensValid = true
|
||||
result.SecurityScore += 15
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Additional security checks
|
||||
pv.performAdditionalSecurityChecks(timeoutCtx, poolAddr, result)
|
||||
|
||||
// 7. Final validation decision
|
||||
if result.SecurityScore >= config.MinSecurityScore && len(result.Errors) == 0 {
|
||||
result.IsValid = true
|
||||
} else {
|
||||
result.IsValid = false
|
||||
}
|
||||
|
||||
// Ensure security score is within bounds
|
||||
if result.SecurityScore < 0 {
|
||||
result.SecurityScore = 0
|
||||
} else if result.SecurityScore > 100 {
|
||||
result.SecurityScore = 100
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
pv.cacheResult(poolAddr, result)
|
||||
|
||||
pv.logger.Debug(fmt.Sprintf("Pool validation complete: %s, valid=%v, score=%d",
|
||||
poolAddr.Hex(), result.IsValid, result.SecurityScore))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// validateBasicExistence checks if the pool contract exists and has code
|
||||
func (pv *PoolValidator) validateBasicExistence(ctx context.Context, poolAddr common.Address, result *ValidationResult) error {
|
||||
// Check if contract has code
|
||||
code, err := pv.client.CodeAt(ctx, poolAddr, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get contract code: %w", err)
|
||||
}
|
||||
|
||||
if len(code) == 0 {
|
||||
return fmt.Errorf("no contract code at address %s", poolAddr.Hex())
|
||||
}
|
||||
|
||||
// Basic code size check - legitimate pools should have substantial code
|
||||
if len(code) < 100 {
|
||||
result.Warnings = append(result.Warnings, "Contract has very small code size")
|
||||
result.SecurityScore -= 10
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkBannedAddresses verifies the pool is not on the banned list
|
||||
func (pv *PoolValidator) checkBannedAddresses(poolAddr common.Address, result *ValidationResult) error {
|
||||
if reason, banned := pv.bannedAddresses[poolAddr]; banned {
|
||||
return fmt.Errorf("pool %s is banned: %s", poolAddr.Hex(), reason)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePoolInterface detects pool type and validates the interface
|
||||
func (pv *PoolValidator) validatePoolInterface(ctx context.Context, poolAddr common.Address, result *ValidationResult) error {
|
||||
// Try to detect pool type by calling standard functions
|
||||
|
||||
// Check for Uniswap V3 interface
|
||||
if pv.isUniswapV3Pool(ctx, poolAddr) {
|
||||
result.PoolType = "uniswap_v3"
|
||||
return pv.validateUniswapV3Interface(ctx, poolAddr, result)
|
||||
}
|
||||
|
||||
// Check for Uniswap V2 interface
|
||||
if pv.isUniswapV2Pool(ctx, poolAddr) {
|
||||
result.PoolType = "uniswap_v2"
|
||||
return pv.validateUniswapV2Interface(ctx, poolAddr, result)
|
||||
}
|
||||
|
||||
// Unknown pool type
|
||||
result.PoolType = "unknown"
|
||||
result.Warnings = append(result.Warnings, "Unknown pool type")
|
||||
return fmt.Errorf("unknown pool interface")
|
||||
}
|
||||
|
||||
// isUniswapV3Pool checks if the pool implements Uniswap V3 interface
|
||||
func (pv *PoolValidator) isUniswapV3Pool(ctx context.Context, poolAddr common.Address) bool {
|
||||
// Try to call slot0() function (unique to Uniswap V3)
|
||||
slot0ABI := `[{"inputs":[],"name":"slot0","outputs":[{"internalType":"uint160","name":"sqrtPriceX96","type":"uint160"},{"internalType":"int24","name":"tick","type":"int24"},{"internalType":"uint16","name":"observationIndex","type":"uint16"},{"internalType":"uint16","name":"observationCardinality","type":"uint16"},{"internalType":"uint16","name":"observationCardinalityNext","type":"uint16"},{"internalType":"uint8","name":"feeProtocol","type":"uint8"},{"internalType":"bool","name":"unlocked","type":"bool"}],"stateMutability":"view","type":"function"}]`
|
||||
|
||||
contractABI, err := uniswap.ParseABI(slot0ABI)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
callData, err := contractABI.Pack("slot0")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err = pv.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &poolAddr,
|
||||
Data: callData,
|
||||
}, nil)
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// isUniswapV2Pool checks if the pool implements Uniswap V2 interface
|
||||
func (pv *PoolValidator) isUniswapV2Pool(ctx context.Context, poolAddr common.Address) bool {
|
||||
// Try to call getReserves() function (standard in Uniswap V2)
|
||||
reservesABI := `[{"inputs":[],"name":"getReserves","outputs":[{"internalType":"uint112","name":"_reserve0","type":"uint112"},{"internalType":"uint112","name":"_reserve1","type":"uint112"},{"internalType":"uint32","name":"_blockTimestampLast","type":"uint32"}],"stateMutability":"view","type":"function"}]`
|
||||
|
||||
contractABI, err := uniswap.ParseABI(reservesABI)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
callData, err := contractABI.Pack("getReserves")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err = pv.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &poolAddr,
|
||||
Data: callData,
|
||||
}, nil)
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// validateUniswapV3Interface validates Uniswap V3 specific functions
|
||||
func (pv *PoolValidator) validateUniswapV3Interface(ctx context.Context, poolAddr common.Address, result *ValidationResult) error {
|
||||
// Get token addresses and fee
|
||||
token0, token1, fee, err := pv.getUniswapV3PoolInfo(ctx, poolAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get V3 pool info: %w", err)
|
||||
}
|
||||
|
||||
result.Token0 = token0
|
||||
result.Token1 = token1
|
||||
result.Fee = fee
|
||||
|
||||
// Validate token ordering (token0 < token1)
|
||||
if token0.Big().Cmp(token1.Big()) >= 0 {
|
||||
return fmt.Errorf("invalid token ordering: token0 must be < token1")
|
||||
}
|
||||
|
||||
// Validate fee tier
|
||||
validFees := []uint32{500, 3000, 10000, 100} // 0.05%, 0.3%, 1%, 0.01%
|
||||
feeValid := false
|
||||
for _, validFee := range validFees {
|
||||
if fee == validFee {
|
||||
feeValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !feeValid {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("Unusual fee tier: %d", fee))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUniswapV2Interface validates Uniswap V2 specific functions
|
||||
func (pv *PoolValidator) validateUniswapV2Interface(ctx context.Context, poolAddr common.Address, result *ValidationResult) error {
|
||||
// Get token addresses
|
||||
token0, token1, err := pv.getUniswapV2PoolInfo(ctx, poolAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get V2 pool info: %w", err)
|
||||
}
|
||||
|
||||
result.Token0 = token0
|
||||
result.Token1 = token1
|
||||
result.Fee = 3000 // V2 has fixed 0.3% fee
|
||||
|
||||
// Validate token ordering
|
||||
if token0.Big().Cmp(token1.Big()) >= 0 {
|
||||
return fmt.Errorf("invalid token ordering: token0 must be < token1")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateFactoryDeployment verifies the pool was deployed by a trusted factory
|
||||
func (pv *PoolValidator) validateFactoryDeployment(ctx context.Context, poolAddr common.Address, result *ValidationResult) error {
|
||||
// For each known factory, try to verify if this pool could have been deployed by it
|
||||
for factoryAddr, factoryName := range pv.trustedFactories {
|
||||
if pv.verifyFactoryDeployment(factoryAddr, factoryName, poolAddr, result) {
|
||||
result.Factory = factoryName
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("pool not deployed by any trusted factory")
|
||||
}
|
||||
|
||||
// verifyFactoryDeployment verifies a specific factory deployed the pool
|
||||
func (pv *PoolValidator) verifyFactoryDeployment(factoryAddr common.Address, factoryName string, poolAddr common.Address, result *ValidationResult) bool {
|
||||
if result.Token0 == (common.Address{}) || result.Token1 == (common.Address{}) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use CREATE2 calculator to verify the pool address
|
||||
return pv.create2Calculator.ValidatePoolAddress(factoryName, result.Token0, result.Token1, result.Fee, poolAddr)
|
||||
}
|
||||
|
||||
// validateTokenContracts validates the token contracts in the pool
|
||||
func (pv *PoolValidator) validateTokenContracts(ctx context.Context, result *ValidationResult) error {
|
||||
if result.Token0 == (common.Address{}) || result.Token1 == (common.Address{}) {
|
||||
return fmt.Errorf("token addresses not available")
|
||||
}
|
||||
|
||||
// Validate token0
|
||||
if err := pv.validateTokenContract(ctx, result.Token0); err != nil {
|
||||
return fmt.Errorf("token0 validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Validate token1
|
||||
if err := pv.validateTokenContract(ctx, result.Token1); err != nil {
|
||||
return fmt.Errorf("token1 validation failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTokenContract validates a single token contract
|
||||
func (pv *PoolValidator) validateTokenContract(ctx context.Context, tokenAddr common.Address) error {
|
||||
// Check if contract exists
|
||||
code, err := pv.client.CodeAt(ctx, tokenAddr, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get token contract code: %w", err)
|
||||
}
|
||||
|
||||
if len(code) == 0 {
|
||||
return fmt.Errorf("no contract code at token address %s", tokenAddr.Hex())
|
||||
}
|
||||
|
||||
// Try to call standard ERC20 functions
|
||||
return pv.validateERC20Interface(ctx, tokenAddr)
|
||||
}
|
||||
|
||||
// validateERC20Interface validates ERC20 token interface
|
||||
func (pv *PoolValidator) validateERC20Interface(ctx context.Context, tokenAddr common.Address) error {
|
||||
// Try to call totalSupply() function
|
||||
erc20ABI := `[{"inputs":[],"name":"totalSupply","outputs":[{"internalType":"uint256","name":"","type":"uint256"}],"stateMutability":"view","type":"function"}]`
|
||||
|
||||
contractABI, err := uniswap.ParseABI(erc20ABI)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse ERC20 ABI: %w", err)
|
||||
}
|
||||
|
||||
callData, err := contractABI.Pack("totalSupply")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pack totalSupply call: %w", err)
|
||||
}
|
||||
|
||||
_, err = pv.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &tokenAddr,
|
||||
Data: callData,
|
||||
}, nil)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("totalSupply call failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// performAdditionalSecurityChecks performs various security checks
|
||||
func (pv *PoolValidator) performAdditionalSecurityChecks(ctx context.Context, poolAddr common.Address, result *ValidationResult) {
|
||||
// Check contract creation time
|
||||
if creationBlock := pv.getContractCreationBlock(ctx, poolAddr); creationBlock > 0 {
|
||||
result.CreationBlock = creationBlock
|
||||
|
||||
// Warn about very new contracts
|
||||
currentBlock, err := pv.client.BlockNumber(ctx)
|
||||
if err == nil && currentBlock-creationBlock < 100 {
|
||||
result.Warnings = append(result.Warnings, "Pool is very new (< 100 blocks old)")
|
||||
result.SecurityScore -= 5
|
||||
}
|
||||
}
|
||||
|
||||
// Check for common attack patterns
|
||||
pv.checkForAttackPatterns(ctx, poolAddr, result)
|
||||
}
|
||||
|
||||
// getContractCreationBlock attempts to find when the contract was created using binary search
|
||||
func (pv *PoolValidator) getContractCreationBlock(ctx context.Context, addr common.Address) uint64 {
|
||||
pv.logger.Debug(fmt.Sprintf("Finding creation block for contract %s", addr.Hex()))
|
||||
|
||||
// Get latest block number first
|
||||
latestBlock, err := pv.client.BlockNumber(ctx)
|
||||
if err != nil {
|
||||
pv.logger.Warn(fmt.Sprintf("Failed to get latest block number: %v", err))
|
||||
return 0
|
||||
}
|
||||
|
||||
// Check if contract exists at latest block
|
||||
codeAtLatest, err := pv.client.CodeAt(ctx, addr, new(big.Int).SetUint64(latestBlock))
|
||||
if err != nil || len(codeAtLatest) == 0 {
|
||||
pv.logger.Debug(fmt.Sprintf("Contract %s does not exist at latest block", addr.Hex()))
|
||||
return 0
|
||||
}
|
||||
|
||||
// Binary search to find creation block
|
||||
// Start with a reasonable range - most pools created in last 10M blocks
|
||||
searchStart := uint64(0)
|
||||
if latestBlock > 10000000 {
|
||||
searchStart = latestBlock - 10000000
|
||||
}
|
||||
|
||||
creationBlock := pv.binarySearchCreationBlock(ctx, addr, searchStart, latestBlock)
|
||||
|
||||
if creationBlock > 0 {
|
||||
pv.logger.Debug(fmt.Sprintf("Contract %s created at block %d", addr.Hex(), creationBlock))
|
||||
}
|
||||
|
||||
return creationBlock
|
||||
}
|
||||
|
||||
// binarySearchCreationBlock performs binary search to find the exact creation block
|
||||
func (pv *PoolValidator) binarySearchCreationBlock(ctx context.Context, addr common.Address, start, end uint64) uint64 {
|
||||
// Limit search iterations to prevent infinite loops
|
||||
maxIterations := 50
|
||||
iteration := 0
|
||||
|
||||
for start <= end && iteration < maxIterations {
|
||||
iteration++
|
||||
mid := (start + end) / 2
|
||||
|
||||
// Check if contract exists at mid block
|
||||
code, err := pv.client.CodeAt(ctx, addr, new(big.Int).SetUint64(mid))
|
||||
if err != nil {
|
||||
pv.logger.Debug(fmt.Sprintf("Error checking code at block %d: %v", mid, err))
|
||||
break
|
||||
}
|
||||
|
||||
hasCode := len(code) > 0
|
||||
|
||||
if hasCode {
|
||||
// Contract exists at mid, check if it exists at mid-1
|
||||
if mid == 0 {
|
||||
return mid
|
||||
}
|
||||
|
||||
prevCode, err := pv.client.CodeAt(ctx, addr, new(big.Int).SetUint64(mid-1))
|
||||
if err != nil || len(prevCode) == 0 {
|
||||
// Contract doesn't exist at mid-1 but exists at mid
|
||||
return mid
|
||||
}
|
||||
|
||||
// Contract exists at both mid and mid-1, search earlier
|
||||
end = mid - 1
|
||||
} else {
|
||||
// Contract doesn't exist at mid, search later
|
||||
start = mid + 1
|
||||
}
|
||||
|
||||
// Add small delay to avoid rate limiting
|
||||
if iteration%10 == 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we couldn't find exact block, return start as best estimate
|
||||
if start <= end {
|
||||
return start
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// checkForAttackPatterns looks for common malicious patterns
|
||||
func (pv *PoolValidator) checkForAttackPatterns(ctx context.Context, poolAddr common.Address, result *ValidationResult) {
|
||||
// Check if contract is a proxy (may be suspicious)
|
||||
if pv.isProxyContract(ctx, poolAddr) {
|
||||
result.Warnings = append(result.Warnings, "Contract appears to be a proxy")
|
||||
result.SecurityScore -= 10
|
||||
}
|
||||
|
||||
// Check for unusual bytecode patterns
|
||||
if pv.hasUnusualBytecode(ctx, poolAddr) {
|
||||
result.Warnings = append(result.Warnings, "Unusual bytecode patterns detected")
|
||||
result.SecurityScore -= 15
|
||||
}
|
||||
}
|
||||
|
||||
// isProxyContract checks if the contract is a proxy
|
||||
func (pv *PoolValidator) isProxyContract(ctx context.Context, addr common.Address) bool {
|
||||
code, err := pv.client.CodeAt(ctx, addr, nil)
|
||||
if err != nil || len(code) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Look for common proxy patterns (delegatecall, etc.)
|
||||
codeHex := hex.EncodeToString(code)
|
||||
return strings.Contains(codeHex, "f4") // delegatecall opcode
|
||||
}
|
||||
|
||||
// hasUnusualBytecode checks for suspicious bytecode patterns
|
||||
func (pv *PoolValidator) hasUnusualBytecode(ctx context.Context, addr common.Address) bool {
|
||||
code, err := pv.client.CodeAt(ctx, addr, nil)
|
||||
if err != nil || len(code) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for unusual patterns
|
||||
if len(code) > 50000 {
|
||||
return true // Unusually large contract
|
||||
}
|
||||
|
||||
// Check for high entropy (potential obfuscation)
|
||||
entropy := pv.calculateEntropy(code)
|
||||
return entropy > 7.5 // High entropy threshold
|
||||
}
|
||||
|
||||
// calculateEntropy calculates Shannon entropy of bytecode
|
||||
func (pv *PoolValidator) calculateEntropy(data []byte) float64 {
|
||||
if len(data) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
freq := make(map[byte]int)
|
||||
for _, b := range data {
|
||||
freq[b]++
|
||||
}
|
||||
|
||||
entropy := 0.0
|
||||
length := float64(len(data))
|
||||
|
||||
for _, f := range freq {
|
||||
p := float64(f) / length
|
||||
if p > 0 {
|
||||
entropy -= p * (float64(f) / length)
|
||||
}
|
||||
}
|
||||
|
||||
return entropy
|
||||
}
|
||||
|
||||
// Helper functions to get pool information
|
||||
|
||||
func (pv *PoolValidator) getUniswapV3PoolInfo(ctx context.Context, poolAddr common.Address) (common.Address, common.Address, uint32, error) {
|
||||
poolABI := `[{"inputs":[],"name":"token0","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"},{"inputs":[],"name":"token1","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"},{"inputs":[],"name":"fee","outputs":[{"internalType":"uint24","name":"","type":"uint24"}],"stateMutability":"view","type":"function"}]`
|
||||
|
||||
contractABI, err := uniswap.ParseABI(poolABI)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, err
|
||||
}
|
||||
|
||||
// Get token0
|
||||
token0Data, err := contractABI.Pack("token0")
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, err
|
||||
}
|
||||
|
||||
token0Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &poolAddr,
|
||||
Data: token0Data,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, err
|
||||
}
|
||||
|
||||
token0Unpacked, err := contractABI.Unpack("token0", token0Result)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, err
|
||||
}
|
||||
|
||||
// Get token1
|
||||
token1Data, err := contractABI.Pack("token1")
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, err
|
||||
}
|
||||
|
||||
token1Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &poolAddr,
|
||||
Data: token1Data,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, err
|
||||
}
|
||||
|
||||
token1Unpacked, err := contractABI.Unpack("token1", token1Result)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, err
|
||||
}
|
||||
|
||||
// Get fee
|
||||
feeData, err := contractABI.Pack("fee")
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, err
|
||||
}
|
||||
|
||||
feeResult, err := pv.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &poolAddr,
|
||||
Data: feeData,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, err
|
||||
}
|
||||
|
||||
feeUnpacked, err := contractABI.Unpack("fee", feeResult)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, err
|
||||
}
|
||||
|
||||
token0 := token0Unpacked[0].(common.Address)
|
||||
token1 := token1Unpacked[0].(common.Address)
|
||||
fee := feeUnpacked[0].(*big.Int).Uint64()
|
||||
|
||||
feeUint32, err := security.SafeUint32(fee)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, 0, fmt.Errorf("invalid fee conversion: %w", err)
|
||||
}
|
||||
return token0, token1, feeUint32, nil
|
||||
}
|
||||
|
||||
func (pv *PoolValidator) getUniswapV2PoolInfo(ctx context.Context, poolAddr common.Address) (common.Address, common.Address, error) {
|
||||
poolABI := `[{"inputs":[],"name":"token0","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"},{"inputs":[],"name":"token1","outputs":[{"internalType":"address","name":"","type":"address"}],"stateMutability":"view","type":"function"}]`
|
||||
|
||||
contractABI, err := uniswap.ParseABI(poolABI)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, err
|
||||
}
|
||||
|
||||
// Get token0
|
||||
token0Data, err := contractABI.Pack("token0")
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, err
|
||||
}
|
||||
|
||||
token0Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &poolAddr,
|
||||
Data: token0Data,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, err
|
||||
}
|
||||
|
||||
token0Unpacked, err := contractABI.Unpack("token0", token0Result)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, err
|
||||
}
|
||||
|
||||
// Get token1
|
||||
token1Data, err := contractABI.Pack("token1")
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, err
|
||||
}
|
||||
|
||||
token1Result, err := pv.client.CallContract(ctx, ethereum.CallMsg{
|
||||
To: &poolAddr,
|
||||
Data: token1Data,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, err
|
||||
}
|
||||
|
||||
token1Unpacked, err := contractABI.Unpack("token1", token1Result)
|
||||
if err != nil {
|
||||
return common.Address{}, common.Address{}, err
|
||||
}
|
||||
|
||||
token0 := token0Unpacked[0].(common.Address)
|
||||
token1 := token1Unpacked[0].(common.Address)
|
||||
|
||||
return token0, token1, nil
|
||||
}
|
||||
|
||||
// Configuration and caching methods
|
||||
|
||||
func (pv *PoolValidator) getDefaultConfig() *PoolValidationConfig {
|
||||
return &PoolValidationConfig{
|
||||
RequireFactoryVerification: true,
|
||||
MinSecurityScore: 70,
|
||||
MaxValidationTime: 10 * time.Second,
|
||||
AllowUnknownFactories: false,
|
||||
RequireTokenValidation: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (pv *PoolValidator) getCachedResult(addr common.Address) *ValidationResult {
|
||||
if result, exists := pv.validationCache[addr]; exists {
|
||||
if time.Since(result.ValidatedAt) < pv.cacheTimeout {
|
||||
return result
|
||||
}
|
||||
delete(pv.validationCache, addr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pv *PoolValidator) cacheResult(addr common.Address, result *ValidationResult) {
|
||||
pv.validationCache[addr] = result
|
||||
}
|
||||
|
||||
// Initialization methods
|
||||
|
||||
func (pv *PoolValidator) initializeTrustedFactories() {
|
||||
// Uniswap V3
|
||||
pv.trustedFactories[common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984")] = "uniswap_v3"
|
||||
// Uniswap V2
|
||||
pv.trustedFactories[common.HexToAddress("0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f")] = "uniswap_v2"
|
||||
// SushiSwap
|
||||
pv.trustedFactories[common.HexToAddress("0xC0AEe478e3658e2610c5F7A4A2E1777cE9e4f2Ac")] = "sushiswap"
|
||||
// Camelot V3 (Arbitrum)
|
||||
pv.trustedFactories[common.HexToAddress("0x1a3c9B1d2F0529D97f2afC5136Cc23e58f1FD35B")] = "camelot_v3"
|
||||
}
|
||||
|
||||
func (pv *PoolValidator) initializeBannedAddresses() {
|
||||
// Add known malicious or problematic pool addresses
|
||||
// This would be populated with real banned addresses in production
|
||||
}
|
||||
|
||||
// AddTrustedFactory adds a new trusted factory
|
||||
func (pv *PoolValidator) AddTrustedFactory(factoryAddr common.Address, name string) {
|
||||
pv.trustedFactories[factoryAddr] = name
|
||||
pv.logger.Info(fmt.Sprintf("Added trusted factory: %s (%s)", name, factoryAddr.Hex()))
|
||||
}
|
||||
|
||||
// BanAddress adds an address to the banned list
|
||||
func (pv *PoolValidator) BanAddress(addr common.Address, reason string) {
|
||||
pv.bannedAddresses[addr] = reason
|
||||
// Clear from cache if present
|
||||
delete(pv.validationCache, addr)
|
||||
pv.logger.Warn(fmt.Sprintf("Banned address: %s (reason: %s)", addr.Hex(), reason))
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,265 +0,0 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// PriceImpactThresholds defines the acceptable price impact levels
|
||||
type PriceImpactThresholds struct {
|
||||
// Low risk: < 0.5% price impact
|
||||
LowThreshold float64
|
||||
// Medium risk: 0.5% - 2% price impact
|
||||
MediumThreshold float64
|
||||
// High risk: 2% - 5% price impact
|
||||
HighThreshold float64
|
||||
// Extreme risk: > 5% price impact (typically unprofitable due to slippage)
|
||||
ExtremeThreshold float64
|
||||
// Maximum acceptable: Reject anything above this (e.g., 10%)
|
||||
MaxAcceptable float64
|
||||
}
|
||||
|
||||
// DefaultPriceImpactThresholds returns conservative production-ready thresholds
|
||||
func DefaultPriceImpactThresholds() *PriceImpactThresholds {
|
||||
return &PriceImpactThresholds{
|
||||
LowThreshold: 0.5, // 0.5%
|
||||
MediumThreshold: 2.0, // 2%
|
||||
HighThreshold: 5.0, // 5%
|
||||
ExtremeThreshold: 10.0, // 10%
|
||||
MaxAcceptable: 15.0, // 15% - reject anything higher
|
||||
}
|
||||
}
|
||||
|
||||
// AggressivePriceImpactThresholds returns more aggressive thresholds for higher volumes
|
||||
func AggressivePriceImpactThresholds() *PriceImpactThresholds {
|
||||
return &PriceImpactThresholds{
|
||||
LowThreshold: 1.0, // 1%
|
||||
MediumThreshold: 3.0, // 3%
|
||||
HighThreshold: 7.0, // 7%
|
||||
ExtremeThreshold: 15.0, // 15%
|
||||
MaxAcceptable: 25.0, // 25%
|
||||
}
|
||||
}
|
||||
|
||||
// ConservativePriceImpactThresholds returns very conservative thresholds for safety
|
||||
func ConservativePriceImpactThresholds() *PriceImpactThresholds {
|
||||
return &PriceImpactThresholds{
|
||||
LowThreshold: 0.1, // 0.1%
|
||||
MediumThreshold: 0.5, // 0.5%
|
||||
HighThreshold: 1.0, // 1%
|
||||
ExtremeThreshold: 2.0, // 2%
|
||||
MaxAcceptable: 5.0, // 5%
|
||||
}
|
||||
}
|
||||
|
||||
// PriceImpactRiskLevel represents the risk level of a price impact
|
||||
type PriceImpactRiskLevel string
|
||||
|
||||
const (
|
||||
RiskLevelNegligible PriceImpactRiskLevel = "Negligible" // < 0.1%
|
||||
RiskLevelLow PriceImpactRiskLevel = "Low" // 0.1-0.5%
|
||||
RiskLevelMedium PriceImpactRiskLevel = "Medium" // 0.5-2%
|
||||
RiskLevelHigh PriceImpactRiskLevel = "High" // 2-5%
|
||||
RiskLevelExtreme PriceImpactRiskLevel = "Extreme" // 5-10%
|
||||
RiskLevelUnacceptable PriceImpactRiskLevel = "Unacceptable" // > 10%
|
||||
)
|
||||
|
||||
// PriceImpactValidationResult contains the result of price impact validation
|
||||
type PriceImpactValidationResult struct {
|
||||
PriceImpact float64 // The calculated price impact percentage
|
||||
RiskLevel PriceImpactRiskLevel // The risk categorization
|
||||
IsAcceptable bool // Whether this price impact is acceptable
|
||||
Recommendation string // Human-readable recommendation
|
||||
Details map[string]interface{} // Additional details
|
||||
}
|
||||
|
||||
// PriceImpactValidator validates price impacts against configured thresholds
|
||||
type PriceImpactValidator struct {
|
||||
thresholds *PriceImpactThresholds
|
||||
}
|
||||
|
||||
// NewPriceImpactValidator creates a new price impact validator
|
||||
func NewPriceImpactValidator(thresholds *PriceImpactThresholds) *PriceImpactValidator {
|
||||
if thresholds == nil {
|
||||
thresholds = DefaultPriceImpactThresholds()
|
||||
}
|
||||
return &PriceImpactValidator{
|
||||
thresholds: thresholds,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePriceImpact validates a price impact percentage
|
||||
func (piv *PriceImpactValidator) ValidatePriceImpact(priceImpact float64) *PriceImpactValidationResult {
|
||||
result := &PriceImpactValidationResult{
|
||||
PriceImpact: priceImpact,
|
||||
Details: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Determine risk level
|
||||
result.RiskLevel = piv.categorizePriceImpact(priceImpact)
|
||||
|
||||
// Determine if acceptable
|
||||
result.IsAcceptable = priceImpact <= piv.thresholds.MaxAcceptable
|
||||
|
||||
// Generate recommendation
|
||||
result.Recommendation = piv.generateRecommendation(priceImpact, result.RiskLevel)
|
||||
|
||||
// Add threshold details
|
||||
result.Details["thresholds"] = map[string]float64{
|
||||
"low": piv.thresholds.LowThreshold,
|
||||
"medium": piv.thresholds.MediumThreshold,
|
||||
"high": piv.thresholds.HighThreshold,
|
||||
"extreme": piv.thresholds.ExtremeThreshold,
|
||||
"max": piv.thresholds.MaxAcceptable,
|
||||
}
|
||||
|
||||
// Add risk-specific details
|
||||
result.Details["risk_level"] = string(result.RiskLevel)
|
||||
result.Details["acceptable"] = result.IsAcceptable
|
||||
result.Details["price_impact_percent"] = priceImpact
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// categorizePriceImpact categorizes the price impact into risk levels
|
||||
func (piv *PriceImpactValidator) categorizePriceImpact(priceImpact float64) PriceImpactRiskLevel {
|
||||
switch {
|
||||
case priceImpact < 0.1:
|
||||
return RiskLevelNegligible
|
||||
case priceImpact < piv.thresholds.LowThreshold:
|
||||
return RiskLevelLow
|
||||
case priceImpact < piv.thresholds.MediumThreshold:
|
||||
return RiskLevelMedium
|
||||
case priceImpact < piv.thresholds.HighThreshold:
|
||||
return RiskLevelHigh
|
||||
case priceImpact < piv.thresholds.ExtremeThreshold:
|
||||
return RiskLevelExtreme
|
||||
default:
|
||||
return RiskLevelUnacceptable
|
||||
}
|
||||
}
|
||||
|
||||
// generateRecommendation generates a recommendation based on price impact
|
||||
func (piv *PriceImpactValidator) generateRecommendation(priceImpact float64, riskLevel PriceImpactRiskLevel) string {
|
||||
switch riskLevel {
|
||||
case RiskLevelNegligible:
|
||||
return fmt.Sprintf("Excellent: Price impact of %.4f%% is negligible. Safe to execute.", priceImpact)
|
||||
case RiskLevelLow:
|
||||
return fmt.Sprintf("Good: Price impact of %.4f%% is low. Execute with standard slippage protection.", priceImpact)
|
||||
case RiskLevelMedium:
|
||||
return fmt.Sprintf("Moderate: Price impact of %.4f%% is medium. Use enhanced slippage protection and consider splitting the trade.", priceImpact)
|
||||
case RiskLevelHigh:
|
||||
return fmt.Sprintf("Caution: Price impact of %.4f%% is high. Strongly recommend splitting into smaller trades or waiting for better liquidity.", priceImpact)
|
||||
case RiskLevelExtreme:
|
||||
return fmt.Sprintf("Warning: Price impact of %.4f%% is extreme. Trade size is too large for current liquidity. Split trade or skip.", priceImpact)
|
||||
case RiskLevelUnacceptable:
|
||||
return fmt.Sprintf("Reject: Price impact of %.4f%% exceeds maximum acceptable threshold (%.2f%%). Do not execute.", priceImpact, piv.thresholds.MaxAcceptable)
|
||||
default:
|
||||
return "Unknown risk level"
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePriceImpactWithLiquidity validates price impact considering trade size and liquidity
|
||||
func (piv *PriceImpactValidator) ValidatePriceImpactWithLiquidity(tradeSize, liquidity *big.Int) *PriceImpactValidationResult {
|
||||
if tradeSize == nil || liquidity == nil || liquidity.Sign() == 0 {
|
||||
return &PriceImpactValidationResult{
|
||||
PriceImpact: 0,
|
||||
RiskLevel: RiskLevelUnacceptable,
|
||||
IsAcceptable: false,
|
||||
Recommendation: "Invalid input: trade size or liquidity is nil/zero",
|
||||
Details: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate price impact: tradeSize / (liquidity + tradeSize) * 100
|
||||
tradeSizeFloat := new(big.Float).SetInt(tradeSize)
|
||||
liquidityFloat := new(big.Float).SetInt(liquidity)
|
||||
|
||||
// Price impact = tradeSize / (liquidity + tradeSize)
|
||||
denominator := new(big.Float).Add(liquidityFloat, tradeSizeFloat)
|
||||
priceImpactRatio := new(big.Float).Quo(tradeSizeFloat, denominator)
|
||||
priceImpactPercent, _ := priceImpactRatio.Float64()
|
||||
priceImpactPercent *= 100.0
|
||||
|
||||
result := piv.ValidatePriceImpact(priceImpactPercent)
|
||||
|
||||
// Add liquidity-specific details
|
||||
result.Details["trade_size"] = tradeSize.String()
|
||||
result.Details["liquidity"] = liquidity.String()
|
||||
result.Details["trade_to_liquidity_ratio"] = new(big.Float).Quo(tradeSizeFloat, liquidityFloat).Text('f', 6)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ShouldRejectTrade determines if a trade should be rejected based on price impact
|
||||
func (piv *PriceImpactValidator) ShouldRejectTrade(priceImpact float64) bool {
|
||||
return priceImpact > piv.thresholds.MaxAcceptable
|
||||
}
|
||||
|
||||
// ShouldSplitTrade determines if a trade should be split based on price impact
|
||||
func (piv *PriceImpactValidator) ShouldSplitTrade(priceImpact float64) bool {
|
||||
return priceImpact >= piv.thresholds.MediumThreshold
|
||||
}
|
||||
|
||||
// GetRecommendedSplitCount recommends how many parts to split a trade into
|
||||
func (piv *PriceImpactValidator) GetRecommendedSplitCount(priceImpact float64) int {
|
||||
switch {
|
||||
case priceImpact < piv.thresholds.MediumThreshold:
|
||||
return 1 // No split needed
|
||||
case priceImpact < piv.thresholds.HighThreshold:
|
||||
return 2 // Split into 2
|
||||
case priceImpact < piv.thresholds.ExtremeThreshold:
|
||||
return 4 // Split into 4
|
||||
case priceImpact < piv.thresholds.MaxAcceptable:
|
||||
return 8 // Split into 8
|
||||
default:
|
||||
return 0 // Reject trade
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateMaxTradeSize calculates the maximum trade size for a given price impact target
|
||||
func (piv *PriceImpactValidator) CalculateMaxTradeSize(liquidity *big.Int, targetPriceImpact float64) *big.Int {
|
||||
if liquidity == nil || liquidity.Sign() == 0 {
|
||||
return big.NewInt(0)
|
||||
}
|
||||
|
||||
// From: priceImpact = tradeSize / (liquidity + tradeSize)
|
||||
// Solve for tradeSize: tradeSize = (priceImpact * liquidity) / (1 - priceImpact)
|
||||
|
||||
priceImpactDecimal := targetPriceImpact / 100.0
|
||||
if priceImpactDecimal >= 1.0 {
|
||||
return big.NewInt(0) // Invalid: 100% price impact or more
|
||||
}
|
||||
|
||||
liquidityFloat := new(big.Float).SetInt(liquidity)
|
||||
priceImpactFloat := big.NewFloat(priceImpactDecimal)
|
||||
|
||||
// numerator = priceImpact * liquidity
|
||||
numerator := new(big.Float).Mul(priceImpactFloat, liquidityFloat)
|
||||
|
||||
// denominator = 1 - priceImpact
|
||||
denominator := new(big.Float).Sub(big.NewFloat(1.0), priceImpactFloat)
|
||||
|
||||
// maxTradeSize = numerator / denominator
|
||||
maxTradeSize := new(big.Float).Quo(numerator, denominator)
|
||||
|
||||
result, _ := maxTradeSize.Int(nil)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetThresholds returns the current threshold configuration
|
||||
func (piv *PriceImpactValidator) GetThresholds() *PriceImpactThresholds {
|
||||
return piv.thresholds
|
||||
}
|
||||
|
||||
// SetThresholds updates the threshold configuration
|
||||
func (piv *PriceImpactValidator) SetThresholds(thresholds *PriceImpactThresholds) {
|
||||
if thresholds != nil {
|
||||
piv.thresholds = thresholds
|
||||
}
|
||||
}
|
||||
|
||||
// FormatPriceImpact formats a price impact value for display
|
||||
func FormatPriceImpact(priceImpact float64) string {
|
||||
return fmt.Sprintf("%.4f%%", priceImpact)
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultPriceImpactThresholds(t *testing.T) {
|
||||
thresholds := DefaultPriceImpactThresholds()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value float64
|
||||
expected float64
|
||||
}{
|
||||
{"Low threshold", thresholds.LowThreshold, 0.5},
|
||||
{"Medium threshold", thresholds.MediumThreshold, 2.0},
|
||||
{"High threshold", thresholds.HighThreshold, 5.0},
|
||||
{"Extreme threshold", thresholds.ExtremeThreshold, 10.0},
|
||||
{"Max acceptable", thresholds.MaxAcceptable, 15.0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.value != tt.expected {
|
||||
t.Errorf("%s = %v, want %v", tt.name, tt.value, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCategorizePriceImpact(t *testing.T) {
|
||||
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
priceImpact float64
|
||||
expectedLevel PriceImpactRiskLevel
|
||||
}{
|
||||
{"Negligible impact", 0.05, RiskLevelNegligible},
|
||||
{"Low impact", 0.3, RiskLevelLow},
|
||||
{"Medium impact", 1.0, RiskLevelMedium},
|
||||
{"High impact", 3.0, RiskLevelHigh},
|
||||
{"Extreme impact", 7.0, RiskLevelExtreme},
|
||||
{"Unacceptable impact", 20.0, RiskLevelUnacceptable},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidatePriceImpact(tt.priceImpact)
|
||||
if result.RiskLevel != tt.expectedLevel {
|
||||
t.Errorf("Risk level = %v, want %v", result.RiskLevel, tt.expectedLevel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldRejectTrade(t *testing.T) {
|
||||
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
priceImpact float64
|
||||
shouldReject bool
|
||||
}{
|
||||
{"Low impact - accept", 0.5, false},
|
||||
{"Medium impact - accept", 2.0, false},
|
||||
{"High impact - accept", 5.0, false},
|
||||
{"Extreme impact - accept", 10.0, false},
|
||||
{"At max threshold - accept", 15.0, false},
|
||||
{"Above max threshold - reject", 15.1, true},
|
||||
{"Very high - reject", 30.0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ShouldRejectTrade(tt.priceImpact)
|
||||
if result != tt.shouldReject {
|
||||
t.Errorf("ShouldRejectTrade(%v) = %v, want %v", tt.priceImpact, result, tt.shouldReject)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSplitTrade(t *testing.T) {
|
||||
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
priceImpact float64
|
||||
shouldSplit bool
|
||||
}{
|
||||
{"Negligible - no split", 0.1, false},
|
||||
{"Low - no split", 0.5, false},
|
||||
{"Just below medium - no split", 1.9, false},
|
||||
{"At medium threshold - split", 2.0, true},
|
||||
{"High - split", 5.0, true},
|
||||
{"Extreme - split", 10.0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ShouldSplitTrade(tt.priceImpact)
|
||||
if result != tt.shouldSplit {
|
||||
t.Errorf("ShouldSplitTrade(%v) = %v, want %v", tt.priceImpact, result, tt.shouldSplit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRecommendedSplitCount(t *testing.T) {
|
||||
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
priceImpact float64
|
||||
expectedSplit int
|
||||
}{
|
||||
{"Low impact - no split", 0.5, 1},
|
||||
{"Medium impact - split in 2", 2.5, 2},
|
||||
{"High impact - split in 4", 6.0, 4},
|
||||
{"Extreme impact - split in 8", 12.0, 8},
|
||||
{"Unacceptable - reject (0)", 20.0, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.GetRecommendedSplitCount(tt.priceImpact)
|
||||
if result != tt.expectedSplit {
|
||||
t.Errorf("GetRecommendedSplitCount(%v) = %v, want %v", tt.priceImpact, result, tt.expectedSplit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateMaxTradeSize(t *testing.T) {
|
||||
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
|
||||
|
||||
liquidity := big.NewInt(1000000) // 1M units of liquidity
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
liquidity *big.Int
|
||||
targetPriceImpact float64
|
||||
expectedApproximate int64 // Approximate expected value
|
||||
}{
|
||||
{"0.5% impact", liquidity, 0.5, 5025}, // ~0.5% of 1M
|
||||
{"1% impact", liquidity, 1.0, 10101}, // ~1% of 1M
|
||||
{"2% impact", liquidity, 2.0, 20408}, // ~2% of 1M
|
||||
{"5% impact", liquidity, 5.0, 52631}, // ~5% of 1M
|
||||
{"10% impact", liquidity, 10.0, 111111}, // ~10% of 1M
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.CalculateMaxTradeSize(tt.liquidity, tt.targetPriceImpact)
|
||||
|
||||
// Check if result is within 5% of expected value
|
||||
resultInt64 := result.Int64()
|
||||
lowerBound := int64(float64(tt.expectedApproximate) * 0.95)
|
||||
upperBound := int64(float64(tt.expectedApproximate) * 1.05)
|
||||
|
||||
if resultInt64 < lowerBound || resultInt64 > upperBound {
|
||||
t.Errorf("CalculateMaxTradeSize() = %v, expected approximately %v (±5%%)", result, tt.expectedApproximate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePriceImpactWithLiquidity(t *testing.T) {
|
||||
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
|
||||
|
||||
liquidity := big.NewInt(1000000) // 1M units
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tradeSize *big.Int
|
||||
liquidity *big.Int
|
||||
expectedRiskLevel PriceImpactRiskLevel
|
||||
}{
|
||||
{"Small trade", big.NewInt(1000), liquidity, RiskLevelNegligible},
|
||||
{"Medium trade", big.NewInt(20000), liquidity, RiskLevelMedium},
|
||||
{"Large trade", big.NewInt(100000), liquidity, RiskLevelExtreme},
|
||||
{"Very large trade", big.NewInt(500000), liquidity, RiskLevelUnacceptable},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidatePriceImpactWithLiquidity(tt.tradeSize, tt.liquidity)
|
||||
if result.RiskLevel != tt.expectedRiskLevel {
|
||||
t.Errorf("Risk level = %v, want %v (price impact: %.2f%%)",
|
||||
result.RiskLevel, tt.expectedRiskLevel, result.PriceImpact)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConservativeThresholds(t *testing.T) {
|
||||
validator := NewPriceImpactValidator(ConservativePriceImpactThresholds())
|
||||
|
||||
// Test that conservative thresholds are more strict
|
||||
// With conservative: High=1.0%, Extreme=2.0%
|
||||
// So 1.0% exactly is at the boundary and goes to Extreme
|
||||
result := validator.ValidatePriceImpact(1.0)
|
||||
|
||||
if result.RiskLevel != RiskLevelExtreme {
|
||||
t.Errorf("With conservative thresholds, 1%% should be Extreme risk, got %v", result.RiskLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggressiveThresholds(t *testing.T) {
|
||||
validator := NewPriceImpactValidator(AggressivePriceImpactThresholds())
|
||||
|
||||
// Test that aggressive thresholds are more lenient
|
||||
// With aggressive: Low=1.0%, Medium=3.0%
|
||||
// So 2.0% falls in the Medium range (between 1.0 and 3.0)
|
||||
result := validator.ValidatePriceImpact(2.0)
|
||||
|
||||
if result.RiskLevel != RiskLevelMedium {
|
||||
t.Errorf("With aggressive thresholds, 2%% should be Medium risk, got %v", result.RiskLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidatePriceImpact(b *testing.B) {
|
||||
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
validator.ValidatePriceImpact(2.5)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidatePriceImpactWithLiquidity(b *testing.B) {
|
||||
validator := NewPriceImpactValidator(DefaultPriceImpactThresholds())
|
||||
tradeSize := big.NewInt(50000)
|
||||
liquidity := big.NewInt(1000000)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
validator.ValidatePriceImpactWithLiquidity(tradeSize, liquidity)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user