package arbitrage import ( "context" "log/slog" "math/big" "os" "testing" "time" "github.com/ethereum/go-ethereum/common" "github.com/your-org/mev-bot/pkg/cache" mevtypes "github.com/your-org/mev-bot/pkg/types" ) func setupDetectorTest(t *testing.T) (*Detector, *cache.PoolCache) { t.Helper() logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelError, })) poolCache := cache.NewPoolCache() // Create components pathFinderConfig := DefaultPathFinderConfig() pathFinder := NewPathFinder(poolCache, pathFinderConfig, logger) gasEstimator := NewGasEstimator(nil, logger) calculatorConfig := DefaultCalculatorConfig() calculator := NewCalculator(calculatorConfig, gasEstimator, logger) detectorConfig := DefaultDetectorConfig() detector := NewDetector(detectorConfig, pathFinder, calculator, poolCache, logger) return detector, poolCache } func addTestPoolsForArbitrage(t *testing.T, cache *cache.PoolCache) (common.Address, common.Address) { t.Helper() ctx := context.Background() tokenA := common.HexToAddress("0x1111111111111111111111111111111111111111") tokenB := common.HexToAddress("0x2222222222222222222222222222222222222222") // Add two pools with different prices for arbitrage pool1 := &mevtypes.PoolInfo{ Address: common.HexToAddress("0xAAAA"), Protocol: mevtypes.ProtocolUniswapV2, PoolType: "constant-product", Token0: tokenA, Token1: tokenB, Token0Decimals: 18, Token1Decimals: 18, Reserve0: new(big.Int).Mul(big.NewInt(1000000), big.NewInt(1e18)), Reserve1: new(big.Int).Mul(big.NewInt(1100000), big.NewInt(1e18)), // Higher price Liquidity: new(big.Int).Mul(big.NewInt(1000000), big.NewInt(1e18)), Fee: 30, IsActive: true, BlockNumber: 1000, } pool2 := &mevtypes.PoolInfo{ Address: common.HexToAddress("0xBBBB"), Protocol: mevtypes.ProtocolUniswapV3, PoolType: "constant-product", Token0: tokenA, Token1: tokenB, Token0Decimals: 18, Token1Decimals: 18, Reserve0: new(big.Int).Mul(big.NewInt(1000000), big.NewInt(1e18)), Reserve1: new(big.Int).Mul(big.NewInt(900000), big.NewInt(1e18)), // Lower price Liquidity: new(big.Int).Mul(big.NewInt(1000000), big.NewInt(1e18)), Fee: 30, IsActive: true, BlockNumber: 1000, } err := cache.Add(ctx, pool1) if err != nil { t.Fatalf("failed to add pool1: %v", err) } err = cache.Add(ctx, pool2) if err != nil { t.Fatalf("failed to add pool2: %v", err) } return tokenA, tokenB } func TestDetector_DetectOpportunities(t *testing.T) { detector, poolCache := setupDetectorTest(t) ctx := context.Background() tokenA, _ := addTestPoolsForArbitrage(t, poolCache) tests := []struct { name string token common.Address wantError bool wantOppMin int }{ { name: "detect opportunities for token", token: tokenA, wantError: false, wantOppMin: 0, // May or may not find profitable opportunities }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opportunities, err := detector.DetectOpportunities(ctx, tt.token) if tt.wantError { if err == nil { t.Error("expected error, got nil") } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if opportunities == nil { t.Fatal("opportunities is nil") } if len(opportunities) < tt.wantOppMin { t.Errorf("got %d opportunities, want at least %d", len(opportunities), tt.wantOppMin) } t.Logf("Found %d opportunities", len(opportunities)) // Validate each opportunity for i, opp := range opportunities { if opp.ID == "" { t.Errorf("opportunity %d has empty ID", i) } if !opp.IsProfitable() { t.Errorf("opportunity %d is not profitable: netProfit=%s", i, opp.NetProfit.String()) } if !opp.CanExecute() { t.Errorf("opportunity %d cannot be executed", i) } t.Logf("Opportunity %d: type=%s, profit=%s, roi=%.2f%%, hops=%d", i, opp.Type, opp.NetProfit.String(), opp.ROI*100, len(opp.Path)) } }) } } func TestDetector_DetectOpportunitiesForSwap(t *testing.T) { detector, poolCache := setupDetectorTest(t) ctx := context.Background() tokenA, tokenB := addTestPoolsForArbitrage(t, poolCache) swapEvent := &mevtypes.SwapEvent{ PoolAddress: common.HexToAddress("0xAAAA"), Protocol: mevtypes.ProtocolUniswapV2, TokenIn: tokenA, TokenOut: tokenB, AmountIn: big.NewInt(1e18), AmountOut: big.NewInt(1e18), BlockNumber: 1000, TxHash: common.HexToHash("0x1234"), } opportunities, err := detector.DetectOpportunitiesForSwap(ctx, swapEvent) if err != nil { t.Fatalf("unexpected error: %v", err) } if opportunities == nil { t.Fatal("opportunities is nil") } t.Logf("Found %d opportunities from swap event", len(opportunities)) } func TestDetector_DetectBetweenTokens(t *testing.T) { detector, poolCache := setupDetectorTest(t) ctx := context.Background() tokenA, tokenB := addTestPoolsForArbitrage(t, poolCache) opportunities, err := detector.DetectBetweenTokens(ctx, tokenA, tokenB) if err != nil { t.Fatalf("unexpected error: %v", err) } if opportunities == nil { t.Fatal("opportunities is nil") } t.Logf("Found %d opportunities between tokens", len(opportunities)) } func TestDetector_FilterProfitable(t *testing.T) { detector, _ := setupDetectorTest(t) opportunities := []*Opportunity{ { ID: "opp1", NetProfit: big.NewInt(1e18), // Profitable ROI: 0.10, Executable: true, }, { ID: "opp2", NetProfit: big.NewInt(-1e17), // Not profitable ROI: -0.05, Executable: false, }, { ID: "opp3", NetProfit: big.NewInt(5e17), // Profitable ROI: 0.05, Executable: true, }, { ID: "opp4", NetProfit: big.NewInt(1e16), // Too small ROI: 0.01, Executable: false, }, } profitable := detector.filterProfitable(opportunities) if len(profitable) != 2 { t.Errorf("got %d profitable opportunities, want 2", len(profitable)) } // Verify all filtered opportunities are profitable for i, opp := range profitable { if !opp.IsProfitable() { t.Errorf("opportunity %d is not profitable", i) } if !opp.CanExecute() { t.Errorf("opportunity %d cannot be executed", i) } } } func TestDetector_IsTokenWhitelisted(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelError, })) tokenA := common.HexToAddress("0x1111") tokenB := common.HexToAddress("0x2222") tokenC := common.HexToAddress("0x3333") tests := []struct { name string whitelistedTokens []common.Address token common.Address wantWhitelisted bool }{ { name: "no whitelist - all allowed", whitelistedTokens: []common.Address{}, token: tokenA, wantWhitelisted: true, }, { name: "token in whitelist", whitelistedTokens: []common.Address{tokenA, tokenB}, token: tokenA, wantWhitelisted: true, }, { name: "token not in whitelist", whitelistedTokens: []common.Address{tokenA, tokenB}, token: tokenC, wantWhitelisted: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { config := DefaultDetectorConfig() config.WhitelistedTokens = tt.whitelistedTokens detector := NewDetector(config, nil, nil, nil, logger) whitelisted := detector.isTokenWhitelisted(tt.token) if whitelisted != tt.wantWhitelisted { t.Errorf("got whitelisted=%v, want %v", whitelisted, tt.wantWhitelisted) } }) } } func TestDetector_UpdateStats(t *testing.T) { detector, _ := setupDetectorTest(t) opportunities := []*Opportunity{ { ID: "opp1", NetProfit: big.NewInt(1e18), ROI: 0.10, Executable: true, }, { ID: "opp2", NetProfit: big.NewInt(5e17), ROI: 0.05, Executable: true, }, { ID: "opp3", NetProfit: big.NewInt(-1e17), // Unprofitable ROI: -0.05, Executable: false, }, } detector.updateStats(opportunities) stats := detector.GetStats() if stats.TotalDetected != 3 { t.Errorf("got TotalDetected=%d, want 3", stats.TotalDetected) } if stats.TotalProfitable != 2 { t.Errorf("got TotalProfitable=%d, want 2", stats.TotalProfitable) } if stats.TotalExecutable != 2 { t.Errorf("got TotalExecutable=%d, want 2", stats.TotalExecutable) } if stats.MaxProfit == nil { t.Fatal("MaxProfit is nil") } expectedMaxProfit := big.NewInt(1e18) if stats.MaxProfit.Cmp(expectedMaxProfit) != 0 { t.Errorf("got MaxProfit=%s, want %s", stats.MaxProfit.String(), expectedMaxProfit.String()) } if stats.TotalProfit == nil { t.Fatal("TotalProfit is nil") } expectedTotalProfit := new(big.Int).Add( new(big.Int).Add(big.NewInt(1e18), big.NewInt(5e17)), big.NewInt(-1e17), ) if stats.TotalProfit.Cmp(expectedTotalProfit) != 0 { t.Errorf("got TotalProfit=%s, want %s", stats.TotalProfit.String(), expectedTotalProfit.String()) } t.Logf("Stats: detected=%d, profitable=%d, executable=%d, maxProfit=%s", stats.TotalDetected, stats.TotalProfitable, stats.TotalExecutable, stats.MaxProfit.String(), ) } func TestDetector_RankOpportunities(t *testing.T) { detector, _ := setupDetectorTest(t) opportunities := []*Opportunity{ {ID: "opp1", Priority: 50}, {ID: "opp2", Priority: 200}, {ID: "opp3", Priority: 100}, {ID: "opp4", Priority: 150}, } ranked := detector.RankOpportunities(opportunities) if len(ranked) != len(opportunities) { t.Errorf("got %d ranked opportunities, want %d", len(ranked), len(opportunities)) } // Verify descending order for i := 0; i < len(ranked)-1; i++ { if ranked[i].Priority < ranked[i+1].Priority { t.Errorf("opportunities not sorted: rank[%d].Priority=%d < rank[%d].Priority=%d", i, ranked[i].Priority, i+1, ranked[i+1].Priority) } } // Verify highest priority is first if ranked[0].ID != "opp2" { t.Errorf("highest priority opportunity is %s, want opp2", ranked[0].ID) } t.Logf("Ranked opportunities: %v", []int{ranked[0].Priority, ranked[1].Priority, ranked[2].Priority, ranked[3].Priority}) } func TestDetector_OpportunityStream(t *testing.T) { detector, _ := setupDetectorTest(t) // Get the stream channel stream := detector.OpportunityStream() if stream == nil { t.Fatal("opportunity stream is nil") } // Create test opportunities opp1 := &Opportunity{ ID: "opp1", NetProfit: big.NewInt(1e18), } opp2 := &Opportunity{ ID: "opp2", NetProfit: big.NewInt(5e17), } // Publish opportunities detector.PublishOpportunity(opp1) detector.PublishOpportunity(opp2) // Read from stream received1 := <-stream if received1.ID != opp1.ID { t.Errorf("got opportunity %s, want %s", received1.ID, opp1.ID) } received2 := <-stream if received2.ID != opp2.ID { t.Errorf("got opportunity %s, want %s", received2.ID, opp2.ID) } t.Log("Successfully published and received opportunities via stream") } func TestDetector_MonitorSwaps(t *testing.T) { detector, poolCache := setupDetectorTest(t) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() tokenA, tokenB := addTestPoolsForArbitrage(t, poolCache) // Create swap channel swapCh := make(chan *mevtypes.SwapEvent, 10) // Start monitoring in background go detector.MonitorSwaps(ctx, swapCh) // Send a test swap swap := &mevtypes.SwapEvent{ PoolAddress: common.HexToAddress("0xAAAA"), Protocol: mevtypes.ProtocolUniswapV2, TokenIn: tokenA, TokenOut: tokenB, AmountIn: big.NewInt(1e18), AmountOut: big.NewInt(1e18), BlockNumber: 1000, TxHash: common.HexToHash("0x1234"), } swapCh <- swap // Wait a bit for processing time.Sleep(500 * time.Millisecond) // Close swap channel close(swapCh) // Wait for context to timeout <-ctx.Done() t.Log("Swap monitoring completed") } func TestDetector_ScanForOpportunities(t *testing.T) { detector, poolCache := setupDetectorTest(t) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() tokenA, tokenB := addTestPoolsForArbitrage(t, poolCache) tokens := []common.Address{tokenA, tokenB} interval := 500 * time.Millisecond // Start scanning in background go detector.ScanForOpportunities(ctx, interval, tokens) // Wait for context to timeout <-ctx.Done() t.Log("Opportunity scanning completed") } func TestDefaultDetectorConfig(t *testing.T) { config := DefaultDetectorConfig() if config.MaxPathsToEvaluate != 50 { t.Errorf("got MaxPathsToEvaluate=%d, want 50", config.MaxPathsToEvaluate) } if config.EvaluationTimeout != 5*time.Second { t.Errorf("got EvaluationTimeout=%v, want 5s", config.EvaluationTimeout) } if config.MinInputAmount == nil { t.Fatal("MinInputAmount is nil") } expectedMinInput := new(big.Int).Mul(big.NewInt(1), big.NewInt(1e17)) if config.MinInputAmount.Cmp(expectedMinInput) != 0 { t.Errorf("got MinInputAmount=%s, want %s", config.MinInputAmount.String(), expectedMinInput.String()) } if config.MaxInputAmount == nil { t.Fatal("MaxInputAmount is nil") } expectedMaxInput := new(big.Int).Mul(big.NewInt(10), big.NewInt(1e18)) if config.MaxInputAmount.Cmp(expectedMaxInput) != 0 { t.Errorf("got MaxInputAmount=%s, want %s", config.MaxInputAmount.String(), expectedMaxInput.String()) } if !config.OptimizeInput { t.Error("OptimizeInput should be true") } if config.DefaultGasPrice == nil { t.Fatal("DefaultGasPrice is nil") } if config.DefaultGasPrice.Cmp(big.NewInt(1e9)) != 0 { t.Errorf("got DefaultGasPrice=%s, want 1000000000", config.DefaultGasPrice.String()) } if config.MaxConcurrentEvaluations != 10 { t.Errorf("got MaxConcurrentEvaluations=%d, want 10", config.MaxConcurrentEvaluations) } if len(config.WhitelistedTokens) != 0 { t.Errorf("got %d whitelisted tokens, want 0 (empty)", len(config.WhitelistedTokens)) } }