diff --git a/pkg/cache/pool_cache.go b/pkg/cache/pool_cache.go new file mode 100644 index 0000000..d8df870 --- /dev/null +++ b/pkg/cache/pool_cache.go @@ -0,0 +1,255 @@ +package cache + +import ( + "context" + "fmt" + "math/big" + "sort" + "sync" + + "github.com/ethereum/go-ethereum/common" + + "github.com/your-org/mev-bot/pkg/types" +) + +// poolCache implements the PoolCache interface with multi-index support +type poolCache struct { + // Primary index: address -> pool + byAddress map[common.Address]*types.PoolInfo + + // Secondary index: token pair -> pools + byTokenPair map[string][]*types.PoolInfo + + // Tertiary index: protocol -> pools + byProtocol map[types.ProtocolType][]*types.PoolInfo + + // Mutex for thread safety + mu sync.RWMutex +} + +// NewPoolCache creates a new multi-index pool cache +func NewPoolCache() PoolCache { + return &poolCache{ + byAddress: make(map[common.Address]*types.PoolInfo), + byTokenPair: make(map[string][]*types.PoolInfo), + byProtocol: make(map[types.ProtocolType][]*types.PoolInfo), + } +} + +// GetByAddress retrieves a pool by its contract address +func (c *poolCache) GetByAddress(ctx context.Context, address common.Address) (*types.PoolInfo, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + pool, exists := c.byAddress[address] + if !exists { + return nil, types.ErrPoolNotFound + } + + return pool, nil +} + +// GetByTokenPair retrieves all pools for a given token pair +func (c *poolCache) GetByTokenPair(ctx context.Context, token0, token1 common.Address) ([]*types.PoolInfo, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + key := makeTokenPairKey(token0, token1) + pools := c.byTokenPair[key] + + if len(pools) == 0 { + return []*types.PoolInfo{}, nil + } + + // Return a copy to prevent external modification + result := make([]*types.PoolInfo, len(pools)) + copy(result, pools) + + return result, nil +} + +// GetByProtocol retrieves all pools for a given protocol +func (c *poolCache) GetByProtocol(ctx context.Context, protocol types.ProtocolType) ([]*types.PoolInfo, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + pools := c.byProtocol[protocol] + + if len(pools) == 0 { + return []*types.PoolInfo{}, nil + } + + // Return a copy to prevent external modification + result := make([]*types.PoolInfo, len(pools)) + copy(result, pools) + + return result, nil +} + +// GetByLiquidity retrieves pools sorted by liquidity (descending) +func (c *poolCache) GetByLiquidity(ctx context.Context, minLiquidity *big.Int, limit int) ([]*types.PoolInfo, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + // Collect all pools with sufficient liquidity + var pools []*types.PoolInfo + for _, pool := range c.byAddress { + if pool.Liquidity != nil && pool.Liquidity.Cmp(minLiquidity) >= 0 { + pools = append(pools, pool) + } + } + + // Sort by liquidity (descending) + sort.Slice(pools, func(i, j int) bool { + return pools[i].Liquidity.Cmp(pools[j].Liquidity) > 0 + }) + + // Apply limit + if limit > 0 && len(pools) > limit { + pools = pools[:limit] + } + + return pools, nil +} + +// Add adds or updates a pool in the cache +func (c *poolCache) Add(ctx context.Context, pool *types.PoolInfo) error { + if pool == nil { + return fmt.Errorf("pool cannot be nil") + } + + if err := pool.Validate(); err != nil { + return fmt.Errorf("invalid pool: %w", err) + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Remove old indexes if pool exists + if oldPool, exists := c.byAddress[pool.Address]; exists { + c.removeFromIndexes(oldPool) + } + + // Add to primary index + c.byAddress[pool.Address] = pool + + // Add to secondary indexes + c.addToIndexes(pool) + + return nil +} + +// Update updates pool information +func (c *poolCache) Update(ctx context.Context, address common.Address, updateFn func(*types.PoolInfo) error) error { + c.mu.Lock() + defer c.mu.Unlock() + + pool, exists := c.byAddress[address] + if !exists { + return types.ErrPoolNotFound + } + + // Remove from indexes before update + c.removeFromIndexes(pool) + + // Apply update + if err := updateFn(pool); err != nil { + // Re-add to indexes even on error to maintain consistency + c.addToIndexes(pool) + return err + } + + // Validate after update + if err := pool.Validate(); err != nil { + // Re-add to indexes even on error + c.addToIndexes(pool) + return fmt.Errorf("pool invalid after update: %w", err) + } + + // Re-add to indexes + c.addToIndexes(pool) + + return nil +} + +// Remove removes a pool from the cache +func (c *poolCache) Remove(ctx context.Context, address common.Address) error { + c.mu.Lock() + defer c.mu.Unlock() + + pool, exists := c.byAddress[address] + if !exists { + return types.ErrPoolNotFound + } + + // Remove from all indexes + delete(c.byAddress, address) + c.removeFromIndexes(pool) + + return nil +} + +// Count returns the total number of pools in the cache +func (c *poolCache) Count(ctx context.Context) (int, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.byAddress), nil +} + +// Clear removes all pools from the cache +func (c *poolCache) Clear(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + c.byAddress = make(map[common.Address]*types.PoolInfo) + c.byTokenPair = make(map[string][]*types.PoolInfo) + c.byProtocol = make(map[types.ProtocolType][]*types.PoolInfo) + + return nil +} + +// addToIndexes adds a pool to secondary indexes +func (c *poolCache) addToIndexes(pool *types.PoolInfo) { + // Add to token pair index + pairKey := makeTokenPairKey(pool.Token0, pool.Token1) + c.byTokenPair[pairKey] = append(c.byTokenPair[pairKey], pool) + + // Add to protocol index + c.byProtocol[pool.Protocol] = append(c.byProtocol[pool.Protocol], pool) +} + +// removeFromIndexes removes a pool from secondary indexes +func (c *poolCache) removeFromIndexes(pool *types.PoolInfo) { + // Remove from token pair index + pairKey := makeTokenPairKey(pool.Token0, pool.Token1) + c.byTokenPair[pairKey] = removePoolFromSlice(c.byTokenPair[pairKey], pool.Address) + if len(c.byTokenPair[pairKey]) == 0 { + delete(c.byTokenPair, pairKey) + } + + // Remove from protocol index + c.byProtocol[pool.Protocol] = removePoolFromSlice(c.byProtocol[pool.Protocol], pool.Address) + if len(c.byProtocol[pool.Protocol]) == 0 { + delete(c.byProtocol, pool.Protocol) + } +} + +// makeTokenPairKey creates a consistent key for a token pair +func makeTokenPairKey(token0, token1 common.Address) string { + // Always use the smaller address first for consistency + if token0.Big().Cmp(token1.Big()) < 0 { + return token0.Hex() + "-" + token1.Hex() + } + return token1.Hex() + "-" + token0.Hex() +} + +// removePoolFromSlice removes a pool with the given address from a slice +func removePoolFromSlice(pools []*types.PoolInfo, address common.Address) []*types.PoolInfo { + for i, pool := range pools { + if pool.Address == address { + return append(pools[:i], pools[i+1:]...) + } + } + return pools +} diff --git a/pkg/cache/pool_cache_test.go b/pkg/cache/pool_cache_test.go new file mode 100644 index 0000000..b31865d --- /dev/null +++ b/pkg/cache/pool_cache_test.go @@ -0,0 +1,534 @@ +package cache + +import ( + "context" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + + "github.com/your-org/mev-bot/pkg/types" +) + +func createTestPool(address, token0, token1 string, protocol types.ProtocolType) *types.PoolInfo { + return &types.PoolInfo{ + Address: common.HexToAddress(address), + Protocol: protocol, + Token0: common.HexToAddress(token0), + Token1: common.HexToAddress(token1), + Token0Decimals: 18, + Token1Decimals: 6, + Reserve0: big.NewInt(1000000), + Reserve1: big.NewInt(500000), + Liquidity: big.NewInt(100000), + IsActive: true, + } +} + +func TestNewPoolCache(t *testing.T) { + cache := NewPoolCache() + if cache == nil { + t.Fatal("NewPoolCache returned nil") + } +} + +func TestPoolCache_Add(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + + err := cache.Add(ctx, pool) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + // Verify it was added + retrieved, err := cache.GetByAddress(ctx, pool.Address) + if err != nil { + t.Fatalf("GetByAddress() failed: %v", err) + } + + if retrieved.Address != pool.Address { + t.Errorf("Retrieved pool address = %v, want %v", retrieved.Address, pool.Address) + } +} + +func TestPoolCache_Add_NilPool(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + err := cache.Add(ctx, nil) + if err == nil { + t.Error("Add(nil) expected error, got nil") + } +} + +func TestPoolCache_Add_InvalidPool(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Pool with zero address + pool := createTestPool("0x0000", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool.Address = common.Address{} + + err := cache.Add(ctx, pool) + if err == nil { + t.Error("Add(invalid pool) expected error, got nil") + } +} + +func TestPoolCache_Add_Update(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + + // Add pool + err := cache.Add(ctx, pool) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + // Update pool (same address, different reserves) + updatedPool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + updatedPool.Reserve0 = big.NewInt(2000000) + + err = cache.Add(ctx, updatedPool) + if err != nil { + t.Fatalf("Add(update) failed: %v", err) + } + + // Verify it was updated + retrieved, err := cache.GetByAddress(ctx, pool.Address) + if err != nil { + t.Fatalf("GetByAddress() failed: %v", err) + } + + if retrieved.Reserve0.Cmp(big.NewInt(2000000)) != 0 { + t.Errorf("Updated pool reserve0 = %v, want %v", retrieved.Reserve0, big.NewInt(2000000)) + } +} + +func TestPoolCache_GetByAddress(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + cache.Add(ctx, pool) + + tests := []struct { + name string + address common.Address + wantErr bool + }{ + { + name: "existing pool", + address: common.HexToAddress("0x1111"), + wantErr: false, + }, + { + name: "non-existing pool", + address: common.HexToAddress("0x9999"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := cache.GetByAddress(ctx, tt.address) + + if tt.wantErr { + if err == nil { + t.Error("GetByAddress() expected error, got nil") + } + if got != nil { + t.Error("GetByAddress() expected nil pool on error") + } + } else { + if err != nil { + t.Errorf("GetByAddress() unexpected error: %v", err) + } + if got == nil { + t.Error("GetByAddress() returned nil pool") + } + } + }) + } +} + +func TestPoolCache_GetByTokenPair(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Add pools with same token pair + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool2 := createTestPool("0x4444", "0x2222", "0x3333", types.ProtocolUniswapV3) + + cache.Add(ctx, pool1) + cache.Add(ctx, pool2) + + tests := []struct { + name string + token0 common.Address + token1 common.Address + wantCount int + }{ + { + name: "existing pair", + token0: common.HexToAddress("0x2222"), + token1: common.HexToAddress("0x3333"), + wantCount: 2, + }, + { + name: "existing pair (reversed)", + token0: common.HexToAddress("0x3333"), + token1: common.HexToAddress("0x2222"), + wantCount: 2, + }, + { + name: "non-existing pair", + token0: common.HexToAddress("0x9999"), + token1: common.HexToAddress("0x8888"), + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pools, err := cache.GetByTokenPair(ctx, tt.token0, tt.token1) + if err != nil { + t.Fatalf("GetByTokenPair() error: %v", err) + } + + if len(pools) != tt.wantCount { + t.Errorf("GetByTokenPair() count = %d, want %d", len(pools), tt.wantCount) + } + }) + } +} + +func TestPoolCache_GetByProtocol(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Add pools with different protocols + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV2) + pool3 := createTestPool("0x7777", "0x8888", "0x9999", types.ProtocolUniswapV3) + + cache.Add(ctx, pool1) + cache.Add(ctx, pool2) + cache.Add(ctx, pool3) + + tests := []struct { + name string + protocol types.ProtocolType + wantCount int + }{ + { + name: "uniswap v2", + protocol: types.ProtocolUniswapV2, + wantCount: 2, + }, + { + name: "uniswap v3", + protocol: types.ProtocolUniswapV3, + wantCount: 1, + }, + { + name: "curve (none)", + protocol: types.ProtocolCurve, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pools, err := cache.GetByProtocol(ctx, tt.protocol) + if err != nil { + t.Fatalf("GetByProtocol() error: %v", err) + } + + if len(pools) != tt.wantCount { + t.Errorf("GetByProtocol() count = %d, want %d", len(pools), tt.wantCount) + } + }) + } +} + +func TestPoolCache_GetByLiquidity(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Add pools with different liquidity + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool1.Liquidity = big.NewInt(1000) + + pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV2) + pool2.Liquidity = big.NewInt(5000) + + pool3 := createTestPool("0x7777", "0x8888", "0x9999", types.ProtocolUniswapV3) + pool3.Liquidity = big.NewInt(10000) + + cache.Add(ctx, pool1) + cache.Add(ctx, pool2) + cache.Add(ctx, pool3) + + tests := []struct { + name string + minLiquidity *big.Int + limit int + wantCount int + wantFirst *big.Int // Expected liquidity of first result + }{ + { + name: "all pools", + minLiquidity: big.NewInt(0), + limit: 0, + wantCount: 3, + wantFirst: big.NewInt(10000), + }, + { + name: "min 2000", + minLiquidity: big.NewInt(2000), + limit: 0, + wantCount: 2, + wantFirst: big.NewInt(10000), + }, + { + name: "limit 2", + minLiquidity: big.NewInt(0), + limit: 2, + wantCount: 2, + wantFirst: big.NewInt(10000), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pools, err := cache.GetByLiquidity(ctx, tt.minLiquidity, tt.limit) + if err != nil { + t.Fatalf("GetByLiquidity() error: %v", err) + } + + if len(pools) != tt.wantCount { + t.Errorf("GetByLiquidity() count = %d, want %d", len(pools), tt.wantCount) + } + + if len(pools) > 0 && pools[0].Liquidity.Cmp(tt.wantFirst) != 0 { + t.Errorf("GetByLiquidity() first liquidity = %v, want %v", pools[0].Liquidity, tt.wantFirst) + } + }) + } +} + +func TestPoolCache_Update(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + cache.Add(ctx, pool) + + // Update reserves + err := cache.Update(ctx, pool.Address, func(p *types.PoolInfo) error { + p.Reserve0 = big.NewInt(9999999) + return nil + }) + + if err != nil { + t.Fatalf("Update() failed: %v", err) + } + + // Verify update + retrieved, err := cache.GetByAddress(ctx, pool.Address) + if err != nil { + t.Fatalf("GetByAddress() failed: %v", err) + } + + if retrieved.Reserve0.Cmp(big.NewInt(9999999)) != 0 { + t.Errorf("Update() reserve0 = %v, want %v", retrieved.Reserve0, big.NewInt(9999999)) + } +} + +func TestPoolCache_Update_NonExistent(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + err := cache.Update(ctx, common.HexToAddress("0x9999"), func(p *types.PoolInfo) error { + return nil + }) + + if err == nil { + t.Error("Update(non-existent) expected error, got nil") + } +} + +func TestPoolCache_Update_Error(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + cache.Add(ctx, pool) + + // Update with error + testErr := fmt.Errorf("test error") + err := cache.Update(ctx, pool.Address, func(p *types.PoolInfo) error { + return testErr + }) + + if err != testErr { + t.Errorf("Update() error = %v, want %v", err, testErr) + } +} + +func TestPoolCache_Update_InvalidAfterUpdate(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + cache.Add(ctx, pool) + + // Make pool invalid + err := cache.Update(ctx, pool.Address, func(p *types.PoolInfo) error { + p.Token0Decimals = 0 // Invalid + return nil + }) + + if err == nil { + t.Error("Update(invalid) expected error, got nil") + } +} + +func TestPoolCache_Remove(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + cache.Add(ctx, pool) + + // Remove pool + err := cache.Remove(ctx, pool.Address) + if err != nil { + t.Fatalf("Remove() failed: %v", err) + } + + // Verify removal + _, err = cache.GetByAddress(ctx, pool.Address) + if err == nil { + t.Error("GetByAddress() after Remove() expected error, got nil") + } + + // Verify removed from token pair index + pools, _ := cache.GetByTokenPair(ctx, pool.Token0, pool.Token1) + if len(pools) != 0 { + t.Error("Pool still in token pair index after removal") + } + + // Verify removed from protocol index + pools, _ = cache.GetByProtocol(ctx, pool.Protocol) + if len(pools) != 0 { + t.Error("Pool still in protocol index after removal") + } +} + +func TestPoolCache_Remove_NonExistent(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + err := cache.Remove(ctx, common.HexToAddress("0x9999")) + if err == nil { + t.Error("Remove(non-existent) expected error, got nil") + } +} + +func TestPoolCache_Count(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Initially empty + count, err := cache.Count(ctx) + if err != nil { + t.Fatalf("Count() failed: %v", err) + } + if count != 0 { + t.Errorf("Count() = %d, want 0", count) + } + + // Add pools + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV3) + + cache.Add(ctx, pool1) + cache.Add(ctx, pool2) + + count, err = cache.Count(ctx) + if err != nil { + t.Fatalf("Count() failed: %v", err) + } + if count != 2 { + t.Errorf("Count() = %d, want 2", count) + } +} + +func TestPoolCache_Clear(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Add pools + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV3) + + cache.Add(ctx, pool1) + cache.Add(ctx, pool2) + + // Clear cache + err := cache.Clear(ctx) + if err != nil { + t.Fatalf("Clear() failed: %v", err) + } + + // Verify cleared + count, _ := cache.Count(ctx) + if count != 0 { + t.Errorf("Count() after Clear() = %d, want 0", count) + } +} + +func Test_makeTokenPairKey(t *testing.T) { + token0 := common.HexToAddress("0x1111") + token1 := common.HexToAddress("0x2222") + + // Should be consistent regardless of order + key1 := makeTokenPairKey(token0, token1) + key2 := makeTokenPairKey(token1, token0) + + if key1 != key2 { + t.Errorf("makeTokenPairKey() not consistent: %s != %s", key1, key2) + } +} + +func Test_removePoolFromSlice(t *testing.T) { + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV3) + pool3 := createTestPool("0x7777", "0x8888", "0x9999", types.ProtocolCurve) + + pools := []*types.PoolInfo{pool1, pool2, pool3} + + // Remove middle pool + result := removePoolFromSlice(pools, pool2.Address) + + if len(result) != 2 { + t.Errorf("removePoolFromSlice() length = %d, want 2", len(result)) + } + + if result[0].Address != pool1.Address || result[1].Address != pool3.Address { + t.Error("removePoolFromSlice() incorrect pools remaining") + } + + // Remove non-existent pool + result = removePoolFromSlice(result, common.HexToAddress("0x9999")) + if len(result) != 2 { + t.Errorf("removePoolFromSlice(non-existent) length = %d, want 2", len(result)) + } +}