package arbitrage import ( "context" "log/slog" "math/big" "os" "testing" "github.com/ethereum/go-ethereum/common" "github.com/your-org/mev-bot/pkg/types" ) func setupGasEstimatorTest(t *testing.T) *GasEstimator { t.Helper() logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelError, })) config := DefaultGasEstimatorConfig() return NewGasEstimator(config, logger) } func TestGasEstimator_EstimateGasCost(t *testing.T) { ge := setupGasEstimatorTest(t) ctx := context.Background() tests := []struct { name string path *Path gasPrice *big.Int wantError bool wantGasMin uint64 wantGasMax uint64 }{ { name: "single V2 swap", path: &Path{ Pools: []*types.PoolInfo{ { Address: common.HexToAddress("0x1111"), Protocol: types.ProtocolUniswapV2, }, }, }, gasPrice: big.NewInt(1e9), // 1 gwei wantError: false, wantGasMin: 130000, // Base + V2 wantGasMax: 160000, }, { name: "single V3 swap", path: &Path{ Pools: []*types.PoolInfo{ { Address: common.HexToAddress("0x2222"), Protocol: types.ProtocolUniswapV3, }, }, }, gasPrice: big.NewInt(2e9), // 2 gwei wantError: false, wantGasMin: 190000, // Base + V3 wantGasMax: 230000, }, { name: "multi-hop path", path: &Path{ Pools: []*types.PoolInfo{ { Address: common.HexToAddress("0x3333"), Protocol: types.ProtocolUniswapV2, }, { Address: common.HexToAddress("0x4444"), Protocol: types.ProtocolUniswapV3, }, { Address: common.HexToAddress("0x5555"), Protocol: types.ProtocolCurve, }, }, }, gasPrice: big.NewInt(1e9), wantError: false, wantGasMin: 450000, // Base + V2 + V3 + Curve wantGasMax: 550000, }, { name: "nil gas price", path: &Path{ Pools: []*types.PoolInfo{ { Address: common.HexToAddress("0x6666"), Protocol: types.ProtocolUniswapV2, }, }, }, gasPrice: nil, wantError: true, }, { name: "zero gas price", path: &Path{ Pools: []*types.PoolInfo{ { Address: common.HexToAddress("0x7777"), Protocol: types.ProtocolUniswapV2, }, }, }, gasPrice: big.NewInt(0), wantError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gasCost, err := ge.EstimateGasCost(ctx, tt.path, tt.gasPrice) if tt.wantError { if err == nil { t.Error("expected error, got nil") } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if gasCost == nil { t.Fatal("gas cost is nil") } if gasCost.Sign() <= 0 { t.Error("gas cost is not positive") } // Calculate expected gas units expectedGasUnits := new(big.Int).Div(gasCost, tt.gasPrice) gasUnits := expectedGasUnits.Uint64() if gasUnits < tt.wantGasMin || gasUnits > tt.wantGasMax { t.Errorf("gas units %d not in range [%d, %d]", gasUnits, tt.wantGasMin, tt.wantGasMax) } t.Logf("Path with %d pools: gas=%d units, cost=%s wei", len(tt.path.Pools), gasUnits, gasCost.String()) }) } } func TestGasEstimator_EstimatePoolGas(t *testing.T) { ge := setupGasEstimatorTest(t) tests := []struct { name string protocol types.ProtocolType wantGas uint64 }{ { name: "UniswapV2", protocol: types.ProtocolUniswapV2, wantGas: ge.config.V2SwapGas, }, { name: "UniswapV3", protocol: types.ProtocolUniswapV3, wantGas: ge.config.V3SwapGas, }, { name: "SushiSwap", protocol: types.ProtocolSushiSwap, wantGas: ge.config.V2SwapGas, }, { name: "Curve", protocol: types.ProtocolCurve, wantGas: ge.config.CurveSwapGas, }, { name: "Unknown protocol", protocol: types.ProtocolType("unknown"), wantGas: ge.config.V2SwapGas, // Default to V2 }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gas := ge.estimatePoolGas(tt.protocol) if gas != tt.wantGas { t.Errorf("got %d gas, want %d", gas, tt.wantGas) } }) } } func TestGasEstimator_EstimateGasLimit(t *testing.T) { ge := setupGasEstimatorTest(t) ctx := context.Background() tests := []struct { name string path *Path wantGasMin uint64 wantGasMax uint64 wantError bool }{ { name: "single pool", path: &Path{ Pools: []*types.PoolInfo{ { Address: common.HexToAddress("0x1111"), Protocol: types.ProtocolUniswapV2, }, }, }, wantGasMin: 130000, wantGasMax: 160000, wantError: false, }, { name: "three pools", path: &Path{ Pools: []*types.PoolInfo{ {Protocol: types.ProtocolUniswapV2}, {Protocol: types.ProtocolUniswapV3}, {Protocol: types.ProtocolCurve}, }, }, wantGasMin: 450000, wantGasMax: 550000, wantError: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gasLimit, err := ge.EstimateGasLimit(ctx, tt.path) if tt.wantError { if err == nil { t.Error("expected error, got nil") } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if gasLimit < tt.wantGasMin || gasLimit > tt.wantGasMax { t.Errorf("gas limit %d not in range [%d, %d]", gasLimit, tt.wantGasMin, tt.wantGasMax) } t.Logf("Gas limit for %d pools: %d", len(tt.path.Pools), gasLimit) }) } } func TestGasEstimator_EstimateOptimalGasPrice(t *testing.T) { ge := setupGasEstimatorTest(t) ctx := context.Background() path := &Path{ Pools: []*types.PoolInfo{ { Address: common.HexToAddress("0x1111"), Protocol: types.ProtocolUniswapV2, }, }, } tests := []struct { name string netProfit *big.Int currentGasPrice *big.Int wantGasPriceMin *big.Int wantGasPriceMax *big.Int useCurrentPrice bool }{ { name: "high profit, low gas price", netProfit: big.NewInt(1e18), // 1 ETH profit currentGasPrice: big.NewInt(1e9), // 1 gwei useCurrentPrice: true, // Should use current (it's lower than max) }, { name: "low profit", netProfit: big.NewInt(1e16), // 0.01 ETH profit currentGasPrice: big.NewInt(1e9), // 1 gwei useCurrentPrice: true, }, { name: "zero profit", netProfit: big.NewInt(0), currentGasPrice: big.NewInt(1e9), useCurrentPrice: true, }, { name: "negative profit", netProfit: big.NewInt(-1e18), currentGasPrice: big.NewInt(1e9), useCurrentPrice: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { optimalPrice, err := ge.EstimateOptimalGasPrice(ctx, tt.netProfit, path, tt.currentGasPrice) if err != nil { t.Fatalf("unexpected error: %v", err) } if optimalPrice == nil { t.Fatal("optimal gas price is nil") } if optimalPrice.Sign() < 0 { t.Error("optimal gas price is negative") } if tt.useCurrentPrice && optimalPrice.Cmp(tt.currentGasPrice) != 0 { t.Logf("optimal price %s differs from current %s", optimalPrice.String(), tt.currentGasPrice.String()) } t.Logf("Net profit: %s, Current: %s, Optimal: %s", tt.netProfit.String(), tt.currentGasPrice.String(), optimalPrice.String(), ) }) } } func TestGasEstimator_CompareGasCosts(t *testing.T) { ge := setupGasEstimatorTest(t) ctx := context.Background() opportunities := []*Opportunity{ { ID: "opp1", Type: OpportunityTypeTwoPool, NetProfit: big.NewInt(1e18), // 1 ETH ROI: 0.10, Path: []*PathStep{ { PoolAddress: common.HexToAddress("0x1111"), Protocol: types.ProtocolUniswapV2, }, }, }, { ID: "opp2", Type: OpportunityTypeMultiHop, NetProfit: big.NewInt(5e17), // 0.5 ETH ROI: 0.15, Path: []*PathStep{ { PoolAddress: common.HexToAddress("0x2222"), Protocol: types.ProtocolUniswapV3, }, { PoolAddress: common.HexToAddress("0x3333"), Protocol: types.ProtocolUniswapV2, }, }, }, { ID: "opp3", Type: OpportunityTypeTriangular, NetProfit: big.NewInt(2e18), // 2 ETH ROI: 0.20, Path: []*PathStep{ { PoolAddress: common.HexToAddress("0x4444"), Protocol: types.ProtocolUniswapV2, }, { PoolAddress: common.HexToAddress("0x5555"), Protocol: types.ProtocolUniswapV3, }, { PoolAddress: common.HexToAddress("0x6666"), Protocol: types.ProtocolCurve, }, }, }, } gasPrice := big.NewInt(1e9) // 1 gwei comparisons, err := ge.CompareGasCosts(ctx, opportunities, gasPrice) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(comparisons) != len(opportunities) { t.Errorf("got %d comparisons, want %d", len(comparisons), len(opportunities)) } for i, comp := range comparisons { t.Logf("Comparison %d: ID=%s, Type=%s, Hops=%d, Gas=%s, Profit=%s, ROI=%.2f%%, Efficiency=%.4f", i, comp.OpportunityID, comp.Type, comp.HopCount, comp.EstimatedGas.String(), comp.NetProfit.String(), comp.ROI*100, comp.Efficiency, ) if comp.OpportunityID == "" { t.Error("opportunity ID is empty") } if comp.EstimatedGas == nil || comp.EstimatedGas.Sign() <= 0 { t.Error("estimated gas is invalid") } if comp.Efficiency <= 0 { t.Error("efficiency should be positive for profitable opportunities") } } // Test GetMostEfficientOpportunity mostEfficient := ge.GetMostEfficientOpportunity(comparisons) if mostEfficient == nil { t.Fatal("most efficient opportunity is nil") } t.Logf("Most efficient: %s with efficiency %.4f", mostEfficient.OpportunityID, mostEfficient.Efficiency) // Verify it's actually the most efficient for _, comp := range comparisons { if comp.Efficiency > mostEfficient.Efficiency { t.Errorf("found more efficient opportunity: %s (%.4f) > %s (%.4f)", comp.OpportunityID, comp.Efficiency, mostEfficient.OpportunityID, mostEfficient.Efficiency, ) } } } func TestGasEstimator_GetMostEfficientOpportunity(t *testing.T) { ge := setupGasEstimatorTest(t) tests := []struct { name string comparisons []*GasCostComparison wantID string wantNil bool }{ { name: "empty list", comparisons: []*GasCostComparison{}, wantNil: true, }, { name: "single opportunity", comparisons: []*GasCostComparison{ { OpportunityID: "opp1", Efficiency: 1.5, }, }, wantID: "opp1", wantNil: false, }, { name: "multiple opportunities", comparisons: []*GasCostComparison{ { OpportunityID: "opp1", Efficiency: 1.5, }, { OpportunityID: "opp2", Efficiency: 2.8, // Most efficient }, { OpportunityID: "opp3", Efficiency: 1.2, }, }, wantID: "opp2", wantNil: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := ge.GetMostEfficientOpportunity(tt.comparisons) if tt.wantNil { if result != nil { t.Error("expected nil result") } return } if result == nil { t.Fatal("unexpected nil result") } if result.OpportunityID != tt.wantID { t.Errorf("got opportunity %s, want %s", result.OpportunityID, tt.wantID) } }) } } func TestDefaultGasEstimatorConfig(t *testing.T) { config := DefaultGasEstimatorConfig() if config.BaseGas != 21000 { t.Errorf("got BaseGas=%d, want 21000", config.BaseGas) } if config.GasPerPool != 10000 { t.Errorf("got GasPerPool=%d, want 10000", config.GasPerPool) } if config.V2SwapGas != 120000 { t.Errorf("got V2SwapGas=%d, want 120000", config.V2SwapGas) } if config.V3SwapGas != 180000 { t.Errorf("got V3SwapGas=%d, want 180000", config.V3SwapGas) } if config.CurveSwapGas != 150000 { t.Errorf("got CurveSwapGas=%d, want 150000", config.CurveSwapGas) } if config.GasPriceMultiplier != 1.1 { t.Errorf("got GasPriceMultiplier=%.2f, want 1.1", config.GasPriceMultiplier) } } func BenchmarkGasEstimator_EstimateGasCost(b *testing.B) { logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelError, })) ge := NewGasEstimator(nil, logger) ctx := context.Background() path := &Path{ Pools: []*types.PoolInfo{ {Protocol: types.ProtocolUniswapV2}, {Protocol: types.ProtocolUniswapV3}, {Protocol: types.ProtocolCurve}, }, } gasPrice := big.NewInt(1e9) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := ge.EstimateGasCost(ctx, path, gasPrice) if err != nil { b.Fatal(err) } } }