package execution import ( "context" "fmt" "log/slog" "math/big" "sync" "time" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethclient" "github.com/your-org/mev-bot/pkg/arbitrage" ) // RiskManagerConfig contains configuration for risk management type RiskManagerConfig struct { // Position limits MaxPositionSize *big.Int // Maximum position size per trade MaxDailyVolume *big.Int // Maximum daily trading volume MaxConcurrentTxs int // Maximum concurrent transactions MaxFailuresPerHour int // Maximum failures before circuit breaker // Profit validation MinProfitAfterGas *big.Int // Minimum profit after gas costs MinROI float64 // Minimum return on investment (e.g., 0.05 = 5%) // Slippage protection MaxSlippageBPS uint16 // Maximum acceptable slippage in basis points SlippageCheckDelay time.Duration // Delay before execution to check for slippage // Gas limits MaxGasPrice *big.Int // Maximum gas price willing to pay MaxGasCost *big.Int // Maximum gas cost per transaction // Circuit breaker CircuitBreakerEnabled bool CircuitBreakerCooldown time.Duration CircuitBreakerThreshold int // Number of failures to trigger // Simulation SimulationEnabled bool SimulationTimeout time.Duration } // DefaultRiskManagerConfig returns default risk management configuration func DefaultRiskManagerConfig() *RiskManagerConfig { return &RiskManagerConfig{ MaxPositionSize: new(big.Int).Mul(big.NewInt(10), big.NewInt(1e18)), // 10 ETH MaxDailyVolume: new(big.Int).Mul(big.NewInt(100), big.NewInt(1e18)), // 100 ETH MaxConcurrentTxs: 5, MaxFailuresPerHour: 10, MinProfitAfterGas: new(big.Int).Mul(big.NewInt(1), big.NewInt(1e16)), // 0.01 ETH MinROI: 0.03, // 3% MaxSlippageBPS: 300, // 3% SlippageCheckDelay: 100 * time.Millisecond, MaxGasPrice: new(big.Int).Mul(big.NewInt(100), big.NewInt(1e9)), // 100 gwei MaxGasCost: new(big.Int).Mul(big.NewInt(5), big.NewInt(1e16)), // 0.05 ETH CircuitBreakerEnabled: true, CircuitBreakerCooldown: 10 * time.Minute, CircuitBreakerThreshold: 5, SimulationEnabled: true, SimulationTimeout: 5 * time.Second, } } // RiskManager manages execution risks type RiskManager struct { config *RiskManagerConfig client *ethclient.Client logger *slog.Logger // State tracking mu sync.RWMutex activeTxs map[common.Hash]*ActiveTransaction dailyVolume *big.Int dailyVolumeResetAt time.Time recentFailures []time.Time circuitBreakerOpen bool circuitBreakerUntil time.Time } // ActiveTransaction tracks an active transaction type ActiveTransaction struct { Hash common.Hash Opportunity *arbitrage.Opportunity SubmittedAt time.Time GasPrice *big.Int ExpectedCost *big.Int } // NewRiskManager creates a new risk manager func NewRiskManager( config *RiskManagerConfig, client *ethclient.Client, logger *slog.Logger, ) *RiskManager { if config == nil { config = DefaultRiskManagerConfig() } return &RiskManager{ config: config, client: client, logger: logger.With("component", "risk_manager"), activeTxs: make(map[common.Hash]*ActiveTransaction), dailyVolume: big.NewInt(0), dailyVolumeResetAt: time.Now().Add(24 * time.Hour), recentFailures: make([]time.Time, 0), } } // RiskAssessment contains the result of risk assessment type RiskAssessment struct { Approved bool Reason string Warnings []string SimulationResult *SimulationResult } // SimulationResult contains simulation results type SimulationResult struct { Success bool ActualOutput *big.Int GasUsed uint64 Revert bool RevertReason string SlippageActual float64 } // AssessRisk performs comprehensive risk assessment func (rm *RiskManager) AssessRisk( ctx context.Context, opp *arbitrage.Opportunity, tx *SwapTransaction, ) (*RiskAssessment, error) { rm.logger.Debug("assessing risk", "opportunityID", opp.ID, "inputAmount", opp.InputAmount.String(), ) assessment := &RiskAssessment{ Approved: true, Warnings: make([]string, 0), } // Check circuit breaker if !rm.checkCircuitBreaker() { assessment.Approved = false assessment.Reason = fmt.Sprintf("circuit breaker open until %s", rm.circuitBreakerUntil.Format(time.RFC3339)) return assessment, nil } // Check concurrent transactions if !rm.checkConcurrentLimit() { assessment.Approved = false assessment.Reason = fmt.Sprintf("concurrent transaction limit reached: %d", rm.config.MaxConcurrentTxs) return assessment, nil } // Check position size if !rm.checkPositionSize(opp.InputAmount) { assessment.Approved = false assessment.Reason = fmt.Sprintf("position size %s exceeds limit %s", opp.InputAmount.String(), rm.config.MaxPositionSize.String()) return assessment, nil } // Check daily volume if !rm.checkDailyVolume(opp.InputAmount) { assessment.Approved = false assessment.Reason = fmt.Sprintf("daily volume limit reached: %s", rm.config.MaxDailyVolume.String()) return assessment, nil } // Check gas price if !rm.checkGasPrice(tx.MaxFeePerGas) { assessment.Approved = false assessment.Reason = fmt.Sprintf("gas price %s exceeds limit %s", tx.MaxFeePerGas.String(), rm.config.MaxGasPrice.String()) return assessment, nil } // Check gas cost gasCost := new(big.Int).Mul(tx.MaxFeePerGas, big.NewInt(int64(tx.GasLimit))) if !rm.checkGasCost(gasCost) { assessment.Approved = false assessment.Reason = fmt.Sprintf("gas cost %s exceeds limit %s", gasCost.String(), rm.config.MaxGasCost.String()) return assessment, nil } // Check minimum profit if !rm.checkMinProfit(opp.NetProfit) { assessment.Approved = false assessment.Reason = fmt.Sprintf("profit %s below minimum %s", opp.NetProfit.String(), rm.config.MinProfitAfterGas.String()) return assessment, nil } // Check minimum ROI if !rm.checkMinROI(opp.ROI) { assessment.Approved = false assessment.Reason = fmt.Sprintf("ROI %.2f%% below minimum %.2f%%", opp.ROI*100, rm.config.MinROI*100) return assessment, nil } // Check slippage if !rm.checkSlippage(tx.Slippage) { assessment.Approved = false assessment.Reason = fmt.Sprintf("slippage %d bps exceeds limit %d bps", tx.Slippage, rm.config.MaxSlippageBPS) return assessment, nil } // Simulate execution if rm.config.SimulationEnabled { simResult, err := rm.SimulateExecution(ctx, tx) if err != nil { assessment.Warnings = append(assessment.Warnings, fmt.Sprintf("simulation failed: %v", err)) } else { assessment.SimulationResult = simResult if !simResult.Success || simResult.Revert { assessment.Approved = false assessment.Reason = fmt.Sprintf("simulation failed: %s", simResult.RevertReason) return assessment, nil } // Check for excessive slippage in simulation if simResult.SlippageActual > float64(rm.config.MaxSlippageBPS)/10000.0 { assessment.Warnings = append(assessment.Warnings, fmt.Sprintf("high slippage detected: %.2f%%", simResult.SlippageActual*100)) } } } rm.logger.Info("risk assessment passed", "opportunityID", opp.ID, "warnings", len(assessment.Warnings), ) return assessment, nil } // SimulateExecution simulates the transaction execution func (rm *RiskManager) SimulateExecution( ctx context.Context, tx *SwapTransaction, ) (*SimulationResult, error) { rm.logger.Debug("simulating execution", "to", tx.To.Hex(), "gasLimit", tx.GasLimit, ) simCtx, cancel := context.WithTimeout(ctx, rm.config.SimulationTimeout) defer cancel() // Create call message msg := ethereum.CallMsg{ To: &tx.To, Gas: tx.GasLimit, GasPrice: tx.MaxFeePerGas, Value: tx.Value, Data: tx.Data, } // Execute simulation result, err := rm.client.CallContract(simCtx, msg, nil) if err != nil { return &SimulationResult{ Success: false, Revert: true, RevertReason: err.Error(), }, nil } // Decode result (assuming it returns output amount) var actualOutput *big.Int if len(result) >= 32 { actualOutput = new(big.Int).SetBytes(result[:32]) } // Calculate actual slippage var slippageActual float64 if tx.Opportunity != nil && actualOutput != nil && tx.Opportunity.OutputAmount.Sign() > 0 { diff := new(big.Float).Sub( new(big.Float).SetInt(tx.Opportunity.OutputAmount), new(big.Float).SetInt(actualOutput), ) slippageActual, _ = new(big.Float).Quo(diff, new(big.Float).SetInt(tx.Opportunity.OutputAmount)).Float64() } return &SimulationResult{ Success: true, ActualOutput: actualOutput, GasUsed: tx.GasLimit, // Estimate Revert: false, SlippageActual: slippageActual, }, nil } // TrackTransaction tracks an active transaction func (rm *RiskManager) TrackTransaction(hash common.Hash, opp *arbitrage.Opportunity, gasPrice *big.Int) { rm.mu.Lock() defer rm.mu.Unlock() rm.activeTxs[hash] = &ActiveTransaction{ Hash: hash, Opportunity: opp, SubmittedAt: time.Now(), GasPrice: gasPrice, ExpectedCost: new(big.Int).Mul(gasPrice, big.NewInt(int64(opp.GasCost.Uint64()))), } // Update daily volume rm.updateDailyVolume(opp.InputAmount) rm.logger.Debug("tracking transaction", "hash", hash.Hex(), "opportunityID", opp.ID, ) } // UntrackTransaction removes a transaction from tracking func (rm *RiskManager) UntrackTransaction(hash common.Hash) { rm.mu.Lock() defer rm.mu.Unlock() delete(rm.activeTxs, hash) rm.logger.Debug("untracked transaction", "hash", hash.Hex()) } // RecordFailure records a transaction failure func (rm *RiskManager) RecordFailure(hash common.Hash, reason string) { rm.mu.Lock() defer rm.mu.Unlock() rm.recentFailures = append(rm.recentFailures, time.Now()) // Clean old failures (older than 1 hour) cutoff := time.Now().Add(-1 * time.Hour) cleaned := make([]time.Time, 0) for _, t := range rm.recentFailures { if t.After(cutoff) { cleaned = append(cleaned, t) } } rm.recentFailures = cleaned rm.logger.Warn("recorded failure", "hash", hash.Hex(), "reason", reason, "recentFailures", len(rm.recentFailures), ) // Check if we should open circuit breaker if rm.config.CircuitBreakerEnabled && len(rm.recentFailures) >= rm.config.CircuitBreakerThreshold { rm.openCircuitBreaker() } } // RecordSuccess records a successful transaction func (rm *RiskManager) RecordSuccess(hash common.Hash, actualProfit *big.Int) { rm.mu.Lock() defer rm.mu.Unlock() rm.logger.Info("recorded success", "hash", hash.Hex(), "actualProfit", actualProfit.String(), ) } // openCircuitBreaker opens the circuit breaker func (rm *RiskManager) openCircuitBreaker() { rm.circuitBreakerOpen = true rm.circuitBreakerUntil = time.Now().Add(rm.config.CircuitBreakerCooldown) rm.logger.Error("circuit breaker opened", "failures", len(rm.recentFailures), "cooldown", rm.config.CircuitBreakerCooldown, "until", rm.circuitBreakerUntil, ) } // checkCircuitBreaker checks if circuit breaker allows execution func (rm *RiskManager) checkCircuitBreaker() bool { rm.mu.RLock() defer rm.mu.RUnlock() if !rm.config.CircuitBreakerEnabled { return true } if rm.circuitBreakerOpen { if time.Now().After(rm.circuitBreakerUntil) { // Reset circuit breaker rm.mu.RUnlock() rm.mu.Lock() rm.circuitBreakerOpen = false rm.recentFailures = make([]time.Time, 0) rm.mu.Unlock() rm.mu.RLock() rm.logger.Info("circuit breaker reset") return true } return false } return true } // checkConcurrentLimit checks concurrent transaction limit func (rm *RiskManager) checkConcurrentLimit() bool { rm.mu.RLock() defer rm.mu.RUnlock() return len(rm.activeTxs) < rm.config.MaxConcurrentTxs } // checkPositionSize checks position size limit func (rm *RiskManager) checkPositionSize(amount *big.Int) bool { return amount.Cmp(rm.config.MaxPositionSize) <= 0 } // checkDailyVolume checks daily volume limit func (rm *RiskManager) checkDailyVolume(amount *big.Int) bool { rm.mu.RLock() defer rm.mu.RUnlock() // Reset daily volume if needed if time.Now().After(rm.dailyVolumeResetAt) { rm.mu.RUnlock() rm.mu.Lock() rm.dailyVolume = big.NewInt(0) rm.dailyVolumeResetAt = time.Now().Add(24 * time.Hour) rm.mu.Unlock() rm.mu.RLock() } newVolume := new(big.Int).Add(rm.dailyVolume, amount) return newVolume.Cmp(rm.config.MaxDailyVolume) <= 0 } // updateDailyVolume updates the daily volume counter func (rm *RiskManager) updateDailyVolume(amount *big.Int) { rm.dailyVolume.Add(rm.dailyVolume, amount) } // checkGasPrice checks gas price limit func (rm *RiskManager) checkGasPrice(gasPrice *big.Int) bool { return gasPrice.Cmp(rm.config.MaxGasPrice) <= 0 } // checkGasCost checks gas cost limit func (rm *RiskManager) checkGasCost(gasCost *big.Int) bool { return gasCost.Cmp(rm.config.MaxGasCost) <= 0 } // checkMinProfit checks minimum profit requirement func (rm *RiskManager) checkMinProfit(profit *big.Int) bool { return profit.Cmp(rm.config.MinProfitAfterGas) >= 0 } // checkMinROI checks minimum ROI requirement func (rm *RiskManager) checkMinROI(roi float64) bool { return roi >= rm.config.MinROI } // checkSlippage checks slippage limit func (rm *RiskManager) checkSlippage(slippageBPS uint16) bool { return slippageBPS <= rm.config.MaxSlippageBPS } // GetActiveTransactions returns all active transactions func (rm *RiskManager) GetActiveTransactions() []*ActiveTransaction { rm.mu.RLock() defer rm.mu.RUnlock() txs := make([]*ActiveTransaction, 0, len(rm.activeTxs)) for _, tx := range rm.activeTxs { txs = append(txs, tx) } return txs } // GetStats returns risk management statistics func (rm *RiskManager) GetStats() map[string]interface{} { rm.mu.RLock() defer rm.mu.RUnlock() return map[string]interface{}{ "active_transactions": len(rm.activeTxs), "daily_volume": rm.dailyVolume.String(), "recent_failures": len(rm.recentFailures), "circuit_breaker_open": rm.circuitBreakerOpen, "circuit_breaker_until": rm.circuitBreakerUntil.Format(time.RFC3339), } }