package arbitrage import ( "context" "log/slog" "math/big" "os" "testing" "github.com/ethereum/go-ethereum/common" "github.com/your-org/mev-bot/pkg/types" ) func setupCalculatorTest(t *testing.T) *Calculator { t.Helper() logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelError, })) gasEstimator := NewGasEstimator(nil, logger) config := DefaultCalculatorConfig() calc := NewCalculator(config, gasEstimator, logger) return calc } func createTestPath(t *testing.T, poolType types.ProtocolType, tokenA, tokenB string) *Path { t.Helper() pool := &types.PoolInfo{ Address: common.HexToAddress("0xABCD"), Protocol: poolType, PoolType: "constant-product", Token0: common.HexToAddress(tokenA), Token1: common.HexToAddress(tokenB), Token0Decimals: 18, Token1Decimals: 18, Reserve0: new(big.Int).Mul(big.NewInt(1000000), big.NewInt(1e18)), Reserve1: new(big.Int).Mul(big.NewInt(1000000), big.NewInt(1e18)), Liquidity: new(big.Int).Mul(big.NewInt(1000000), big.NewInt(1e18)), Fee: 30, // 0.3% IsActive: true, BlockNumber: 1000, } return &Path{ Tokens: []common.Address{ common.HexToAddress(tokenA), common.HexToAddress(tokenB), }, Pools: []*types.PoolInfo{pool}, Type: OpportunityTypeTwoPool, } } func TestCalculator_CalculateProfitability(t *testing.T) { calc := setupCalculatorTest(t) ctx := context.Background() tokenA := "0x1111111111111111111111111111111111111111" tokenB := "0x2222222222222222222222222222222222222222" tests := []struct { name string path *Path inputAmount *big.Int gasPrice *big.Int wantError bool }{ { name: "valid V2 swap", path: createTestPath(t, types.ProtocolUniswapV2, tokenA, tokenB), inputAmount: big.NewInt(1e18), // 1 token gasPrice: big.NewInt(1e9), // 1 gwei wantError: false, }, { name: "empty path", path: &Path{Pools: []*types.PoolInfo{}}, inputAmount: big.NewInt(1e18), gasPrice: big.NewInt(1e9), wantError: true, }, { name: "zero input amount", path: createTestPath(t, types.ProtocolUniswapV2, tokenA, tokenB), inputAmount: big.NewInt(0), gasPrice: big.NewInt(1e9), wantError: true, }, { name: "nil input amount", path: createTestPath(t, types.ProtocolUniswapV2, tokenA, tokenB), inputAmount: nil, gasPrice: big.NewInt(1e9), wantError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opp, err := calc.CalculateProfitability(ctx, tt.path, tt.inputAmount, tt.gasPrice) if tt.wantError { if err == nil { t.Error("expected error, got nil") } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if opp == nil { t.Fatal("expected opportunity, got nil") } // Validate opportunity fields if opp.ID == "" { t.Error("opportunity ID is empty") } if len(opp.Path) != len(tt.path.Pools) { t.Errorf("got %d path steps, want %d", len(opp.Path), len(tt.path.Pools)) } if opp.InputAmount.Cmp(tt.inputAmount) != 0 { t.Errorf("input amount mismatch: got %s, want %s", opp.InputAmount.String(), tt.inputAmount.String()) } if opp.OutputAmount == nil { t.Error("output amount is nil") } if opp.GasCost == nil { t.Error("gas cost is nil") } if opp.NetProfit == nil { t.Error("net profit is nil") } // Verify calculations expectedGrossProfit := new(big.Int).Sub(opp.OutputAmount, opp.InputAmount) if opp.GrossProfit.Cmp(expectedGrossProfit) != 0 { t.Errorf("gross profit mismatch: got %s, want %s", opp.GrossProfit.String(), expectedGrossProfit.String()) } expectedNetProfit := new(big.Int).Sub(opp.GrossProfit, opp.GasCost) if opp.NetProfit.Cmp(expectedNetProfit) != 0 { t.Errorf("net profit mismatch: got %s, want %s", opp.NetProfit.String(), expectedNetProfit.String()) } t.Logf("Opportunity: input=%s, output=%s, grossProfit=%s, gasCost=%s, netProfit=%s, roi=%.2f%%, priceImpact=%.2f%%", opp.InputAmount.String(), opp.OutputAmount.String(), opp.GrossProfit.String(), opp.GasCost.String(), opp.NetProfit.String(), opp.ROI*100, opp.PriceImpact*100, ) }) } } func TestCalculator_CalculateSwapOutputV2(t *testing.T) { calc := setupCalculatorTest(t) tokenA := common.HexToAddress("0x1111") tokenB := common.HexToAddress("0x2222") pool := &types.PoolInfo{ Protocol: types.ProtocolUniswapV2, Token0: tokenA, Token1: tokenB, Token0Decimals: 18, Token1Decimals: 18, Reserve0: big.NewInt(1000000e18), // 1M tokens Reserve1: big.NewInt(1000000e18), // 1M tokens Fee: 30, // 0.3% } tests := []struct { name string pool *types.PoolInfo tokenIn common.Address tokenOut common.Address amountIn *big.Int wantError bool checkOutput bool }{ { name: "valid swap token0 → token1", pool: pool, tokenIn: tokenA, tokenOut: tokenB, amountIn: big.NewInt(1000e18), // 1000 tokens wantError: false, checkOutput: true, }, { name: "valid swap token1 → token0", pool: pool, tokenIn: tokenB, tokenOut: tokenA, amountIn: big.NewInt(1000e18), wantError: false, checkOutput: true, }, { name: "pool with nil reserves", pool: &types.PoolInfo{ Protocol: types.ProtocolUniswapV2, Token0: tokenA, Token1: tokenB, Token0Decimals: 18, Token1Decimals: 18, Reserve0: nil, Reserve1: nil, Fee: 30, }, tokenIn: tokenA, tokenOut: tokenB, amountIn: big.NewInt(1000e18), wantError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { amountOut, priceImpact, err := calc.calculateSwapOutputV2(tt.pool, tt.tokenIn, tt.tokenOut, tt.amountIn) if tt.wantError { if err == nil { t.Error("expected error, got nil") } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if amountOut == nil { t.Fatal("amount out is nil") } if amountOut.Sign() <= 0 { t.Error("amount out is not positive") } if priceImpact < 0 || priceImpact > 1 { t.Errorf("price impact out of range: %f", priceImpact) } if tt.checkOutput { // For equal reserves, output should be slightly less than input due to fees expectedMin := new(big.Int).Mul(tt.amountIn, big.NewInt(99)) expectedMin.Div(expectedMin, big.NewInt(100)) if amountOut.Cmp(expectedMin) < 0 { t.Errorf("output too low: got %s, want at least %s", amountOut.String(), expectedMin.String()) } if amountOut.Cmp(tt.amountIn) >= 0 { t.Errorf("output should be less than input due to fees: got %s, input %s", amountOut.String(), tt.amountIn.String()) } } t.Logf("Swap: in=%s, out=%s, impact=%.4f%%", tt.amountIn.String(), amountOut.String(), priceImpact*100) }) } } func TestCalculator_CalculatePriceImpactV2(t *testing.T) { calc := setupCalculatorTest(t) reserveIn := big.NewInt(1000000e18) reserveOut := big.NewInt(1000000e18) tests := []struct { name string amountIn *big.Int amountOut *big.Int wantImpactMin float64 wantImpactMax float64 }{ { name: "small swap", amountIn: big.NewInt(100e18), amountOut: big.NewInt(99e18), wantImpactMin: 0.0, wantImpactMax: 0.01, // < 1% }, { name: "medium swap", amountIn: big.NewInt(10000e18), amountOut: big.NewInt(9900e18), wantImpactMin: 0.0, wantImpactMax: 0.05, // < 5% }, { name: "large swap", amountIn: big.NewInt(100000e18), amountOut: big.NewInt(90000e18), wantImpactMin: 0.05, wantImpactMax: 0.20, // 5-20% }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { impact := calc.calculatePriceImpactV2(reserveIn, reserveOut, tt.amountIn, tt.amountOut) if impact < tt.wantImpactMin || impact > tt.wantImpactMax { t.Errorf("price impact %.4f%% not in range [%.4f%%, %.4f%%]", impact*100, tt.wantImpactMin*100, tt.wantImpactMax*100) } t.Logf("Swap size: %.0f%% of reserves, Impact: %.4f%%", float64(tt.amountIn.Int64())/float64(reserveIn.Int64())*100, impact*100, ) }) } } func TestCalculator_CalculateFeeAmount(t *testing.T) { calc := setupCalculatorTest(t) tests := []struct { name string amountIn *big.Int feeBasisPoints uint32 protocol types.ProtocolType expectedFee *big.Int }{ { name: "0.3% fee", amountIn: big.NewInt(1000e18), feeBasisPoints: 30, protocol: types.ProtocolUniswapV2, expectedFee: big.NewInt(3e18), // 1000 * 0.003 = 3 }, { name: "0.05% fee", amountIn: big.NewInt(1000e18), feeBasisPoints: 5, protocol: types.ProtocolUniswapV3, expectedFee: big.NewInt(5e17), // 1000 * 0.0005 = 0.5 }, { name: "zero fee", amountIn: big.NewInt(1000e18), feeBasisPoints: 0, protocol: types.ProtocolUniswapV2, expectedFee: big.NewInt(0), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fee := calc.calculateFeeAmount(tt.amountIn, tt.feeBasisPoints, tt.protocol) if fee.Cmp(tt.expectedFee) != 0 { t.Errorf("got fee %s, want %s", fee.String(), tt.expectedFee.String()) } }) } } func TestCalculator_CalculatePriority(t *testing.T) { calc := setupCalculatorTest(t) tests := []struct { name string netProfit *big.Int roi float64 wantPriority int }{ { name: "high profit, high ROI", netProfit: new(big.Int).Mul(big.NewInt(1), big.NewInt(1e18)), // 1 ETH roi: 0.50, // 50% wantPriority: 600, // 100 + 500 }, { name: "medium profit, medium ROI", netProfit: new(big.Int).Mul(big.NewInt(5), big.NewInt(1e17)), // 0.5 ETH roi: 0.20, // 20% wantPriority: 250, // 50 + 200 }, { name: "low profit, low ROI", netProfit: new(big.Int).Mul(big.NewInt(1), big.NewInt(1e16)), // 0.01 ETH roi: 0.05, // 5% wantPriority: 51, // 1 + 50 }, { name: "negative profit", netProfit: big.NewInt(-1e18), roi: -0.10, wantPriority: -100, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { priority := calc.calculatePriority(tt.netProfit, tt.roi) if priority != tt.wantPriority { t.Errorf("got priority %d, want %d", priority, tt.wantPriority) } }) } } func TestCalculator_IsExecutable(t *testing.T) { calc := setupCalculatorTest(t) minProfit := new(big.Int).Mul(big.NewInt(5), big.NewInt(1e16)) // 0.05 ETH calc.config.MinProfitWei = minProfit calc.config.MinROI = 0.05 // 5% calc.config.MaxPriceImpact = 0.10 // 10% tests := []struct { name string netProfit *big.Int roi float64 priceImpact float64 wantExecutable bool }{ { name: "meets all criteria", netProfit: new(big.Int).Mul(big.NewInt(1), big.NewInt(1e17)), // 0.1 ETH roi: 0.10, // 10% priceImpact: 0.05, // 5% wantExecutable: true, }, { name: "profit too low", netProfit: big.NewInt(1e16), // 0.01 ETH roi: 0.10, priceImpact: 0.05, wantExecutable: false, }, { name: "ROI too low", netProfit: new(big.Int).Mul(big.NewInt(1), big.NewInt(1e17)), roi: 0.02, // 2% priceImpact: 0.05, wantExecutable: false, }, { name: "price impact too high", netProfit: new(big.Int).Mul(big.NewInt(1), big.NewInt(1e17)), roi: 0.10, priceImpact: 0.15, // 15% wantExecutable: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { executable := calc.isExecutable(tt.netProfit, tt.roi, tt.priceImpact) if executable != tt.wantExecutable { t.Errorf("got executable=%v, want %v", executable, tt.wantExecutable) } }) } } func TestDefaultCalculatorConfig(t *testing.T) { config := DefaultCalculatorConfig() if config.MinProfitWei == nil { t.Fatal("MinProfitWei is nil") } expectedMinProfit := new(big.Int).Mul(big.NewInt(5), new(big.Int).Exp(big.NewInt(10), big.NewInt(16), nil)) if config.MinProfitWei.Cmp(expectedMinProfit) != 0 { t.Errorf("got MinProfitWei=%s, want %s", config.MinProfitWei.String(), expectedMinProfit.String()) } if config.MinROI != 0.05 { t.Errorf("got MinROI=%.4f, want 0.05", config.MinROI) } if config.MaxPriceImpact != 0.10 { t.Errorf("got MaxPriceImpact=%.4f, want 0.10", config.MaxPriceImpact) } if config.MaxGasPriceGwei != 100 { t.Errorf("got MaxGasPriceGwei=%d, want 100", config.MaxGasPriceGwei) } if config.SlippageTolerance != 0.005 { t.Errorf("got SlippageTolerance=%.4f, want 0.005", config.SlippageTolerance) } }