package trading import ( "fmt" "math/big" "time" "github.com/ethereum/go-ethereum/common" "github.com/fraktal/mev-beta/internal/logger" "github.com/fraktal/mev-beta/pkg/validation" ) // SlippageProtection provides comprehensive slippage protection for trades type SlippageProtection struct { validator *validation.InputValidator logger *logger.Logger maxSlippagePercent float64 priceUpdateWindow time.Duration emergencyStopLoss float64 minimumLiquidity *big.Int } // TradeParameters represents parameters for a trade type TradeParameters struct { TokenIn common.Address TokenOut common.Address AmountIn *big.Int MinAmountOut *big.Int MaxSlippage float64 Deadline uint64 Pool common.Address ExpectedPrice *big.Float CurrentLiquidity *big.Int } // SlippageCheck represents the result of slippage validation type SlippageCheck struct { IsValid bool CalculatedSlippage float64 MaxAllowedSlippage float64 PriceImpact float64 Warnings []string Errors []string } // NewSlippageProtection creates a new slippage protection instance func NewSlippageProtection(logger *logger.Logger) *SlippageProtection { return &SlippageProtection{ validator: validation.NewInputValidator(), logger: logger, maxSlippagePercent: 5.0, // 5% maximum slippage priceUpdateWindow: 30 * time.Second, emergencyStopLoss: 20.0, // 20% emergency stop loss minimumLiquidity: big.NewInt(10000), // Minimum liquidity threshold } } // ValidateTradeParameters performs comprehensive validation of trade parameters func (sp *SlippageProtection) ValidateTradeParameters(params *TradeParameters) (*SlippageCheck, error) { check := &SlippageCheck{ IsValid: true, Warnings: make([]string, 0), Errors: make([]string, 0), } // Validate input parameters if err := sp.validateInputParameters(params, check); err != nil { return check, err } // Calculate slippage slippage, err := sp.calculateSlippage(params) if err != nil { check.Errors = append(check.Errors, fmt.Sprintf("Failed to calculate slippage: %v", err)) check.IsValid = false return check, nil } check.CalculatedSlippage = slippage // Check slippage limits if slippage > params.MaxSlippage { check.Errors = append(check.Errors, fmt.Sprintf("Calculated slippage %.2f%% exceeds maximum allowed %.2f%%", slippage, params.MaxSlippage)) check.IsValid = false } // Check emergency stop loss if slippage > sp.emergencyStopLoss { check.Errors = append(check.Errors, fmt.Sprintf("Slippage %.2f%% exceeds emergency stop loss %.2f%%", slippage, sp.emergencyStopLoss)) check.IsValid = false } // Calculate price impact priceImpact, err := sp.calculatePriceImpact(params) if err != nil { check.Warnings = append(check.Warnings, fmt.Sprintf("Could not calculate price impact: %v", err)) } else { check.PriceImpact = priceImpact // Warn about high price impact if priceImpact > 3.0 { check.Warnings = append(check.Warnings, fmt.Sprintf("High price impact detected: %.2f%%", priceImpact)) } } // Check liquidity if err := sp.checkLiquidity(params, check); err != nil { check.Errors = append(check.Errors, err.Error()) check.IsValid = false } // Check for sandwich attack protection if err := sp.checkSandwichAttackRisk(params, check); err != nil { check.Warnings = append(check.Warnings, err.Error()) } check.MaxAllowedSlippage = params.MaxSlippage sp.logger.Debug(fmt.Sprintf("Slippage check completed: valid=%t, slippage=%.2f%%, impact=%.2f%%", check.IsValid, check.CalculatedSlippage, check.PriceImpact)) return check, nil } // validateInputParameters validates all input parameters func (sp *SlippageProtection) validateInputParameters(params *TradeParameters, check *SlippageCheck) error { // Validate addresses if err := sp.validator.ValidateCommonAddress(params.TokenIn); err != nil { check.Errors = append(check.Errors, fmt.Sprintf("Invalid TokenIn: %v", err)) check.IsValid = false } if err := sp.validator.ValidateCommonAddress(params.TokenOut); err != nil { check.Errors = append(check.Errors, fmt.Sprintf("Invalid TokenOut: %v", err)) check.IsValid = false } if err := sp.validator.ValidateCommonAddress(params.Pool); err != nil { check.Errors = append(check.Errors, fmt.Sprintf("Invalid Pool: %v", err)) check.IsValid = false } // Check for same token if params.TokenIn == params.TokenOut { check.Errors = append(check.Errors, "TokenIn and TokenOut cannot be the same") check.IsValid = false } // Validate amounts if err := sp.validator.ValidateBigInt(params.AmountIn, "AmountIn"); err != nil { check.Errors = append(check.Errors, fmt.Sprintf("Invalid AmountIn: %v", err)) check.IsValid = false } if err := sp.validator.ValidateBigInt(params.MinAmountOut, "MinAmountOut"); err != nil { check.Errors = append(check.Errors, fmt.Sprintf("Invalid MinAmountOut: %v", err)) check.IsValid = false } // Validate slippage tolerance if err := sp.validator.ValidateSlippageTolerance(params.MaxSlippage); err != nil { check.Errors = append(check.Errors, fmt.Sprintf("Invalid MaxSlippage: %v", err)) check.IsValid = false } // Validate deadline if err := sp.validator.ValidateDeadline(params.Deadline); err != nil { check.Errors = append(check.Errors, fmt.Sprintf("Invalid Deadline: %v", err)) check.IsValid = false } return nil } // calculateSlippage calculates the slippage percentage func (sp *SlippageProtection) calculateSlippage(params *TradeParameters) (float64, error) { if params.ExpectedPrice == nil { return 0, fmt.Errorf("expected price not provided") } // Calculate expected output based on expected price amountInFloat := new(big.Float).SetInt(params.AmountIn) expectedAmountOut := new(big.Float).Mul(amountInFloat, params.ExpectedPrice) // Convert to integer for comparison expectedAmountOutInt, _ := expectedAmountOut.Int(nil) // Calculate slippage percentage if expectedAmountOutInt.Cmp(big.NewInt(0)) == 0 { return 0, fmt.Errorf("expected amount out is zero") } // Slippage = (expected - minimum) / expected * 100 diff := new(big.Int).Sub(expectedAmountOutInt, params.MinAmountOut) slippageFloat := new(big.Float).Quo(new(big.Float).SetInt(diff), new(big.Float).SetInt(expectedAmountOutInt)) slippagePercent, _ := slippageFloat.Float64() return slippagePercent * 100, nil } // calculatePriceImpact calculates the price impact of the trade func (sp *SlippageProtection) calculatePriceImpact(params *TradeParameters) (float64, error) { if params.CurrentLiquidity == nil || params.CurrentLiquidity.Cmp(big.NewInt(0)) == 0 { return 0, fmt.Errorf("current liquidity not available") } // Simple price impact calculation: amount / liquidity * 100 // In practice, this would use more sophisticated AMM math amountFloat := new(big.Float).SetInt(params.AmountIn) liquidityFloat := new(big.Float).SetInt(params.CurrentLiquidity) impact := new(big.Float).Quo(amountFloat, liquidityFloat) impactPercent, _ := impact.Float64() return impactPercent * 100, nil } // checkLiquidity validates that sufficient liquidity exists func (sp *SlippageProtection) checkLiquidity(params *TradeParameters, check *SlippageCheck) error { if params.CurrentLiquidity == nil { return fmt.Errorf("liquidity information not available") } // Check minimum liquidity threshold if params.CurrentLiquidity.Cmp(sp.minimumLiquidity) < 0 { return fmt.Errorf("liquidity %s below minimum threshold %s", params.CurrentLiquidity.String(), sp.minimumLiquidity.String()) } // Check if trade size is reasonable relative to liquidity liquidityFloat := new(big.Float).SetInt(params.CurrentLiquidity) amountFloat := new(big.Float).SetInt(params.AmountIn) ratio := new(big.Float).Quo(amountFloat, liquidityFloat) ratioPercent, _ := ratio.Float64() if ratioPercent > 0.1 { // 10% of liquidity check.Warnings = append(check.Warnings, fmt.Sprintf("Trade size is %.2f%% of available liquidity", ratioPercent*100)) } return nil } // checkSandwichAttackRisk checks for potential sandwich attack risks func (sp *SlippageProtection) checkSandwichAttackRisk(params *TradeParameters, check *SlippageCheck) error { // Check if the trade is large enough to be a sandwich attack target liquidityFloat := new(big.Float).SetInt(params.CurrentLiquidity) amountFloat := new(big.Float).SetInt(params.AmountIn) ratio := new(big.Float).Quo(amountFloat, liquidityFloat) ratioPercent, _ := ratio.Float64() // Large trades are more susceptible to sandwich attacks if ratioPercent > 0.05 { // 5% of liquidity return fmt.Errorf("large trade size (%.2f%% of liquidity) may be vulnerable to sandwich attacks", ratioPercent*100) } // Check slippage tolerance - high tolerance increases sandwich risk if params.MaxSlippage > 1.0 { // 1% return fmt.Errorf("high slippage tolerance (%.2f%%) increases sandwich attack risk", params.MaxSlippage) } return nil } // AdjustForMarketConditions adjusts trade parameters based on current market conditions func (sp *SlippageProtection) AdjustForMarketConditions(params *TradeParameters, volatility float64) *TradeParameters { adjusted := *params // Copy parameters // Increase slippage tolerance during high volatility if volatility > 0.05 { // 5% volatility volatilityMultiplier := 1.0 + volatility adjusted.MaxSlippage = params.MaxSlippage * volatilityMultiplier // Cap at maximum allowed slippage if adjusted.MaxSlippage > sp.maxSlippagePercent { adjusted.MaxSlippage = sp.maxSlippagePercent } sp.logger.Info(fmt.Sprintf("Adjusted slippage tolerance to %.2f%% due to high volatility %.2f%%", adjusted.MaxSlippage, volatility*100)) } return &adjusted } // CreateSafeTradeParameters creates conservative trade parameters func (sp *SlippageProtection) CreateSafeTradeParameters( tokenIn, tokenOut, pool common.Address, amountIn *big.Int, expectedPrice *big.Float, currentLiquidity *big.Int, ) *TradeParameters { // Calculate minimum amount out with conservative slippage conservativeSlippage := 0.5 // 0.5% amountInFloat := new(big.Float).SetInt(amountIn) expectedAmountOut := new(big.Float).Mul(amountInFloat, expectedPrice) // Apply slippage buffer slippageMultiplier := new(big.Float).SetFloat64(1.0 - conservativeSlippage/100.0) minAmountOut := new(big.Float).Mul(expectedAmountOut, slippageMultiplier) minAmountOutInt, _ := minAmountOut.Int(nil) // Set deadline to 5 minutes from now deadline := uint64(time.Now().Add(5 * time.Minute).Unix()) return &TradeParameters{ TokenIn: tokenIn, TokenOut: tokenOut, AmountIn: amountIn, MinAmountOut: minAmountOutInt, MaxSlippage: conservativeSlippage, Deadline: deadline, Pool: pool, ExpectedPrice: expectedPrice, CurrentLiquidity: currentLiquidity, } } // GetEmergencyStopLoss returns the emergency stop loss threshold func (sp *SlippageProtection) GetEmergencyStopLoss() float64 { return sp.emergencyStopLoss } // SetMaxSlippage updates the maximum allowed slippage func (sp *SlippageProtection) SetMaxSlippage(maxSlippage float64) error { if err := sp.validator.ValidateSlippageTolerance(maxSlippage); err != nil { return err } sp.maxSlippagePercent = maxSlippage sp.logger.Info(fmt.Sprintf("Updated maximum slippage to %.2f%%", maxSlippage)) return nil }