Files
mev-beta/pkg/execution/risk_manager.go
Administrator 688311f1e0 fix(compilation): resolve type system and interface errors
- Add GetPoolsByToken method to cache interface and implementation
- Fix interface pointer types (use interface not *interface)
- Fix SwapEvent.TokenIn/TokenOut usage to use GetInputToken/GetOutputToken methods
- Fix ethereum.CallMsg import and usage
- Fix parser factory and validator initialization in main.go
- Remove unused variables and imports

WIP: Still fixing main.go config struct field mismatches

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-10 19:46:06 +01:00

500 lines
14 KiB
Go

package execution
import (
"context"
"fmt"
"log/slog"
"math/big"
"sync"
"time"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/your-org/mev-bot/pkg/arbitrage"
)
// RiskManagerConfig contains configuration for risk management
type RiskManagerConfig struct {
// Position limits
MaxPositionSize *big.Int // Maximum position size per trade
MaxDailyVolume *big.Int // Maximum daily trading volume
MaxConcurrentTxs int // Maximum concurrent transactions
MaxFailuresPerHour int // Maximum failures before circuit breaker
// Profit validation
MinProfitAfterGas *big.Int // Minimum profit after gas costs
MinROI float64 // Minimum return on investment (e.g., 0.05 = 5%)
// Slippage protection
MaxSlippageBPS uint16 // Maximum acceptable slippage in basis points
SlippageCheckDelay time.Duration // Delay before execution to check for slippage
// Gas limits
MaxGasPrice *big.Int // Maximum gas price willing to pay
MaxGasCost *big.Int // Maximum gas cost per transaction
// Circuit breaker
CircuitBreakerEnabled bool
CircuitBreakerCooldown time.Duration
CircuitBreakerThreshold int // Number of failures to trigger
// Simulation
SimulationEnabled bool
SimulationTimeout time.Duration
}
// DefaultRiskManagerConfig returns default risk management configuration
func DefaultRiskManagerConfig() *RiskManagerConfig {
return &RiskManagerConfig{
MaxPositionSize: new(big.Int).Mul(big.NewInt(10), big.NewInt(1e18)), // 10 ETH
MaxDailyVolume: new(big.Int).Mul(big.NewInt(100), big.NewInt(1e18)), // 100 ETH
MaxConcurrentTxs: 5,
MaxFailuresPerHour: 10,
MinProfitAfterGas: new(big.Int).Mul(big.NewInt(1), big.NewInt(1e16)), // 0.01 ETH
MinROI: 0.03, // 3%
MaxSlippageBPS: 300, // 3%
SlippageCheckDelay: 100 * time.Millisecond,
MaxGasPrice: new(big.Int).Mul(big.NewInt(100), big.NewInt(1e9)), // 100 gwei
MaxGasCost: new(big.Int).Mul(big.NewInt(5), big.NewInt(1e16)), // 0.05 ETH
CircuitBreakerEnabled: true,
CircuitBreakerCooldown: 10 * time.Minute,
CircuitBreakerThreshold: 5,
SimulationEnabled: true,
SimulationTimeout: 5 * time.Second,
}
}
// RiskManager manages execution risks
type RiskManager struct {
config *RiskManagerConfig
client *ethclient.Client
logger *slog.Logger
// State tracking
mu sync.RWMutex
activeTxs map[common.Hash]*ActiveTransaction
dailyVolume *big.Int
dailyVolumeResetAt time.Time
recentFailures []time.Time
circuitBreakerOpen bool
circuitBreakerUntil time.Time
}
// ActiveTransaction tracks an active transaction
type ActiveTransaction struct {
Hash common.Hash
Opportunity *arbitrage.Opportunity
SubmittedAt time.Time
GasPrice *big.Int
ExpectedCost *big.Int
}
// NewRiskManager creates a new risk manager
func NewRiskManager(
config *RiskManagerConfig,
client *ethclient.Client,
logger *slog.Logger,
) *RiskManager {
if config == nil {
config = DefaultRiskManagerConfig()
}
return &RiskManager{
config: config,
client: client,
logger: logger.With("component", "risk_manager"),
activeTxs: make(map[common.Hash]*ActiveTransaction),
dailyVolume: big.NewInt(0),
dailyVolumeResetAt: time.Now().Add(24 * time.Hour),
recentFailures: make([]time.Time, 0),
}
}
// RiskAssessment contains the result of risk assessment
type RiskAssessment struct {
Approved bool
Reason string
Warnings []string
SimulationResult *SimulationResult
}
// SimulationResult contains simulation results
type SimulationResult struct {
Success bool
ActualOutput *big.Int
GasUsed uint64
Revert bool
RevertReason string
SlippageActual float64
}
// AssessRisk performs comprehensive risk assessment
func (rm *RiskManager) AssessRisk(
ctx context.Context,
opp *arbitrage.Opportunity,
tx *SwapTransaction,
) (*RiskAssessment, error) {
rm.logger.Debug("assessing risk",
"opportunityID", opp.ID,
"inputAmount", opp.InputAmount.String(),
)
assessment := &RiskAssessment{
Approved: true,
Warnings: make([]string, 0),
}
// Check circuit breaker
if !rm.checkCircuitBreaker() {
assessment.Approved = false
assessment.Reason = fmt.Sprintf("circuit breaker open until %s", rm.circuitBreakerUntil.Format(time.RFC3339))
return assessment, nil
}
// Check concurrent transactions
if !rm.checkConcurrentLimit() {
assessment.Approved = false
assessment.Reason = fmt.Sprintf("concurrent transaction limit reached: %d", rm.config.MaxConcurrentTxs)
return assessment, nil
}
// Check position size
if !rm.checkPositionSize(opp.InputAmount) {
assessment.Approved = false
assessment.Reason = fmt.Sprintf("position size %s exceeds limit %s", opp.InputAmount.String(), rm.config.MaxPositionSize.String())
return assessment, nil
}
// Check daily volume
if !rm.checkDailyVolume(opp.InputAmount) {
assessment.Approved = false
assessment.Reason = fmt.Sprintf("daily volume limit reached: %s", rm.config.MaxDailyVolume.String())
return assessment, nil
}
// Check gas price
if !rm.checkGasPrice(tx.MaxFeePerGas) {
assessment.Approved = false
assessment.Reason = fmt.Sprintf("gas price %s exceeds limit %s", tx.MaxFeePerGas.String(), rm.config.MaxGasPrice.String())
return assessment, nil
}
// Check gas cost
gasCost := new(big.Int).Mul(tx.MaxFeePerGas, big.NewInt(int64(tx.GasLimit)))
if !rm.checkGasCost(gasCost) {
assessment.Approved = false
assessment.Reason = fmt.Sprintf("gas cost %s exceeds limit %s", gasCost.String(), rm.config.MaxGasCost.String())
return assessment, nil
}
// Check minimum profit
if !rm.checkMinProfit(opp.NetProfit) {
assessment.Approved = false
assessment.Reason = fmt.Sprintf("profit %s below minimum %s", opp.NetProfit.String(), rm.config.MinProfitAfterGas.String())
return assessment, nil
}
// Check minimum ROI
if !rm.checkMinROI(opp.ROI) {
assessment.Approved = false
assessment.Reason = fmt.Sprintf("ROI %.2f%% below minimum %.2f%%", opp.ROI*100, rm.config.MinROI*100)
return assessment, nil
}
// Check slippage
if !rm.checkSlippage(tx.Slippage) {
assessment.Approved = false
assessment.Reason = fmt.Sprintf("slippage %d bps exceeds limit %d bps", tx.Slippage, rm.config.MaxSlippageBPS)
return assessment, nil
}
// Simulate execution
if rm.config.SimulationEnabled {
simResult, err := rm.SimulateExecution(ctx, tx)
if err != nil {
assessment.Warnings = append(assessment.Warnings, fmt.Sprintf("simulation failed: %v", err))
} else {
assessment.SimulationResult = simResult
if !simResult.Success || simResult.Revert {
assessment.Approved = false
assessment.Reason = fmt.Sprintf("simulation failed: %s", simResult.RevertReason)
return assessment, nil
}
// Check for excessive slippage in simulation
if simResult.SlippageActual > float64(rm.config.MaxSlippageBPS)/10000.0 {
assessment.Warnings = append(assessment.Warnings,
fmt.Sprintf("high slippage detected: %.2f%%", simResult.SlippageActual*100))
}
}
}
rm.logger.Info("risk assessment passed",
"opportunityID", opp.ID,
"warnings", len(assessment.Warnings),
)
return assessment, nil
}
// SimulateExecution simulates the transaction execution
func (rm *RiskManager) SimulateExecution(
ctx context.Context,
tx *SwapTransaction,
) (*SimulationResult, error) {
rm.logger.Debug("simulating execution",
"to", tx.To.Hex(),
"gasLimit", tx.GasLimit,
)
simCtx, cancel := context.WithTimeout(ctx, rm.config.SimulationTimeout)
defer cancel()
// Create call message
msg := ethereum.CallMsg{
To: &tx.To,
Gas: tx.GasLimit,
GasPrice: tx.MaxFeePerGas,
Value: tx.Value,
Data: tx.Data,
}
// Execute simulation
result, err := rm.client.CallContract(simCtx, msg, nil)
if err != nil {
return &SimulationResult{
Success: false,
Revert: true,
RevertReason: err.Error(),
}, nil
}
// Decode result (assuming it returns output amount)
var actualOutput *big.Int
if len(result) >= 32 {
actualOutput = new(big.Int).SetBytes(result[:32])
}
// Calculate actual slippage
var slippageActual float64
if tx.Opportunity != nil && actualOutput != nil && tx.Opportunity.OutputAmount.Sign() > 0 {
diff := new(big.Float).Sub(
new(big.Float).SetInt(tx.Opportunity.OutputAmount),
new(big.Float).SetInt(actualOutput),
)
slippageActual, _ = new(big.Float).Quo(diff, new(big.Float).SetInt(tx.Opportunity.OutputAmount)).Float64()
}
return &SimulationResult{
Success: true,
ActualOutput: actualOutput,
GasUsed: tx.GasLimit, // Estimate
Revert: false,
SlippageActual: slippageActual,
}, nil
}
// TrackTransaction tracks an active transaction
func (rm *RiskManager) TrackTransaction(hash common.Hash, opp *arbitrage.Opportunity, gasPrice *big.Int) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.activeTxs[hash] = &ActiveTransaction{
Hash: hash,
Opportunity: opp,
SubmittedAt: time.Now(),
GasPrice: gasPrice,
ExpectedCost: new(big.Int).Mul(gasPrice, big.NewInt(int64(opp.GasCost.Uint64()))),
}
// Update daily volume
rm.updateDailyVolume(opp.InputAmount)
rm.logger.Debug("tracking transaction",
"hash", hash.Hex(),
"opportunityID", opp.ID,
)
}
// UntrackTransaction removes a transaction from tracking
func (rm *RiskManager) UntrackTransaction(hash common.Hash) {
rm.mu.Lock()
defer rm.mu.Unlock()
delete(rm.activeTxs, hash)
rm.logger.Debug("untracked transaction", "hash", hash.Hex())
}
// RecordFailure records a transaction failure
func (rm *RiskManager) RecordFailure(hash common.Hash, reason string) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.recentFailures = append(rm.recentFailures, time.Now())
// Clean old failures (older than 1 hour)
cutoff := time.Now().Add(-1 * time.Hour)
cleaned := make([]time.Time, 0)
for _, t := range rm.recentFailures {
if t.After(cutoff) {
cleaned = append(cleaned, t)
}
}
rm.recentFailures = cleaned
rm.logger.Warn("recorded failure",
"hash", hash.Hex(),
"reason", reason,
"recentFailures", len(rm.recentFailures),
)
// Check if we should open circuit breaker
if rm.config.CircuitBreakerEnabled && len(rm.recentFailures) >= rm.config.CircuitBreakerThreshold {
rm.openCircuitBreaker()
}
}
// RecordSuccess records a successful transaction
func (rm *RiskManager) RecordSuccess(hash common.Hash, actualProfit *big.Int) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.logger.Info("recorded success",
"hash", hash.Hex(),
"actualProfit", actualProfit.String(),
)
}
// openCircuitBreaker opens the circuit breaker
func (rm *RiskManager) openCircuitBreaker() {
rm.circuitBreakerOpen = true
rm.circuitBreakerUntil = time.Now().Add(rm.config.CircuitBreakerCooldown)
rm.logger.Error("circuit breaker opened",
"failures", len(rm.recentFailures),
"cooldown", rm.config.CircuitBreakerCooldown,
"until", rm.circuitBreakerUntil,
)
}
// checkCircuitBreaker checks if circuit breaker allows execution
func (rm *RiskManager) checkCircuitBreaker() bool {
rm.mu.RLock()
defer rm.mu.RUnlock()
if !rm.config.CircuitBreakerEnabled {
return true
}
if rm.circuitBreakerOpen {
if time.Now().After(rm.circuitBreakerUntil) {
// Reset circuit breaker
rm.mu.RUnlock()
rm.mu.Lock()
rm.circuitBreakerOpen = false
rm.recentFailures = make([]time.Time, 0)
rm.mu.Unlock()
rm.mu.RLock()
rm.logger.Info("circuit breaker reset")
return true
}
return false
}
return true
}
// checkConcurrentLimit checks concurrent transaction limit
func (rm *RiskManager) checkConcurrentLimit() bool {
rm.mu.RLock()
defer rm.mu.RUnlock()
return len(rm.activeTxs) < rm.config.MaxConcurrentTxs
}
// checkPositionSize checks position size limit
func (rm *RiskManager) checkPositionSize(amount *big.Int) bool {
return amount.Cmp(rm.config.MaxPositionSize) <= 0
}
// checkDailyVolume checks daily volume limit
func (rm *RiskManager) checkDailyVolume(amount *big.Int) bool {
rm.mu.RLock()
defer rm.mu.RUnlock()
// Reset daily volume if needed
if time.Now().After(rm.dailyVolumeResetAt) {
rm.mu.RUnlock()
rm.mu.Lock()
rm.dailyVolume = big.NewInt(0)
rm.dailyVolumeResetAt = time.Now().Add(24 * time.Hour)
rm.mu.Unlock()
rm.mu.RLock()
}
newVolume := new(big.Int).Add(rm.dailyVolume, amount)
return newVolume.Cmp(rm.config.MaxDailyVolume) <= 0
}
// updateDailyVolume updates the daily volume counter
func (rm *RiskManager) updateDailyVolume(amount *big.Int) {
rm.dailyVolume.Add(rm.dailyVolume, amount)
}
// checkGasPrice checks gas price limit
func (rm *RiskManager) checkGasPrice(gasPrice *big.Int) bool {
return gasPrice.Cmp(rm.config.MaxGasPrice) <= 0
}
// checkGasCost checks gas cost limit
func (rm *RiskManager) checkGasCost(gasCost *big.Int) bool {
return gasCost.Cmp(rm.config.MaxGasCost) <= 0
}
// checkMinProfit checks minimum profit requirement
func (rm *RiskManager) checkMinProfit(profit *big.Int) bool {
return profit.Cmp(rm.config.MinProfitAfterGas) >= 0
}
// checkMinROI checks minimum ROI requirement
func (rm *RiskManager) checkMinROI(roi float64) bool {
return roi >= rm.config.MinROI
}
// checkSlippage checks slippage limit
func (rm *RiskManager) checkSlippage(slippageBPS uint16) bool {
return slippageBPS <= rm.config.MaxSlippageBPS
}
// GetActiveTransactions returns all active transactions
func (rm *RiskManager) GetActiveTransactions() []*ActiveTransaction {
rm.mu.RLock()
defer rm.mu.RUnlock()
txs := make([]*ActiveTransaction, 0, len(rm.activeTxs))
for _, tx := range rm.activeTxs {
txs = append(txs, tx)
}
return txs
}
// GetStats returns risk management statistics
func (rm *RiskManager) GetStats() map[string]interface{} {
rm.mu.RLock()
defer rm.mu.RUnlock()
return map[string]interface{}{
"active_transactions": len(rm.activeTxs),
"daily_volume": rm.dailyVolume.String(),
"recent_failures": len(rm.recentFailures),
"circuit_breaker_open": rm.circuitBreakerOpen,
"circuit_breaker_until": rm.circuitBreakerUntil.Format(time.RFC3339),
}
}