package cache import ( "context" "fmt" "math/big" "sort" "sync" "github.com/ethereum/go-ethereum/common" "github.com/your-org/mev-bot/pkg/types" ) // poolCache implements the PoolCache interface with multi-index support type poolCache struct { // Primary index: address -> pool byAddress map[common.Address]*types.PoolInfo // Secondary index: token pair -> pools byTokenPair map[string][]*types.PoolInfo // Tertiary index: protocol -> pools byProtocol map[types.ProtocolType][]*types.PoolInfo // Mutex for thread safety mu sync.RWMutex } // NewPoolCache creates a new multi-index pool cache func NewPoolCache() PoolCache { return &poolCache{ byAddress: make(map[common.Address]*types.PoolInfo), byTokenPair: make(map[string][]*types.PoolInfo), byProtocol: make(map[types.ProtocolType][]*types.PoolInfo), } } // GetByAddress retrieves a pool by its contract address func (c *poolCache) GetByAddress(ctx context.Context, address common.Address) (*types.PoolInfo, error) { c.mu.RLock() defer c.mu.RUnlock() pool, exists := c.byAddress[address] if !exists { return nil, types.ErrPoolNotFound } return pool, nil } // GetByTokenPair retrieves all pools for a given token pair func (c *poolCache) GetByTokenPair(ctx context.Context, token0, token1 common.Address) ([]*types.PoolInfo, error) { c.mu.RLock() defer c.mu.RUnlock() key := makeTokenPairKey(token0, token1) pools := c.byTokenPair[key] if len(pools) == 0 { return []*types.PoolInfo{}, nil } // Return a copy to prevent external modification result := make([]*types.PoolInfo, len(pools)) copy(result, pools) return result, nil } // GetByProtocol retrieves all pools for a given protocol func (c *poolCache) GetByProtocol(ctx context.Context, protocol types.ProtocolType) ([]*types.PoolInfo, error) { c.mu.RLock() defer c.mu.RUnlock() pools := c.byProtocol[protocol] if len(pools) == 0 { return []*types.PoolInfo{}, nil } // Return a copy to prevent external modification result := make([]*types.PoolInfo, len(pools)) copy(result, pools) return result, nil } // GetByLiquidity retrieves pools sorted by liquidity (descending) func (c *poolCache) GetByLiquidity(ctx context.Context, minLiquidity *big.Int, limit int) ([]*types.PoolInfo, error) { c.mu.RLock() defer c.mu.RUnlock() // Collect all pools with sufficient liquidity var pools []*types.PoolInfo for _, pool := range c.byAddress { if pool.Liquidity != nil && pool.Liquidity.Cmp(minLiquidity) >= 0 { pools = append(pools, pool) } } // Sort by liquidity (descending) sort.Slice(pools, func(i, j int) bool { return pools[i].Liquidity.Cmp(pools[j].Liquidity) > 0 }) // Apply limit if limit > 0 && len(pools) > limit { pools = pools[:limit] } return pools, nil } // Add adds or updates a pool in the cache func (c *poolCache) Add(ctx context.Context, pool *types.PoolInfo) error { if pool == nil { return fmt.Errorf("pool cannot be nil") } if err := pool.Validate(); err != nil { return fmt.Errorf("invalid pool: %w", err) } c.mu.Lock() defer c.mu.Unlock() // Remove old indexes if pool exists if oldPool, exists := c.byAddress[pool.Address]; exists { c.removeFromIndexes(oldPool) } // Add to primary index c.byAddress[pool.Address] = pool // Add to secondary indexes c.addToIndexes(pool) return nil } // Update updates pool information func (c *poolCache) Update(ctx context.Context, address common.Address, updateFn func(*types.PoolInfo) error) error { c.mu.Lock() defer c.mu.Unlock() pool, exists := c.byAddress[address] if !exists { return types.ErrPoolNotFound } // Remove from indexes before update c.removeFromIndexes(pool) // Apply update if err := updateFn(pool); err != nil { // Re-add to indexes even on error to maintain consistency c.addToIndexes(pool) return err } // Validate after update if err := pool.Validate(); err != nil { // Re-add to indexes even on error c.addToIndexes(pool) return fmt.Errorf("pool invalid after update: %w", err) } // Re-add to indexes c.addToIndexes(pool) return nil } // Remove removes a pool from the cache func (c *poolCache) Remove(ctx context.Context, address common.Address) error { c.mu.Lock() defer c.mu.Unlock() pool, exists := c.byAddress[address] if !exists { return types.ErrPoolNotFound } // Remove from all indexes delete(c.byAddress, address) c.removeFromIndexes(pool) return nil } // Count returns the total number of pools in the cache func (c *poolCache) Count(ctx context.Context) (int, error) { c.mu.RLock() defer c.mu.RUnlock() return len(c.byAddress), nil } // Clear removes all pools from the cache func (c *poolCache) Clear(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() c.byAddress = make(map[common.Address]*types.PoolInfo) c.byTokenPair = make(map[string][]*types.PoolInfo) c.byProtocol = make(map[types.ProtocolType][]*types.PoolInfo) return nil } // addToIndexes adds a pool to secondary indexes func (c *poolCache) addToIndexes(pool *types.PoolInfo) { // Add to token pair index pairKey := makeTokenPairKey(pool.Token0, pool.Token1) c.byTokenPair[pairKey] = append(c.byTokenPair[pairKey], pool) // Add to protocol index c.byProtocol[pool.Protocol] = append(c.byProtocol[pool.Protocol], pool) } // removeFromIndexes removes a pool from secondary indexes func (c *poolCache) removeFromIndexes(pool *types.PoolInfo) { // Remove from token pair index pairKey := makeTokenPairKey(pool.Token0, pool.Token1) c.byTokenPair[pairKey] = removePoolFromSlice(c.byTokenPair[pairKey], pool.Address) if len(c.byTokenPair[pairKey]) == 0 { delete(c.byTokenPair, pairKey) } // Remove from protocol index c.byProtocol[pool.Protocol] = removePoolFromSlice(c.byProtocol[pool.Protocol], pool.Address) if len(c.byProtocol[pool.Protocol]) == 0 { delete(c.byProtocol, pool.Protocol) } } // makeTokenPairKey creates a consistent key for a token pair func makeTokenPairKey(token0, token1 common.Address) string { // Always use the smaller address first for consistency if token0.Big().Cmp(token1.Big()) < 0 { return token0.Hex() + "-" + token1.Hex() } return token1.Hex() + "-" + token0.Hex() } // removePoolFromSlice removes a pool with the given address from a slice func removePoolFromSlice(pools []*types.PoolInfo, address common.Address) []*types.PoolInfo { for i, pool := range pools { if pool.Address == address { return append(pools[:i], pools[i+1:]...) } } return pools }