package cache import ( "context" "math/big" "testing" "github.com/ethereum/go-ethereum/common" "github.com/your-org/mev-bot/pkg/types" ) func createTestPool(address, token0, token1 string, protocol types.ProtocolType) *types.PoolInfo { return &types.PoolInfo{ Address: common.HexToAddress(address), Protocol: protocol, Token0: common.HexToAddress(token0), Token1: common.HexToAddress(token1), Token0Decimals: 18, Token1Decimals: 6, Reserve0: big.NewInt(1000000), Reserve1: big.NewInt(500000), Liquidity: big.NewInt(100000), IsActive: true, } } func TestNewPoolCache(t *testing.T) { cache := NewPoolCache() if cache == nil { t.Fatal("NewPoolCache returned nil") } } func TestPoolCache_Add(t *testing.T) { ctx := context.Background() cache := NewPoolCache() pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) err := cache.Add(ctx, pool) if err != nil { t.Fatalf("Add() failed: %v", err) } // Verify it was added retrieved, err := cache.GetByAddress(ctx, pool.Address) if err != nil { t.Fatalf("GetByAddress() failed: %v", err) } if retrieved.Address != pool.Address { t.Errorf("Retrieved pool address = %v, want %v", retrieved.Address, pool.Address) } } func TestPoolCache_Add_NilPool(t *testing.T) { ctx := context.Background() cache := NewPoolCache() err := cache.Add(ctx, nil) if err == nil { t.Error("Add(nil) expected error, got nil") } } func TestPoolCache_Add_InvalidPool(t *testing.T) { ctx := context.Background() cache := NewPoolCache() // Pool with zero address pool := createTestPool("0x0000", "0x2222", "0x3333", types.ProtocolUniswapV2) pool.Address = common.Address{} err := cache.Add(ctx, pool) if err == nil { t.Error("Add(invalid pool) expected error, got nil") } } func TestPoolCache_Add_Update(t *testing.T) { ctx := context.Background() cache := NewPoolCache() pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) // Add pool err := cache.Add(ctx, pool) if err != nil { t.Fatalf("Add() failed: %v", err) } // Update pool (same address, different reserves) updatedPool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) updatedPool.Reserve0 = big.NewInt(2000000) err = cache.Add(ctx, updatedPool) if err != nil { t.Fatalf("Add(update) failed: %v", err) } // Verify it was updated retrieved, err := cache.GetByAddress(ctx, pool.Address) if err != nil { t.Fatalf("GetByAddress() failed: %v", err) } if retrieved.Reserve0.Cmp(big.NewInt(2000000)) != 0 { t.Errorf("Updated pool reserve0 = %v, want %v", retrieved.Reserve0, big.NewInt(2000000)) } } func TestPoolCache_GetByAddress(t *testing.T) { ctx := context.Background() cache := NewPoolCache() pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) cache.Add(ctx, pool) tests := []struct { name string address common.Address wantErr bool }{ { name: "existing pool", address: common.HexToAddress("0x1111"), wantErr: false, }, { name: "non-existing pool", address: common.HexToAddress("0x9999"), wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := cache.GetByAddress(ctx, tt.address) if tt.wantErr { if err == nil { t.Error("GetByAddress() expected error, got nil") } if got != nil { t.Error("GetByAddress() expected nil pool on error") } } else { if err != nil { t.Errorf("GetByAddress() unexpected error: %v", err) } if got == nil { t.Error("GetByAddress() returned nil pool") } } }) } } func TestPoolCache_GetByTokenPair(t *testing.T) { ctx := context.Background() cache := NewPoolCache() // Add pools with same token pair pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) pool2 := createTestPool("0x4444", "0x2222", "0x3333", types.ProtocolUniswapV3) cache.Add(ctx, pool1) cache.Add(ctx, pool2) tests := []struct { name string token0 common.Address token1 common.Address wantCount int }{ { name: "existing pair", token0: common.HexToAddress("0x2222"), token1: common.HexToAddress("0x3333"), wantCount: 2, }, { name: "existing pair (reversed)", token0: common.HexToAddress("0x3333"), token1: common.HexToAddress("0x2222"), wantCount: 2, }, { name: "non-existing pair", token0: common.HexToAddress("0x9999"), token1: common.HexToAddress("0x8888"), wantCount: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pools, err := cache.GetByTokenPair(ctx, tt.token0, tt.token1) if err != nil { t.Fatalf("GetByTokenPair() error: %v", err) } if len(pools) != tt.wantCount { t.Errorf("GetByTokenPair() count = %d, want %d", len(pools), tt.wantCount) } }) } } func TestPoolCache_GetByProtocol(t *testing.T) { ctx := context.Background() cache := NewPoolCache() // Add pools with different protocols pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV2) pool3 := createTestPool("0x7777", "0x8888", "0x9999", types.ProtocolUniswapV3) cache.Add(ctx, pool1) cache.Add(ctx, pool2) cache.Add(ctx, pool3) tests := []struct { name string protocol types.ProtocolType wantCount int }{ { name: "uniswap v2", protocol: types.ProtocolUniswapV2, wantCount: 2, }, { name: "uniswap v3", protocol: types.ProtocolUniswapV3, wantCount: 1, }, { name: "curve (none)", protocol: types.ProtocolCurve, wantCount: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pools, err := cache.GetByProtocol(ctx, tt.protocol) if err != nil { t.Fatalf("GetByProtocol() error: %v", err) } if len(pools) != tt.wantCount { t.Errorf("GetByProtocol() count = %d, want %d", len(pools), tt.wantCount) } }) } } func TestPoolCache_GetByLiquidity(t *testing.T) { ctx := context.Background() cache := NewPoolCache() // Add pools with different liquidity pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) pool1.Liquidity = big.NewInt(1000) pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV2) pool2.Liquidity = big.NewInt(5000) pool3 := createTestPool("0x7777", "0x8888", "0x9999", types.ProtocolUniswapV3) pool3.Liquidity = big.NewInt(10000) cache.Add(ctx, pool1) cache.Add(ctx, pool2) cache.Add(ctx, pool3) tests := []struct { name string minLiquidity *big.Int limit int wantCount int wantFirst *big.Int // Expected liquidity of first result }{ { name: "all pools", minLiquidity: big.NewInt(0), limit: 0, wantCount: 3, wantFirst: big.NewInt(10000), }, { name: "min 2000", minLiquidity: big.NewInt(2000), limit: 0, wantCount: 2, wantFirst: big.NewInt(10000), }, { name: "limit 2", minLiquidity: big.NewInt(0), limit: 2, wantCount: 2, wantFirst: big.NewInt(10000), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pools, err := cache.GetByLiquidity(ctx, tt.minLiquidity, tt.limit) if err != nil { t.Fatalf("GetByLiquidity() error: %v", err) } if len(pools) != tt.wantCount { t.Errorf("GetByLiquidity() count = %d, want %d", len(pools), tt.wantCount) } if len(pools) > 0 && pools[0].Liquidity.Cmp(tt.wantFirst) != 0 { t.Errorf("GetByLiquidity() first liquidity = %v, want %v", pools[0].Liquidity, tt.wantFirst) } }) } } func TestPoolCache_Update(t *testing.T) { ctx := context.Background() cache := NewPoolCache() pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) cache.Add(ctx, pool) // Update reserves err := cache.Update(ctx, pool.Address, func(p *types.PoolInfo) error { p.Reserve0 = big.NewInt(9999999) return nil }) if err != nil { t.Fatalf("Update() failed: %v", err) } // Verify update retrieved, err := cache.GetByAddress(ctx, pool.Address) if err != nil { t.Fatalf("GetByAddress() failed: %v", err) } if retrieved.Reserve0.Cmp(big.NewInt(9999999)) != 0 { t.Errorf("Update() reserve0 = %v, want %v", retrieved.Reserve0, big.NewInt(9999999)) } } func TestPoolCache_Update_NonExistent(t *testing.T) { ctx := context.Background() cache := NewPoolCache() err := cache.Update(ctx, common.HexToAddress("0x9999"), func(p *types.PoolInfo) error { return nil }) if err == nil { t.Error("Update(non-existent) expected error, got nil") } } func TestPoolCache_Update_Error(t *testing.T) { ctx := context.Background() cache := NewPoolCache() pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) cache.Add(ctx, pool) // Update with error testErr := fmt.Errorf("test error") err := cache.Update(ctx, pool.Address, func(p *types.PoolInfo) error { return testErr }) if err != testErr { t.Errorf("Update() error = %v, want %v", err, testErr) } } func TestPoolCache_Update_InvalidAfterUpdate(t *testing.T) { ctx := context.Background() cache := NewPoolCache() pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) cache.Add(ctx, pool) // Make pool invalid err := cache.Update(ctx, pool.Address, func(p *types.PoolInfo) error { p.Token0Decimals = 0 // Invalid return nil }) if err == nil { t.Error("Update(invalid) expected error, got nil") } } func TestPoolCache_Remove(t *testing.T) { ctx := context.Background() cache := NewPoolCache() pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) cache.Add(ctx, pool) // Remove pool err := cache.Remove(ctx, pool.Address) if err != nil { t.Fatalf("Remove() failed: %v", err) } // Verify removal _, err = cache.GetByAddress(ctx, pool.Address) if err == nil { t.Error("GetByAddress() after Remove() expected error, got nil") } // Verify removed from token pair index pools, _ := cache.GetByTokenPair(ctx, pool.Token0, pool.Token1) if len(pools) != 0 { t.Error("Pool still in token pair index after removal") } // Verify removed from protocol index pools, _ = cache.GetByProtocol(ctx, pool.Protocol) if len(pools) != 0 { t.Error("Pool still in protocol index after removal") } } func TestPoolCache_Remove_NonExistent(t *testing.T) { ctx := context.Background() cache := NewPoolCache() err := cache.Remove(ctx, common.HexToAddress("0x9999")) if err == nil { t.Error("Remove(non-existent) expected error, got nil") } } func TestPoolCache_Count(t *testing.T) { ctx := context.Background() cache := NewPoolCache() // Initially empty count, err := cache.Count(ctx) if err != nil { t.Fatalf("Count() failed: %v", err) } if count != 0 { t.Errorf("Count() = %d, want 0", count) } // Add pools pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV3) cache.Add(ctx, pool1) cache.Add(ctx, pool2) count, err = cache.Count(ctx) if err != nil { t.Fatalf("Count() failed: %v", err) } if count != 2 { t.Errorf("Count() = %d, want 2", count) } } func TestPoolCache_Clear(t *testing.T) { ctx := context.Background() cache := NewPoolCache() // Add pools pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV3) cache.Add(ctx, pool1) cache.Add(ctx, pool2) // Clear cache err := cache.Clear(ctx) if err != nil { t.Fatalf("Clear() failed: %v", err) } // Verify cleared count, _ := cache.Count(ctx) if count != 0 { t.Errorf("Count() after Clear() = %d, want 0", count) } } func Test_makeTokenPairKey(t *testing.T) { token0 := common.HexToAddress("0x1111") token1 := common.HexToAddress("0x2222") // Should be consistent regardless of order key1 := makeTokenPairKey(token0, token1) key2 := makeTokenPairKey(token1, token0) if key1 != key2 { t.Errorf("makeTokenPairKey() not consistent: %s != %s", key1, key2) } } func Test_removePoolFromSlice(t *testing.T) { pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV3) pool3 := createTestPool("0x7777", "0x8888", "0x9999", types.ProtocolCurve) pools := []*types.PoolInfo{pool1, pool2, pool3} // Remove middle pool result := removePoolFromSlice(pools, pool2.Address) if len(result) != 2 { t.Errorf("removePoolFromSlice() length = %d, want 2", len(result)) } if result[0].Address != pool1.Address || result[1].Address != pool3.Address { t.Error("removePoolFromSlice() incorrect pools remaining") } // Remove non-existent pool result = removePoolFromSlice(result, common.HexToAddress("0x9999")) if len(result) != 2 { t.Errorf("removePoolFromSlice(non-existent) length = %d, want 2", len(result)) } }