package arbitrage import ( "context" "log/slog" "math/big" "os" "testing" "github.com/ethereum/go-ethereum/common" "github.com/your-org/mev-bot/pkg/cache" "github.com/your-org/mev-bot/pkg/types" ) func setupPathFinderTest(t *testing.T) (*PathFinder, *cache.PoolCache) { t.Helper() logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelError, // Reduce noise in tests })) poolCache := cache.NewPoolCache() config := DefaultPathFinderConfig() pf := NewPathFinder(poolCache, config, logger) return pf, poolCache } func addTestPool(t *testing.T, cache *cache.PoolCache, address, token0, token1 string, protocol types.ProtocolType, liquidity int64) *types.PoolInfo { t.Helper() pool := &types.PoolInfo{ Address: common.HexToAddress(address), Protocol: protocol, PoolType: "constant-product", Token0: common.HexToAddress(token0), Token1: common.HexToAddress(token1), Token0Decimals: 18, Token1Decimals: 18, Token0Symbol: "TOKEN0", Token1Symbol: "TOKEN1", Reserve0: big.NewInt(liquidity), Reserve1: big.NewInt(liquidity), Liquidity: big.NewInt(liquidity), Fee: 30, // 0.3% IsActive: true, BlockNumber: 1000, LastUpdate: 1000, } err := cache.Add(context.Background(), pool) if err != nil { t.Fatalf("failed to add pool: %v", err) } return pool } func TestPathFinder_FindTwoPoolPaths(t *testing.T) { pf, cache := setupPathFinderTest(t) ctx := context.Background() tokenA := "0x1111111111111111111111111111111111111111" tokenB := "0x2222222222222222222222222222222222222222" // Add three pools for tokenA-tokenB with different liquidity pool1 := addTestPool(t, cache, "0xAAAA", tokenA, tokenB, types.ProtocolUniswapV2, 100000) pool2 := addTestPool(t, cache, "0xBBBB", tokenA, tokenB, types.ProtocolUniswapV3, 200000) pool3 := addTestPool(t, cache, "0xCCCC", tokenA, tokenB, types.ProtocolSushiSwap, 150000) tests := []struct { name string tokenA string tokenB string wantPathCount int wantError bool }{ { name: "valid two-pool arbitrage", tokenA: tokenA, tokenB: tokenB, wantPathCount: 6, // 3 pools = 3 pairs × 2 directions = 6 paths wantError: false, }, { name: "tokens with no pools", tokenA: "0x3333333333333333333333333333333333333333", tokenB: "0x4444444444444444444444444444444444444444", wantPathCount: 0, wantError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { paths, err := pf.FindTwoPoolPaths(ctx, common.HexToAddress(tt.tokenA), common.HexToAddress(tt.tokenB)) if tt.wantError { if err == nil { t.Errorf("expected error, got nil") } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if len(paths) != tt.wantPathCount { t.Errorf("got %d paths, want %d", len(paths), tt.wantPathCount) } // Validate path structure for i, path := range paths { if path.Type != OpportunityTypeTwoPool { t.Errorf("path %d: wrong type: got %s, want %s", i, path.Type, OpportunityTypeTwoPool) } if len(path.Tokens) != 3 { t.Errorf("path %d: got %d tokens, want 3", i, len(path.Tokens)) } if len(path.Pools) != 2 { t.Errorf("path %d: got %d pools, want 2", i, len(path.Pools)) } // First and last token should be the same (round trip) if path.Tokens[0] != path.Tokens[2] { t.Errorf("path %d: not a round trip: start=%s, end=%s", i, path.Tokens[0].Hex(), path.Tokens[2].Hex()) } } // Verify all pools are used poolsUsed := make(map[common.Address]bool) for _, path := range paths { for _, pool := range path.Pools { poolsUsed[pool.Address] = true } } if len(poolsUsed) != 3 { t.Errorf("expected all 3 pools to be used, got %d", len(poolsUsed)) } expectedPools := []common.Address{pool1.Address, pool2.Address, pool3.Address} for _, expected := range expectedPools { if !poolsUsed[expected] { t.Errorf("pool %s not used in any path", expected.Hex()) } } }) } } func TestPathFinder_FindTriangularPaths(t *testing.T) { pf, cache := setupPathFinderTest(t) ctx := context.Background() tokenA := "0x1111111111111111111111111111111111111111" // Starting token tokenB := "0x2222222222222222222222222222222222222222" tokenC := "0x3333333333333333333333333333333333333333" // Create triangular path: A-B, B-C, C-A addTestPool(t, cache, "0xAA11", tokenA, tokenB, types.ProtocolUniswapV2, 100000) addTestPool(t, cache, "0xBB22", tokenB, tokenC, types.ProtocolUniswapV3, 100000) addTestPool(t, cache, "0xCC33", tokenC, tokenA, types.ProtocolSushiSwap, 100000) // Add another triangular path: A-B (different pool), B-D, D-A tokenD := "0x4444444444444444444444444444444444444444" addTestPool(t, cache, "0xAA12", tokenA, tokenB, types.ProtocolUniswapV2, 100000) addTestPool(t, cache, "0xBB44", tokenB, tokenD, types.ProtocolUniswapV3, 100000) addTestPool(t, cache, "0xDD44", tokenD, tokenA, types.ProtocolSushiSwap, 100000) paths, err := pf.FindTriangularPaths(ctx, common.HexToAddress(tokenA)) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(paths) == 0 { t.Fatal("expected at least one triangular path") } // Validate path structure for i, path := range paths { if path.Type != OpportunityTypeTriangular { t.Errorf("path %d: wrong type: got %s, want %s", i, path.Type, OpportunityTypeTriangular) } if len(path.Tokens) != 4 { t.Errorf("path %d: got %d tokens, want 4", i, len(path.Tokens)) } if len(path.Pools) != 3 { t.Errorf("path %d: got %d pools, want 3", i, len(path.Pools)) } // First and last token should be tokenA if path.Tokens[0] != common.HexToAddress(tokenA) { t.Errorf("path %d: wrong start token: got %s, want %s", i, path.Tokens[0].Hex(), tokenA) } if path.Tokens[3] != common.HexToAddress(tokenA) { t.Errorf("path %d: wrong end token: got %s, want %s", i, path.Tokens[3].Hex(), tokenA) } // No duplicate tokens in the middle if path.Tokens[1] == path.Tokens[2] { t.Errorf("path %d: duplicate middle tokens", i) } } t.Logf("found %d triangular paths", len(paths)) } func TestPathFinder_FindMultiHopPaths(t *testing.T) { pf, cache := setupPathFinderTest(t) ctx := context.Background() tokenA := "0x1111111111111111111111111111111111111111" tokenB := "0x2222222222222222222222222222222222222222" tokenC := "0x3333333333333333333333333333333333333333" tokenD := "0x4444444444444444444444444444444444444444" // Create path: A → B → C → D addTestPool(t, cache, "0xAB11", tokenA, tokenB, types.ProtocolUniswapV2, 100000) addTestPool(t, cache, "0xBC22", tokenB, tokenC, types.ProtocolUniswapV3, 100000) addTestPool(t, cache, "0xCD33", tokenC, tokenD, types.ProtocolSushiSwap, 100000) // Add alternative path: A → B → D (shorter) addTestPool(t, cache, "0xBD44", tokenB, tokenD, types.ProtocolUniswapV2, 100000) tests := []struct { name string startToken string endToken string maxHops int wantPathCount int wantError bool }{ { name: "2-hop path", startToken: tokenA, endToken: tokenC, maxHops: 2, wantPathCount: 1, // A → B → C wantError: false, }, { name: "3-hop path with alternatives", startToken: tokenA, endToken: tokenD, maxHops: 3, wantPathCount: 2, // A → B → D (2 hops) and A → B → C → D (3 hops) wantError: false, }, { name: "invalid maxHops too small", startToken: tokenA, endToken: tokenD, maxHops: 1, wantPathCount: 0, wantError: true, }, { name: "invalid maxHops too large", startToken: tokenA, endToken: tokenD, maxHops: 10, wantPathCount: 0, wantError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { paths, err := pf.FindMultiHopPaths(ctx, common.HexToAddress(tt.startToken), common.HexToAddress(tt.endToken), tt.maxHops, ) if tt.wantError { if err == nil { t.Errorf("expected error, got nil") } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if len(paths) != tt.wantPathCount { t.Errorf("got %d paths, want %d", len(paths), tt.wantPathCount) } // Validate path structure for i, path := range paths { if path.Type != OpportunityTypeMultiHop { t.Errorf("path %d: wrong type: got %s, want %s", i, path.Type, OpportunityTypeMultiHop) } if len(path.Pools) > tt.maxHops { t.Errorf("path %d: too many hops: got %d, max %d", i, len(path.Pools), tt.maxHops) } if len(path.Tokens) != len(path.Pools)+1 { t.Errorf("path %d: token count mismatch: got %d tokens, %d pools", i, len(path.Tokens), len(path.Pools)) } // Verify start and end tokens if path.Tokens[0] != common.HexToAddress(tt.startToken) { t.Errorf("path %d: wrong start token: got %s, want %s", i, path.Tokens[0].Hex(), tt.startToken) } if path.Tokens[len(path.Tokens)-1] != common.HexToAddress(tt.endToken) { t.Errorf("path %d: wrong end token: got %s, want %s", i, path.Tokens[len(path.Tokens)-1].Hex(), tt.endToken) } // Verify pool connections for j := 0; j < len(path.Pools); j++ { pool := path.Pools[j] tokenIn := path.Tokens[j] tokenOut := path.Tokens[j+1] // Check that pool contains both tokens hasTokenIn := pool.Token0 == tokenIn || pool.Token1 == tokenIn hasTokenOut := pool.Token0 == tokenOut || pool.Token1 == tokenOut if !hasTokenIn { t.Errorf("path %d, pool %d: doesn't contain input token %s", i, j, tokenIn.Hex()) } if !hasTokenOut { t.Errorf("path %d, pool %d: doesn't contain output token %s", i, j, tokenOut.Hex()) } } } t.Logf("test %s: found %d paths", tt.name, len(paths)) }) } } func TestPathFinder_FilterPools(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelError, })) tests := []struct { name string config *PathFinderConfig pools []*types.PoolInfo wantFiltered int }{ { name: "filter by minimum liquidity", config: &PathFinderConfig{ MinLiquidity: big.NewInt(50000), AllowedProtocols: []types.ProtocolType{ types.ProtocolUniswapV2, types.ProtocolUniswapV3, }, }, pools: []*types.PoolInfo{ { Address: common.HexToAddress("0x1111"), Protocol: types.ProtocolUniswapV2, Liquidity: big.NewInt(100000), IsActive: true, }, { Address: common.HexToAddress("0x2222"), Protocol: types.ProtocolUniswapV2, Liquidity: big.NewInt(10000), // Too low IsActive: true, }, { Address: common.HexToAddress("0x3333"), Protocol: types.ProtocolUniswapV3, Liquidity: big.NewInt(75000), IsActive: true, }, }, wantFiltered: 2, // Only 2 pools meet liquidity requirement }, { name: "filter by protocol", config: &PathFinderConfig{ MinLiquidity: big.NewInt(0), AllowedProtocols: []types.ProtocolType{types.ProtocolUniswapV2}, }, pools: []*types.PoolInfo{ { Address: common.HexToAddress("0x1111"), Protocol: types.ProtocolUniswapV2, Liquidity: big.NewInt(100000), IsActive: true, }, { Address: common.HexToAddress("0x2222"), Protocol: types.ProtocolUniswapV3, // Not allowed Liquidity: big.NewInt(100000), IsActive: true, }, { Address: common.HexToAddress("0x3333"), Protocol: types.ProtocolSushiSwap, // Not allowed Liquidity: big.NewInt(100000), IsActive: true, }, }, wantFiltered: 1, // Only UniswapV2 pool }, { name: "filter inactive pools", config: &PathFinderConfig{ MinLiquidity: big.NewInt(0), AllowedProtocols: []types.ProtocolType{ types.ProtocolUniswapV2, }, }, pools: []*types.PoolInfo{ { Address: common.HexToAddress("0x1111"), Protocol: types.ProtocolUniswapV2, Liquidity: big.NewInt(100000), IsActive: true, }, { Address: common.HexToAddress("0x2222"), Protocol: types.ProtocolUniswapV2, Liquidity: big.NewInt(100000), IsActive: false, // Inactive }, }, wantFiltered: 1, // Only active pool }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { poolCache := cache.NewPoolCache() pf := NewPathFinder(poolCache, tt.config, logger) filtered := pf.filterPools(tt.pools) if len(filtered) != tt.wantFiltered { t.Errorf("got %d filtered pools, want %d", len(filtered), tt.wantFiltered) } }) } } func TestPathFinder_GetOtherToken(t *testing.T) { pf, _ := setupPathFinderTest(t) tokenA := common.HexToAddress("0x1111111111111111111111111111111111111111") tokenB := common.HexToAddress("0x2222222222222222222222222222222222222222") tokenC := common.HexToAddress("0x3333333333333333333333333333333333333333") pool := &types.PoolInfo{ Token0: tokenA, Token1: tokenB, } tests := []struct { name string inputToken common.Address wantToken common.Address }{ { name: "get token1 when input is token0", inputToken: tokenA, wantToken: tokenB, }, { name: "get token0 when input is token1", inputToken: tokenB, wantToken: tokenA, }, { name: "return zero address for unknown token", inputToken: tokenC, wantToken: common.Address{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := pf.getOtherToken(pool, tt.inputToken) if got != tt.wantToken { t.Errorf("got %s, want %s", got.Hex(), tt.wantToken.Hex()) } }) } } func TestPathFinder_GetPathSignature(t *testing.T) { pf, _ := setupPathFinderTest(t) pool1 := &types.PoolInfo{Address: common.HexToAddress("0xAAAA")} pool2 := &types.PoolInfo{Address: common.HexToAddress("0xBBBB")} pool3 := &types.PoolInfo{Address: common.HexToAddress("0xCCCC")} tests := []struct { name string pools []*types.PoolInfo wantSig string }{ { name: "single pool", pools: []*types.PoolInfo{pool1}, wantSig: "0x000000000000000000000000000000000000aaaa", }, { name: "two pools", pools: []*types.PoolInfo{pool1, pool2}, wantSig: "0x000000000000000000000000000000000000aaaa-0x000000000000000000000000000000000000bbbb", }, { name: "three pools", pools: []*types.PoolInfo{pool1, pool2, pool3}, wantSig: "0x000000000000000000000000000000000000aaaa-0x000000000000000000000000000000000000bbbb-0x000000000000000000000000000000000000cccc", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := pf.getPathSignature(tt.pools) if got != tt.wantSig { t.Errorf("got %s, want %s", got, tt.wantSig) } }) } } func TestDefaultPathFinderConfig(t *testing.T) { config := DefaultPathFinderConfig() if config.MaxHops != 4 { t.Errorf("got MaxHops=%d, want 4", config.MaxHops) } if config.MinLiquidity == nil { t.Fatal("MinLiquidity is nil") } expectedMinLiq := new(big.Int).Mul(big.NewInt(10000), new(big.Int).Exp(big.NewInt(10), big.NewInt(18), nil)) if config.MinLiquidity.Cmp(expectedMinLiq) != 0 { t.Errorf("got MinLiquidity=%s, want %s", config.MinLiquidity.String(), expectedMinLiq.String()) } if len(config.AllowedProtocols) == 0 { t.Error("AllowedProtocols is empty") } expectedProtocols := []types.ProtocolType{ types.ProtocolUniswapV2, types.ProtocolUniswapV3, types.ProtocolSushiSwap, types.ProtocolCurve, } for _, expected := range expectedProtocols { found := false for _, protocol := range config.AllowedProtocols { if protocol == expected { found = true break } } if !found { t.Errorf("missing protocol %s in AllowedProtocols", expected) } } if config.MaxPathsPerPair != 10 { t.Errorf("got MaxPathsPerPair=%d, want 10", config.MaxPathsPerPair) } }