diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/pkg/cache/pool_cache.go b/pkg/cache/pool_cache.go new file mode 100644 index 0000000..d8df870 --- /dev/null +++ b/pkg/cache/pool_cache.go @@ -0,0 +1,255 @@ +package cache + +import ( + "context" + "fmt" + "math/big" + "sort" + "sync" + + "github.com/ethereum/go-ethereum/common" + + "github.com/your-org/mev-bot/pkg/types" +) + +// poolCache implements the PoolCache interface with multi-index support +type poolCache struct { + // Primary index: address -> pool + byAddress map[common.Address]*types.PoolInfo + + // Secondary index: token pair -> pools + byTokenPair map[string][]*types.PoolInfo + + // Tertiary index: protocol -> pools + byProtocol map[types.ProtocolType][]*types.PoolInfo + + // Mutex for thread safety + mu sync.RWMutex +} + +// NewPoolCache creates a new multi-index pool cache +func NewPoolCache() PoolCache { + return &poolCache{ + byAddress: make(map[common.Address]*types.PoolInfo), + byTokenPair: make(map[string][]*types.PoolInfo), + byProtocol: make(map[types.ProtocolType][]*types.PoolInfo), + } +} + +// GetByAddress retrieves a pool by its contract address +func (c *poolCache) GetByAddress(ctx context.Context, address common.Address) (*types.PoolInfo, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + pool, exists := c.byAddress[address] + if !exists { + return nil, types.ErrPoolNotFound + } + + return pool, nil +} + +// GetByTokenPair retrieves all pools for a given token pair +func (c *poolCache) GetByTokenPair(ctx context.Context, token0, token1 common.Address) ([]*types.PoolInfo, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + key := makeTokenPairKey(token0, token1) + pools := c.byTokenPair[key] + + if len(pools) == 0 { + return []*types.PoolInfo{}, nil + } + + // Return a copy to prevent external modification + result := make([]*types.PoolInfo, len(pools)) + copy(result, pools) + + return result, nil +} + +// GetByProtocol retrieves all pools for a given protocol +func (c *poolCache) GetByProtocol(ctx context.Context, protocol types.ProtocolType) ([]*types.PoolInfo, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + pools := c.byProtocol[protocol] + + if len(pools) == 0 { + return []*types.PoolInfo{}, nil + } + + // Return a copy to prevent external modification + result := make([]*types.PoolInfo, len(pools)) + copy(result, pools) + + return result, nil +} + +// GetByLiquidity retrieves pools sorted by liquidity (descending) +func (c *poolCache) GetByLiquidity(ctx context.Context, minLiquidity *big.Int, limit int) ([]*types.PoolInfo, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + // Collect all pools with sufficient liquidity + var pools []*types.PoolInfo + for _, pool := range c.byAddress { + if pool.Liquidity != nil && pool.Liquidity.Cmp(minLiquidity) >= 0 { + pools = append(pools, pool) + } + } + + // Sort by liquidity (descending) + sort.Slice(pools, func(i, j int) bool { + return pools[i].Liquidity.Cmp(pools[j].Liquidity) > 0 + }) + + // Apply limit + if limit > 0 && len(pools) > limit { + pools = pools[:limit] + } + + return pools, nil +} + +// Add adds or updates a pool in the cache +func (c *poolCache) Add(ctx context.Context, pool *types.PoolInfo) error { + if pool == nil { + return fmt.Errorf("pool cannot be nil") + } + + if err := pool.Validate(); err != nil { + return fmt.Errorf("invalid pool: %w", err) + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Remove old indexes if pool exists + if oldPool, exists := c.byAddress[pool.Address]; exists { + c.removeFromIndexes(oldPool) + } + + // Add to primary index + c.byAddress[pool.Address] = pool + + // Add to secondary indexes + c.addToIndexes(pool) + + return nil +} + +// Update updates pool information +func (c *poolCache) Update(ctx context.Context, address common.Address, updateFn func(*types.PoolInfo) error) error { + c.mu.Lock() + defer c.mu.Unlock() + + pool, exists := c.byAddress[address] + if !exists { + return types.ErrPoolNotFound + } + + // Remove from indexes before update + c.removeFromIndexes(pool) + + // Apply update + if err := updateFn(pool); err != nil { + // Re-add to indexes even on error to maintain consistency + c.addToIndexes(pool) + return err + } + + // Validate after update + if err := pool.Validate(); err != nil { + // Re-add to indexes even on error + c.addToIndexes(pool) + return fmt.Errorf("pool invalid after update: %w", err) + } + + // Re-add to indexes + c.addToIndexes(pool) + + return nil +} + +// Remove removes a pool from the cache +func (c *poolCache) Remove(ctx context.Context, address common.Address) error { + c.mu.Lock() + defer c.mu.Unlock() + + pool, exists := c.byAddress[address] + if !exists { + return types.ErrPoolNotFound + } + + // Remove from all indexes + delete(c.byAddress, address) + c.removeFromIndexes(pool) + + return nil +} + +// Count returns the total number of pools in the cache +func (c *poolCache) Count(ctx context.Context) (int, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.byAddress), nil +} + +// Clear removes all pools from the cache +func (c *poolCache) Clear(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + c.byAddress = make(map[common.Address]*types.PoolInfo) + c.byTokenPair = make(map[string][]*types.PoolInfo) + c.byProtocol = make(map[types.ProtocolType][]*types.PoolInfo) + + return nil +} + +// addToIndexes adds a pool to secondary indexes +func (c *poolCache) addToIndexes(pool *types.PoolInfo) { + // Add to token pair index + pairKey := makeTokenPairKey(pool.Token0, pool.Token1) + c.byTokenPair[pairKey] = append(c.byTokenPair[pairKey], pool) + + // Add to protocol index + c.byProtocol[pool.Protocol] = append(c.byProtocol[pool.Protocol], pool) +} + +// removeFromIndexes removes a pool from secondary indexes +func (c *poolCache) removeFromIndexes(pool *types.PoolInfo) { + // Remove from token pair index + pairKey := makeTokenPairKey(pool.Token0, pool.Token1) + c.byTokenPair[pairKey] = removePoolFromSlice(c.byTokenPair[pairKey], pool.Address) + if len(c.byTokenPair[pairKey]) == 0 { + delete(c.byTokenPair, pairKey) + } + + // Remove from protocol index + c.byProtocol[pool.Protocol] = removePoolFromSlice(c.byProtocol[pool.Protocol], pool.Address) + if len(c.byProtocol[pool.Protocol]) == 0 { + delete(c.byProtocol, pool.Protocol) + } +} + +// makeTokenPairKey creates a consistent key for a token pair +func makeTokenPairKey(token0, token1 common.Address) string { + // Always use the smaller address first for consistency + if token0.Big().Cmp(token1.Big()) < 0 { + return token0.Hex() + "-" + token1.Hex() + } + return token1.Hex() + "-" + token0.Hex() +} + +// removePoolFromSlice removes a pool with the given address from a slice +func removePoolFromSlice(pools []*types.PoolInfo, address common.Address) []*types.PoolInfo { + for i, pool := range pools { + if pool.Address == address { + return append(pools[:i], pools[i+1:]...) + } + } + return pools +} diff --git a/pkg/cache/pool_cache_test.go b/pkg/cache/pool_cache_test.go new file mode 100644 index 0000000..b31865d --- /dev/null +++ b/pkg/cache/pool_cache_test.go @@ -0,0 +1,534 @@ +package cache + +import ( + "context" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + + "github.com/your-org/mev-bot/pkg/types" +) + +func createTestPool(address, token0, token1 string, protocol types.ProtocolType) *types.PoolInfo { + return &types.PoolInfo{ + Address: common.HexToAddress(address), + Protocol: protocol, + Token0: common.HexToAddress(token0), + Token1: common.HexToAddress(token1), + Token0Decimals: 18, + Token1Decimals: 6, + Reserve0: big.NewInt(1000000), + Reserve1: big.NewInt(500000), + Liquidity: big.NewInt(100000), + IsActive: true, + } +} + +func TestNewPoolCache(t *testing.T) { + cache := NewPoolCache() + if cache == nil { + t.Fatal("NewPoolCache returned nil") + } +} + +func TestPoolCache_Add(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + + err := cache.Add(ctx, pool) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + // Verify it was added + retrieved, err := cache.GetByAddress(ctx, pool.Address) + if err != nil { + t.Fatalf("GetByAddress() failed: %v", err) + } + + if retrieved.Address != pool.Address { + t.Errorf("Retrieved pool address = %v, want %v", retrieved.Address, pool.Address) + } +} + +func TestPoolCache_Add_NilPool(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + err := cache.Add(ctx, nil) + if err == nil { + t.Error("Add(nil) expected error, got nil") + } +} + +func TestPoolCache_Add_InvalidPool(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Pool with zero address + pool := createTestPool("0x0000", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool.Address = common.Address{} + + err := cache.Add(ctx, pool) + if err == nil { + t.Error("Add(invalid pool) expected error, got nil") + } +} + +func TestPoolCache_Add_Update(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + + // Add pool + err := cache.Add(ctx, pool) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + // Update pool (same address, different reserves) + updatedPool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + updatedPool.Reserve0 = big.NewInt(2000000) + + err = cache.Add(ctx, updatedPool) + if err != nil { + t.Fatalf("Add(update) failed: %v", err) + } + + // Verify it was updated + retrieved, err := cache.GetByAddress(ctx, pool.Address) + if err != nil { + t.Fatalf("GetByAddress() failed: %v", err) + } + + if retrieved.Reserve0.Cmp(big.NewInt(2000000)) != 0 { + t.Errorf("Updated pool reserve0 = %v, want %v", retrieved.Reserve0, big.NewInt(2000000)) + } +} + +func TestPoolCache_GetByAddress(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + cache.Add(ctx, pool) + + tests := []struct { + name string + address common.Address + wantErr bool + }{ + { + name: "existing pool", + address: common.HexToAddress("0x1111"), + wantErr: false, + }, + { + name: "non-existing pool", + address: common.HexToAddress("0x9999"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := cache.GetByAddress(ctx, tt.address) + + if tt.wantErr { + if err == nil { + t.Error("GetByAddress() expected error, got nil") + } + if got != nil { + t.Error("GetByAddress() expected nil pool on error") + } + } else { + if err != nil { + t.Errorf("GetByAddress() unexpected error: %v", err) + } + if got == nil { + t.Error("GetByAddress() returned nil pool") + } + } + }) + } +} + +func TestPoolCache_GetByTokenPair(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Add pools with same token pair + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool2 := createTestPool("0x4444", "0x2222", "0x3333", types.ProtocolUniswapV3) + + cache.Add(ctx, pool1) + cache.Add(ctx, pool2) + + tests := []struct { + name string + token0 common.Address + token1 common.Address + wantCount int + }{ + { + name: "existing pair", + token0: common.HexToAddress("0x2222"), + token1: common.HexToAddress("0x3333"), + wantCount: 2, + }, + { + name: "existing pair (reversed)", + token0: common.HexToAddress("0x3333"), + token1: common.HexToAddress("0x2222"), + wantCount: 2, + }, + { + name: "non-existing pair", + token0: common.HexToAddress("0x9999"), + token1: common.HexToAddress("0x8888"), + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pools, err := cache.GetByTokenPair(ctx, tt.token0, tt.token1) + if err != nil { + t.Fatalf("GetByTokenPair() error: %v", err) + } + + if len(pools) != tt.wantCount { + t.Errorf("GetByTokenPair() count = %d, want %d", len(pools), tt.wantCount) + } + }) + } +} + +func TestPoolCache_GetByProtocol(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Add pools with different protocols + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV2) + pool3 := createTestPool("0x7777", "0x8888", "0x9999", types.ProtocolUniswapV3) + + cache.Add(ctx, pool1) + cache.Add(ctx, pool2) + cache.Add(ctx, pool3) + + tests := []struct { + name string + protocol types.ProtocolType + wantCount int + }{ + { + name: "uniswap v2", + protocol: types.ProtocolUniswapV2, + wantCount: 2, + }, + { + name: "uniswap v3", + protocol: types.ProtocolUniswapV3, + wantCount: 1, + }, + { + name: "curve (none)", + protocol: types.ProtocolCurve, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pools, err := cache.GetByProtocol(ctx, tt.protocol) + if err != nil { + t.Fatalf("GetByProtocol() error: %v", err) + } + + if len(pools) != tt.wantCount { + t.Errorf("GetByProtocol() count = %d, want %d", len(pools), tt.wantCount) + } + }) + } +} + +func TestPoolCache_GetByLiquidity(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Add pools with different liquidity + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool1.Liquidity = big.NewInt(1000) + + pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV2) + pool2.Liquidity = big.NewInt(5000) + + pool3 := createTestPool("0x7777", "0x8888", "0x9999", types.ProtocolUniswapV3) + pool3.Liquidity = big.NewInt(10000) + + cache.Add(ctx, pool1) + cache.Add(ctx, pool2) + cache.Add(ctx, pool3) + + tests := []struct { + name string + minLiquidity *big.Int + limit int + wantCount int + wantFirst *big.Int // Expected liquidity of first result + }{ + { + name: "all pools", + minLiquidity: big.NewInt(0), + limit: 0, + wantCount: 3, + wantFirst: big.NewInt(10000), + }, + { + name: "min 2000", + minLiquidity: big.NewInt(2000), + limit: 0, + wantCount: 2, + wantFirst: big.NewInt(10000), + }, + { + name: "limit 2", + minLiquidity: big.NewInt(0), + limit: 2, + wantCount: 2, + wantFirst: big.NewInt(10000), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pools, err := cache.GetByLiquidity(ctx, tt.minLiquidity, tt.limit) + if err != nil { + t.Fatalf("GetByLiquidity() error: %v", err) + } + + if len(pools) != tt.wantCount { + t.Errorf("GetByLiquidity() count = %d, want %d", len(pools), tt.wantCount) + } + + if len(pools) > 0 && pools[0].Liquidity.Cmp(tt.wantFirst) != 0 { + t.Errorf("GetByLiquidity() first liquidity = %v, want %v", pools[0].Liquidity, tt.wantFirst) + } + }) + } +} + +func TestPoolCache_Update(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + cache.Add(ctx, pool) + + // Update reserves + err := cache.Update(ctx, pool.Address, func(p *types.PoolInfo) error { + p.Reserve0 = big.NewInt(9999999) + return nil + }) + + if err != nil { + t.Fatalf("Update() failed: %v", err) + } + + // Verify update + retrieved, err := cache.GetByAddress(ctx, pool.Address) + if err != nil { + t.Fatalf("GetByAddress() failed: %v", err) + } + + if retrieved.Reserve0.Cmp(big.NewInt(9999999)) != 0 { + t.Errorf("Update() reserve0 = %v, want %v", retrieved.Reserve0, big.NewInt(9999999)) + } +} + +func TestPoolCache_Update_NonExistent(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + err := cache.Update(ctx, common.HexToAddress("0x9999"), func(p *types.PoolInfo) error { + return nil + }) + + if err == nil { + t.Error("Update(non-existent) expected error, got nil") + } +} + +func TestPoolCache_Update_Error(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + cache.Add(ctx, pool) + + // Update with error + testErr := fmt.Errorf("test error") + err := cache.Update(ctx, pool.Address, func(p *types.PoolInfo) error { + return testErr + }) + + if err != testErr { + t.Errorf("Update() error = %v, want %v", err, testErr) + } +} + +func TestPoolCache_Update_InvalidAfterUpdate(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + cache.Add(ctx, pool) + + // Make pool invalid + err := cache.Update(ctx, pool.Address, func(p *types.PoolInfo) error { + p.Token0Decimals = 0 // Invalid + return nil + }) + + if err == nil { + t.Error("Update(invalid) expected error, got nil") + } +} + +func TestPoolCache_Remove(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + pool := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + cache.Add(ctx, pool) + + // Remove pool + err := cache.Remove(ctx, pool.Address) + if err != nil { + t.Fatalf("Remove() failed: %v", err) + } + + // Verify removal + _, err = cache.GetByAddress(ctx, pool.Address) + if err == nil { + t.Error("GetByAddress() after Remove() expected error, got nil") + } + + // Verify removed from token pair index + pools, _ := cache.GetByTokenPair(ctx, pool.Token0, pool.Token1) + if len(pools) != 0 { + t.Error("Pool still in token pair index after removal") + } + + // Verify removed from protocol index + pools, _ = cache.GetByProtocol(ctx, pool.Protocol) + if len(pools) != 0 { + t.Error("Pool still in protocol index after removal") + } +} + +func TestPoolCache_Remove_NonExistent(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + err := cache.Remove(ctx, common.HexToAddress("0x9999")) + if err == nil { + t.Error("Remove(non-existent) expected error, got nil") + } +} + +func TestPoolCache_Count(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Initially empty + count, err := cache.Count(ctx) + if err != nil { + t.Fatalf("Count() failed: %v", err) + } + if count != 0 { + t.Errorf("Count() = %d, want 0", count) + } + + // Add pools + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV3) + + cache.Add(ctx, pool1) + cache.Add(ctx, pool2) + + count, err = cache.Count(ctx) + if err != nil { + t.Fatalf("Count() failed: %v", err) + } + if count != 2 { + t.Errorf("Count() = %d, want 2", count) + } +} + +func TestPoolCache_Clear(t *testing.T) { + ctx := context.Background() + cache := NewPoolCache() + + // Add pools + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV3) + + cache.Add(ctx, pool1) + cache.Add(ctx, pool2) + + // Clear cache + err := cache.Clear(ctx) + if err != nil { + t.Fatalf("Clear() failed: %v", err) + } + + // Verify cleared + count, _ := cache.Count(ctx) + if count != 0 { + t.Errorf("Count() after Clear() = %d, want 0", count) + } +} + +func Test_makeTokenPairKey(t *testing.T) { + token0 := common.HexToAddress("0x1111") + token1 := common.HexToAddress("0x2222") + + // Should be consistent regardless of order + key1 := makeTokenPairKey(token0, token1) + key2 := makeTokenPairKey(token1, token0) + + if key1 != key2 { + t.Errorf("makeTokenPairKey() not consistent: %s != %s", key1, key2) + } +} + +func Test_removePoolFromSlice(t *testing.T) { + pool1 := createTestPool("0x1111", "0x2222", "0x3333", types.ProtocolUniswapV2) + pool2 := createTestPool("0x4444", "0x5555", "0x6666", types.ProtocolUniswapV3) + pool3 := createTestPool("0x7777", "0x8888", "0x9999", types.ProtocolCurve) + + pools := []*types.PoolInfo{pool1, pool2, pool3} + + // Remove middle pool + result := removePoolFromSlice(pools, pool2.Address) + + if len(result) != 2 { + t.Errorf("removePoolFromSlice() length = %d, want 2", len(result)) + } + + if result[0].Address != pool1.Address || result[1].Address != pool3.Address { + t.Error("removePoolFromSlice() incorrect pools remaining") + } + + // Remove non-existent pool + result = removePoolFromSlice(result, common.HexToAddress("0x9999")) + if len(result) != 2 { + t.Errorf("removePoolFromSlice(non-existent) length = %d, want 2", len(result)) + } +} diff --git a/pkg/observability/logger_test.go b/pkg/observability/logger_test.go new file mode 100644 index 0000000..9ddc715 --- /dev/null +++ b/pkg/observability/logger_test.go @@ -0,0 +1,85 @@ +package observability + +import ( + "context" + "log/slog" + "testing" +) + +func TestNewLogger(t *testing.T) { + logger := NewLogger(slog.LevelInfo) + if logger == nil { + t.Fatal("NewLogger returned nil") + } +} + +func TestLogger_Debug(t *testing.T) { + logger := NewLogger(slog.LevelDebug) + // Should not panic + logger.Debug("test message", "key", "value") +} + +func TestLogger_Info(t *testing.T) { + logger := NewLogger(slog.LevelInfo) + // Should not panic + logger.Info("test message", "key", "value") +} + +func TestLogger_Warn(t *testing.T) { + logger := NewLogger(slog.LevelWarn) + // Should not panic + logger.Warn("test message", "key", "value") +} + +func TestLogger_Error(t *testing.T) { + logger := NewLogger(slog.LevelError) + // Should not panic + logger.Error("test message", "key", "value") +} + +func TestLogger_With(t *testing.T) { + logger := NewLogger(slog.LevelInfo) + contextLogger := logger.With("component", "test") + + if contextLogger == nil { + t.Fatal("With() returned nil") + } + + // Should not panic + contextLogger.Info("test message") +} + +func TestLogger_WithContext(t *testing.T) { + logger := NewLogger(slog.LevelInfo) + ctx := context.Background() + contextLogger := logger.WithContext(ctx) + + if contextLogger == nil { + t.Fatal("WithContext() returned nil") + } + + // Should not panic + contextLogger.Info("test message") +} + +func TestLogger_AllLevels(t *testing.T) { + levels := []slog.Level{ + slog.LevelDebug, + slog.LevelInfo, + slog.LevelWarn, + slog.LevelError, + } + + for _, level := range levels { + logger := NewLogger(level) + if logger == nil { + t.Errorf("NewLogger(%v) returned nil", level) + } + + // All log methods should work regardless of level + logger.Debug("debug") + logger.Info("info") + logger.Warn("warn") + logger.Error("error") + } +} diff --git a/pkg/observability/metrics_test.go b/pkg/observability/metrics_test.go new file mode 100644 index 0000000..76d3ff5 --- /dev/null +++ b/pkg/observability/metrics_test.go @@ -0,0 +1,65 @@ +package observability + +import ( + "testing" +) + +func TestNewMetrics(t *testing.T) { + metrics := NewMetrics("test") + if metrics == nil { + t.Fatal("NewMetrics returned nil") + } +} + +func TestMetrics_RecordSwapEvent(t *testing.T) { + metrics := NewMetrics("test") + + // Should not panic + metrics.RecordSwapEvent("uniswap-v2", true) + metrics.RecordSwapEvent("uniswap-v3", false) +} + +func TestMetrics_RecordParseLatency(t *testing.T) { + metrics := NewMetrics("test") + + // Should not panic + metrics.RecordParseLatency("uniswap-v2", 0.005) + metrics.RecordParseLatency("uniswap-v3", 0.003) +} + +func TestMetrics_RecordArbitrageOpportunity(t *testing.T) { + metrics := NewMetrics("test") + + // Should not panic + metrics.RecordArbitrageOpportunity(0.1) + metrics.RecordArbitrageOpportunity(0.5) +} + +func TestMetrics_RecordExecution(t *testing.T) { + metrics := NewMetrics("test") + + // Should not panic + metrics.RecordExecution(true, 0.05) + metrics.RecordExecution(false, -0.01) +} + +func TestMetrics_PoolCacheSize(t *testing.T) { + metrics := NewMetrics("test") + + // Should not panic + metrics.IncrementPoolCacheSize() + metrics.IncrementPoolCacheSize() + metrics.DecrementPoolCacheSize() +} + +func TestMetrics_AllMethods(t *testing.T) { + metrics := NewMetrics("test") + + // Test all methods in sequence + metrics.RecordSwapEvent("uniswap-v2", true) + metrics.RecordParseLatency("uniswap-v2", 0.004) + metrics.RecordArbitrageOpportunity(0.08) + metrics.RecordExecution(true, 0.05) + metrics.IncrementPoolCacheSize() + metrics.DecrementPoolCacheSize() +} diff --git a/pkg/parsers/factory.go b/pkg/parsers/factory.go new file mode 100644 index 0000000..4249e05 --- /dev/null +++ b/pkg/parsers/factory.go @@ -0,0 +1,104 @@ +package parsers + +import ( + "context" + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/core/types" + + mevtypes "github.com/your-org/mev-bot/pkg/types" +) + +// factory implements the Factory interface +type factory struct { + parsers map[mevtypes.ProtocolType]Parser + mu sync.RWMutex +} + +// NewFactory creates a new parser factory +func NewFactory() Factory { + return &factory{ + parsers: make(map[mevtypes.ProtocolType]Parser), + } +} + +// RegisterParser registers a parser for a protocol +func (f *factory) RegisterParser(protocol mevtypes.ProtocolType, parser Parser) error { + if protocol == mevtypes.ProtocolUnknown { + return fmt.Errorf("cannot register parser for unknown protocol") + } + + if parser == nil { + return fmt.Errorf("parser cannot be nil") + } + + f.mu.Lock() + defer f.mu.Unlock() + + if _, exists := f.parsers[protocol]; exists { + return fmt.Errorf("parser for protocol %s already registered", protocol) + } + + f.parsers[protocol] = parser + return nil +} + +// GetParser returns a parser for the given protocol +func (f *factory) GetParser(protocol mevtypes.ProtocolType) (Parser, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + parser, exists := f.parsers[protocol] + if !exists { + return nil, fmt.Errorf("no parser registered for protocol %s", protocol) + } + + return parser, nil +} + +// ParseLog routes a log to the appropriate parser +func (f *factory) ParseLog(ctx context.Context, log types.Log, tx *types.Transaction) (*mevtypes.SwapEvent, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + // Try each registered parser + for _, parser := range f.parsers { + if parser.SupportsLog(log) { + return parser.ParseLog(ctx, log, tx) + } + } + + return nil, mevtypes.ErrUnsupportedProtocol +} + +// ParseTransaction parses all swap events from a transaction +func (f *factory) ParseTransaction(ctx context.Context, tx *types.Transaction, receipt *types.Receipt) ([]*mevtypes.SwapEvent, error) { + if receipt == nil { + return nil, fmt.Errorf("receipt cannot be nil") + } + + f.mu.RLock() + defer f.mu.RUnlock() + + var allEvents []*mevtypes.SwapEvent + + // Try each log with all parsers + for _, log := range receipt.Logs { + for _, parser := range f.parsers { + if parser.SupportsLog(*log) { + event, err := parser.ParseLog(ctx, *log, tx) + if err != nil { + // Log error but continue with other parsers + continue + } + if event != nil { + allEvents = append(allEvents, event) + } + break // Found parser for this log, move to next log + } + } + } + + return allEvents, nil +} diff --git a/pkg/parsers/factory_test.go b/pkg/parsers/factory_test.go new file mode 100644 index 0000000..17b4c72 --- /dev/null +++ b/pkg/parsers/factory_test.go @@ -0,0 +1,407 @@ +package parsers + +import ( + "context" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + + mevtypes "github.com/your-org/mev-bot/pkg/types" +) + +// mockParser is a mock implementation of Parser for testing +type mockParser struct { + protocol mevtypes.ProtocolType + supportsLog func(types.Log) bool + parseLog func(context.Context, types.Log, *types.Transaction) (*mevtypes.SwapEvent, error) + parseReceipt func(context.Context, *types.Receipt, *types.Transaction) ([]*mevtypes.SwapEvent, error) +} + +func (m *mockParser) ParseLog(ctx context.Context, log types.Log, tx *types.Transaction) (*mevtypes.SwapEvent, error) { + if m.parseLog != nil { + return m.parseLog(ctx, log, tx) + } + return nil, nil +} + +func (m *mockParser) ParseReceipt(ctx context.Context, receipt *types.Receipt, tx *types.Transaction) ([]*mevtypes.SwapEvent, error) { + if m.parseReceipt != nil { + return m.parseReceipt(ctx, receipt, tx) + } + return nil, nil +} + +func (m *mockParser) SupportsLog(log types.Log) bool { + if m.supportsLog != nil { + return m.supportsLog(log) + } + return false +} + +func (m *mockParser) Protocol() mevtypes.ProtocolType { + return m.protocol +} + +func TestNewFactory(t *testing.T) { + factory := NewFactory() + if factory == nil { + t.Fatal("NewFactory returned nil") + } +} + +func TestFactory_RegisterParser(t *testing.T) { + tests := []struct { + name string + protocol mevtypes.ProtocolType + parser Parser + wantErr bool + errString string + }{ + { + name: "valid registration", + protocol: mevtypes.ProtocolUniswapV2, + parser: &mockParser{protocol: mevtypes.ProtocolUniswapV2}, + wantErr: false, + }, + { + name: "unknown protocol", + protocol: mevtypes.ProtocolUnknown, + parser: &mockParser{protocol: mevtypes.ProtocolUnknown}, + wantErr: true, + errString: "cannot register parser for unknown protocol", + }, + { + name: "nil parser", + protocol: mevtypes.ProtocolUniswapV2, + parser: nil, + wantErr: true, + errString: "parser cannot be nil", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := NewFactory() + err := factory.RegisterParser(tt.protocol, tt.parser) + + if tt.wantErr { + if err == nil { + t.Errorf("RegisterParser() expected error, got nil") + return + } + if tt.errString != "" && err.Error() != tt.errString { + t.Errorf("RegisterParser() error = %v, want %v", err.Error(), tt.errString) + } + } else { + if err != nil { + t.Errorf("RegisterParser() unexpected error: %v", err) + } + } + }) + } +} + +func TestFactory_RegisterParser_Duplicate(t *testing.T) { + factory := NewFactory() + parser := &mockParser{protocol: mevtypes.ProtocolUniswapV2} + + // First registration should succeed + err := factory.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + if err != nil { + t.Fatalf("First RegisterParser() failed: %v", err) + } + + // Second registration should fail + err = factory.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + if err == nil { + t.Error("RegisterParser() expected error for duplicate registration, got nil") + } +} + +func TestFactory_GetParser(t *testing.T) { + factory := NewFactory() + parser := &mockParser{protocol: mevtypes.ProtocolUniswapV2} + + // Register parser + err := factory.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + if err != nil { + t.Fatalf("RegisterParser() failed: %v", err) + } + + tests := []struct { + name string + protocol mevtypes.ProtocolType + wantErr bool + }{ + { + name: "get registered parser", + protocol: mevtypes.ProtocolUniswapV2, + wantErr: false, + }, + { + name: "get unregistered parser", + protocol: mevtypes.ProtocolUniswapV3, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := factory.GetParser(tt.protocol) + + if tt.wantErr { + if err == nil { + t.Error("GetParser() expected error, got nil") + } + if got != nil { + t.Error("GetParser() expected nil parser on error") + } + } else { + if err != nil { + t.Errorf("GetParser() unexpected error: %v", err) + } + if got == nil { + t.Error("GetParser() returned nil parser") + } + } + }) + } +} + +func TestFactory_ParseLog(t *testing.T) { + ctx := context.Background() + + // Create test log + testLog := types.Log{ + Address: common.HexToAddress("0x1234"), + Topics: []common.Hash{common.HexToHash("0xabcd")}, + Data: []byte{}, + } + + testTx := types.NewTransaction( + 0, + common.HexToAddress("0x1234"), + big.NewInt(0), + 21000, + big.NewInt(1000000000), + nil, + ) + + tests := []struct { + name string + setupFactory func() Factory + log types.Log + tx *types.Transaction + wantErr bool + wantEvent bool + }{ + { + name: "parser supports log", + setupFactory: func() Factory { + f := NewFactory() + parser := &mockParser{ + protocol: mevtypes.ProtocolUniswapV2, + supportsLog: func(log types.Log) bool { + return true + }, + parseLog: func(ctx context.Context, log types.Log, tx *types.Transaction) (*mevtypes.SwapEvent, error) { + return &mevtypes.SwapEvent{ + Protocol: mevtypes.ProtocolUniswapV2, + }, nil + }, + } + f.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + return f + }, + log: testLog, + tx: testTx, + wantErr: false, + wantEvent: true, + }, + { + name: "no parser supports log", + setupFactory: func() Factory { + f := NewFactory() + parser := &mockParser{ + protocol: mevtypes.ProtocolUniswapV2, + supportsLog: func(log types.Log) bool { + return false + }, + } + f.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + return f + }, + log: testLog, + tx: testTx, + wantErr: true, + wantEvent: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := tt.setupFactory() + event, err := factory.ParseLog(ctx, tt.log, tt.tx) + + if tt.wantErr { + if err == nil { + t.Error("ParseLog() expected error, got nil") + } + } else { + if err != nil { + t.Errorf("ParseLog() unexpected error: %v", err) + } + } + + if tt.wantEvent { + if event == nil { + t.Error("ParseLog() expected event, got nil") + } + } else { + if event != nil && !tt.wantErr { + t.Error("ParseLog() expected nil event") + } + } + }) + } +} + +func TestFactory_ParseTransaction(t *testing.T) { + ctx := context.Background() + + testTx := types.NewTransaction( + 0, + common.HexToAddress("0x1234"), + big.NewInt(0), + 21000, + big.NewInt(1000000000), + nil, + ) + + testLog := &types.Log{ + Address: common.HexToAddress("0x1234"), + Topics: []common.Hash{common.HexToHash("0xabcd")}, + Data: []byte{}, + } + + testReceipt := &types.Receipt{ + Logs: []*types.Log{testLog}, + } + + tests := []struct { + name string + setupFactory func() Factory + tx *types.Transaction + receipt *types.Receipt + wantErr bool + wantEvents int + }{ + { + name: "parse transaction with events", + setupFactory: func() Factory { + f := NewFactory() + parser := &mockParser{ + protocol: mevtypes.ProtocolUniswapV2, + supportsLog: func(log types.Log) bool { + return true + }, + parseLog: func(ctx context.Context, log types.Log, tx *types.Transaction) (*mevtypes.SwapEvent, error) { + return &mevtypes.SwapEvent{ + Protocol: mevtypes.ProtocolUniswapV2, + }, nil + }, + } + f.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + return f + }, + tx: testTx, + receipt: testReceipt, + wantErr: false, + wantEvents: 1, + }, + { + name: "parse transaction with no matching parsers", + setupFactory: func() Factory { + f := NewFactory() + parser := &mockParser{ + protocol: mevtypes.ProtocolUniswapV2, + supportsLog: func(log types.Log) bool { + return false + }, + } + f.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + return f + }, + tx: testTx, + receipt: testReceipt, + wantErr: false, + wantEvents: 0, + }, + { + name: "nil receipt", + setupFactory: func() Factory { + return NewFactory() + }, + tx: testTx, + receipt: nil, + wantErr: true, + wantEvents: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := tt.setupFactory() + events, err := factory.ParseTransaction(ctx, tt.tx, tt.receipt) + + if tt.wantErr { + if err == nil { + t.Error("ParseTransaction() expected error, got nil") + } + } else { + if err != nil { + t.Errorf("ParseTransaction() unexpected error: %v", err) + } + } + + if len(events) != tt.wantEvents { + t.Errorf("ParseTransaction() got %d events, want %d", len(events), tt.wantEvents) + } + }) + } +} + +func TestFactory_ConcurrentAccess(t *testing.T) { + factory := NewFactory() + + // Test concurrent registration + done := make(chan bool) + + for i := 0; i < 10; i++ { + go func(n int) { + protocol := mevtypes.ProtocolType(fmt.Sprintf("protocol-%d", n)) + parser := &mockParser{protocol: protocol} + factory.RegisterParser(protocol, parser) + done <- true + }(i) + } + + for i := 0; i < 10; i++ { + <-done + } + + // Test concurrent reads + for i := 0; i < 10; i++ { + go func(n int) { + protocol := mevtypes.ProtocolType(fmt.Sprintf("protocol-%d", n)) + factory.GetParser(protocol) + done <- true + }(i) + } + + for i := 0; i < 10; i++ { + <-done + } +} diff --git a/pkg/types/pool_test.go b/pkg/types/pool_test.go new file mode 100644 index 0000000..f16449d --- /dev/null +++ b/pkg/types/pool_test.go @@ -0,0 +1,286 @@ +package types + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +func TestPoolInfo_Validate(t *testing.T) { + validPool := &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + PoolType: "constant-product", + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + Reserve0: big.NewInt(1000000), + Reserve1: big.NewInt(500000), + IsActive: true, + } + + tests := []struct { + name string + pool *PoolInfo + wantErr error + }{ + { + name: "valid pool", + pool: validPool, + wantErr: nil, + }, + { + name: "invalid pool address", + pool: &PoolInfo{ + Address: common.Address{}, + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + }, + wantErr: ErrInvalidPoolAddress, + }, + { + name: "invalid token0 address", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.Address{}, + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + }, + wantErr: ErrInvalidToken0Address, + }, + { + name: "invalid token1 address", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.Address{}, + Token0Decimals: 18, + Token1Decimals: 6, + }, + wantErr: ErrInvalidToken1Address, + }, + { + name: "invalid token0 decimals - zero", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 0, + Token1Decimals: 6, + }, + wantErr: ErrInvalidToken0Decimals, + }, + { + name: "invalid token0 decimals - too high", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 19, + Token1Decimals: 6, + }, + wantErr: ErrInvalidToken0Decimals, + }, + { + name: "invalid token1 decimals - zero", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 0, + }, + wantErr: ErrInvalidToken1Decimals, + }, + { + name: "invalid token1 decimals - too high", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 19, + }, + wantErr: ErrInvalidToken1Decimals, + }, + { + name: "unknown protocol", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUnknown, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + }, + wantErr: ErrUnknownProtocol, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.pool.Validate() + if err != tt.wantErr { + t.Errorf("Validate() error = %v, want %v", err, tt.wantErr) + } + }) + } +} + +func TestPoolInfo_GetTokenPair(t *testing.T) { + tests := []struct { + name string + pool *PoolInfo + wantToken0 common.Address + wantToken1 common.Address + }{ + { + name: "token0 < token1", + pool: &PoolInfo{ + Token0: common.HexToAddress("0x1111"), + Token1: common.HexToAddress("0x2222"), + }, + wantToken0: common.HexToAddress("0x1111"), + wantToken1: common.HexToAddress("0x2222"), + }, + { + name: "token1 < token0", + pool: &PoolInfo{ + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x1111"), + }, + wantToken0: common.HexToAddress("0x1111"), + wantToken1: common.HexToAddress("0x2222"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token0, token1 := tt.pool.GetTokenPair() + if token0 != tt.wantToken0 { + t.Errorf("GetTokenPair() token0 = %v, want %v", token0, tt.wantToken0) + } + if token1 != tt.wantToken1 { + t.Errorf("GetTokenPair() token1 = %v, want %v", token1, tt.wantToken1) + } + }) + } +} + +func TestPoolInfo_CalculatePrice(t *testing.T) { + tests := []struct { + name string + pool *PoolInfo + wantPrice string // String representation for comparison + }{ + { + name: "equal decimals", + pool: &PoolInfo{ + Token0Decimals: 18, + Token1Decimals: 18, + Reserve0: big.NewInt(1000000000000000000), // 1e18 + Reserve1: big.NewInt(2000000000000000000), // 2e18 + }, + wantPrice: "2", + }, + { + name: "different decimals - USDC/WETH", + pool: &PoolInfo{ + Token0Decimals: 6, // USDC + Token1Decimals: 18, // WETH + Reserve0: big.NewInt(1000000), // 1 USDC + Reserve1: big.NewInt(1000000000000000000), // 1 WETH + }, + wantPrice: "1000000000000", // 1 WETH = 1,000,000,000,000 scaled USDC + }, + { + name: "zero reserve0", + pool: &PoolInfo{ + Token0Decimals: 18, + Token1Decimals: 18, + Reserve0: big.NewInt(0), + Reserve1: big.NewInt(1000), + }, + wantPrice: "0", + }, + { + name: "nil reserves", + pool: &PoolInfo{ + Token0Decimals: 18, + Token1Decimals: 18, + Reserve0: nil, + Reserve1: nil, + }, + wantPrice: "0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + price := tt.pool.CalculatePrice() + if price.String() != tt.wantPrice { + t.Errorf("CalculatePrice() = %v, want %v", price.String(), tt.wantPrice) + } + }) + } +} + +func Test_scaleToDecimals(t *testing.T) { + tests := []struct { + name string + amount *big.Int + fromDecimals uint8 + toDecimals uint8 + want *big.Int + }{ + { + name: "same decimals", + amount: big.NewInt(1000), + fromDecimals: 18, + toDecimals: 18, + want: big.NewInt(1000), + }, + { + name: "scale up - 6 to 18 decimals", + amount: big.NewInt(1000000), // 1 USDC (6 decimals) + fromDecimals: 6, + toDecimals: 18, + want: new(big.Int).Mul(big.NewInt(1000000), new(big.Int).Exp(big.NewInt(10), big.NewInt(12), nil)), + }, + { + name: "scale down - 18 to 6 decimals", + amount: new(big.Int).Mul(big.NewInt(1), new(big.Int).Exp(big.NewInt(10), big.NewInt(18), nil)), // 1 ETH + fromDecimals: 18, + toDecimals: 6, + want: big.NewInt(1000000), + }, + { + name: "scale up - 8 to 18 decimals (WBTC to ETH)", + amount: big.NewInt(100000000), // 1 WBTC (8 decimals) + fromDecimals: 8, + toDecimals: 18, + want: new(big.Int).Mul(big.NewInt(100000000), new(big.Int).Exp(big.NewInt(10), big.NewInt(10), nil)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := scaleToDecimals(tt.amount, tt.fromDecimals, tt.toDecimals) + if got.Cmp(tt.want) != 0 { + t.Errorf("scaleToDecimals() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/types/swap_test.go b/pkg/types/swap_test.go new file mode 100644 index 0000000..2daaf40 --- /dev/null +++ b/pkg/types/swap_test.go @@ -0,0 +1,259 @@ +package types + +import ( + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" +) + +func TestSwapEvent_Validate(t *testing.T) { + validEvent := &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + BlockNumber: 1000, + LogIndex: 0, + Timestamp: time.Now(), + PoolAddress: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + Amount0In: big.NewInt(1000), + Amount1In: big.NewInt(0), + Amount0Out: big.NewInt(0), + Amount1Out: big.NewInt(500), + Sender: common.HexToAddress("0x4444"), + Recipient: common.HexToAddress("0x5555"), + } + + tests := []struct { + name string + event *SwapEvent + wantErr error + }{ + { + name: "valid event", + event: validEvent, + wantErr: nil, + }, + { + name: "invalid tx hash", + event: &SwapEvent{ + TxHash: common.Hash{}, + PoolAddress: common.HexToAddress("0x1111"), + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Protocol: ProtocolUniswapV2, + Amount0In: big.NewInt(1000), + }, + wantErr: ErrInvalidTxHash, + }, + { + name: "invalid pool address", + event: &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + PoolAddress: common.Address{}, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Protocol: ProtocolUniswapV2, + Amount0In: big.NewInt(1000), + }, + wantErr: ErrInvalidPoolAddress, + }, + { + name: "invalid token0 address", + event: &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + PoolAddress: common.HexToAddress("0x1111"), + Token0: common.Address{}, + Token1: common.HexToAddress("0x3333"), + Protocol: ProtocolUniswapV2, + Amount0In: big.NewInt(1000), + }, + wantErr: ErrInvalidToken0Address, + }, + { + name: "invalid token1 address", + event: &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + PoolAddress: common.HexToAddress("0x1111"), + Token0: common.HexToAddress("0x2222"), + Token1: common.Address{}, + Protocol: ProtocolUniswapV2, + Amount0In: big.NewInt(1000), + }, + wantErr: ErrInvalidToken1Address, + }, + { + name: "unknown protocol", + event: &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + PoolAddress: common.HexToAddress("0x1111"), + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Protocol: ProtocolUnknown, + Amount0In: big.NewInt(1000), + }, + wantErr: ErrUnknownProtocol, + }, + { + name: "zero amounts", + event: &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + PoolAddress: common.HexToAddress("0x1111"), + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Protocol: ProtocolUniswapV2, + Amount0In: big.NewInt(0), + Amount1In: big.NewInt(0), + Amount0Out: big.NewInt(0), + Amount1Out: big.NewInt(0), + }, + wantErr: ErrZeroAmounts, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.event.Validate() + if err != tt.wantErr { + t.Errorf("Validate() error = %v, want %v", err, tt.wantErr) + } + }) + } +} + +func TestSwapEvent_GetInputToken(t *testing.T) { + tests := []struct { + name string + event *SwapEvent + wantToken common.Address + wantAmount *big.Int + }{ + { + name: "token0 input", + event: &SwapEvent{ + Token0: common.HexToAddress("0x1111"), + Token1: common.HexToAddress("0x2222"), + Amount0In: big.NewInt(1000), + Amount1In: big.NewInt(0), + Amount0Out: big.NewInt(0), + Amount1Out: big.NewInt(500), + }, + wantToken: common.HexToAddress("0x1111"), + wantAmount: big.NewInt(1000), + }, + { + name: "token1 input", + event: &SwapEvent{ + Token0: common.HexToAddress("0x1111"), + Token1: common.HexToAddress("0x2222"), + Amount0In: big.NewInt(0), + Amount1In: big.NewInt(500), + Amount0Out: big.NewInt(1000), + Amount1Out: big.NewInt(0), + }, + wantToken: common.HexToAddress("0x2222"), + wantAmount: big.NewInt(500), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, amount := tt.event.GetInputToken() + if token != tt.wantToken { + t.Errorf("GetInputToken() token = %v, want %v", token, tt.wantToken) + } + if amount.Cmp(tt.wantAmount) != 0 { + t.Errorf("GetInputToken() amount = %v, want %v", amount, tt.wantAmount) + } + }) + } +} + +func TestSwapEvent_GetOutputToken(t *testing.T) { + tests := []struct { + name string + event *SwapEvent + wantToken common.Address + wantAmount *big.Int + }{ + { + name: "token0 output", + event: &SwapEvent{ + Token0: common.HexToAddress("0x1111"), + Token1: common.HexToAddress("0x2222"), + Amount0In: big.NewInt(0), + Amount1In: big.NewInt(500), + Amount0Out: big.NewInt(1000), + Amount1Out: big.NewInt(0), + }, + wantToken: common.HexToAddress("0x1111"), + wantAmount: big.NewInt(1000), + }, + { + name: "token1 output", + event: &SwapEvent{ + Token0: common.HexToAddress("0x1111"), + Token1: common.HexToAddress("0x2222"), + Amount0In: big.NewInt(1000), + Amount1In: big.NewInt(0), + Amount0Out: big.NewInt(0), + Amount1Out: big.NewInt(500), + }, + wantToken: common.HexToAddress("0x2222"), + wantAmount: big.NewInt(500), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, amount := tt.event.GetOutputToken() + if token != tt.wantToken { + t.Errorf("GetOutputToken() token = %v, want %v", token, tt.wantToken) + } + if amount.Cmp(tt.wantAmount) != 0 { + t.Errorf("GetOutputToken() amount = %v, want %v", amount, tt.wantAmount) + } + }) + } +} + +func Test_isZero(t *testing.T) { + tests := []struct { + name string + n *big.Int + want bool + }{ + { + name: "nil is zero", + n: nil, + want: true, + }, + { + name: "zero value is zero", + n: big.NewInt(0), + want: true, + }, + { + name: "positive value is not zero", + n: big.NewInt(100), + want: false, + }, + { + name: "negative value is not zero", + n: big.NewInt(-100), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isZero(tt.n); got != tt.want { + t.Errorf("isZero() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/validation/validator.go b/pkg/validation/validator.go new file mode 100644 index 0000000..9765f9b --- /dev/null +++ b/pkg/validation/validator.go @@ -0,0 +1,141 @@ +package validation + +import ( + "context" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + + "github.com/your-org/mev-bot/pkg/types" +) + +// validator implements the Validator interface +type validator struct { + rules *ValidationRules +} + +// NewValidator creates a new validator with the given rules +func NewValidator(rules *ValidationRules) Validator { + if rules == nil { + rules = DefaultValidationRules() + } + return &validator{ + rules: rules, + } +} + +// ValidateSwapEvent validates a swap event +func (v *validator) ValidateSwapEvent(ctx context.Context, event *types.SwapEvent) error { + if event == nil { + return fmt.Errorf("event cannot be nil") + } + + // First, run built-in validation + if err := event.Validate(); err != nil { + return err + } + + // Additional validation based on rules + if v.rules.RejectZeroAddresses { + if event.Token0 == (common.Address{}) || event.Token1 == (common.Address{}) { + return types.ErrInvalidToken0Address + } + } + + if v.rules.RejectZeroAmounts { + if isZero(event.Amount0In) && isZero(event.Amount1In) && + isZero(event.Amount0Out) && isZero(event.Amount1Out) { + return types.ErrZeroAmounts + } + } + + // Check amount thresholds + amounts := []*big.Int{event.Amount0In, event.Amount1In, event.Amount0Out, event.Amount1Out} + for _, amount := range amounts { + if amount == nil || amount.Sign() == 0 { + continue + } + + if v.rules.MinAmount != nil && amount.Cmp(v.rules.MinAmount) < 0 { + return fmt.Errorf("amount %s below minimum %s", amount.String(), v.rules.MinAmount.String()) + } + + if v.rules.MaxAmount != nil && amount.Cmp(v.rules.MaxAmount) > 0 { + return fmt.Errorf("amount %s exceeds maximum %s", amount.String(), v.rules.MaxAmount.String()) + } + } + + // Check if protocol is allowed + if len(v.rules.AllowedProtocols) > 0 { + if !v.rules.AllowedProtocols[event.Protocol] { + return fmt.Errorf("protocol %s not allowed", event.Protocol) + } + } + + // Check blacklisted pools + if v.rules.BlacklistedPools[event.PoolAddress] { + return fmt.Errorf("pool %s is blacklisted", event.PoolAddress.Hex()) + } + + // Check blacklisted tokens + if v.rules.BlacklistedTokens[event.Token0] || v.rules.BlacklistedTokens[event.Token1] { + return fmt.Errorf("blacklisted token in swap") + } + + return nil +} + +// ValidatePoolInfo validates pool information +func (v *validator) ValidatePoolInfo(ctx context.Context, pool *types.PoolInfo) error { + if pool == nil { + return fmt.Errorf("pool cannot be nil") + } + + // Run built-in validation + if err := pool.Validate(); err != nil { + return err + } + + // Check if protocol is allowed + if len(v.rules.AllowedProtocols) > 0 { + if !v.rules.AllowedProtocols[pool.Protocol] { + return fmt.Errorf("protocol %s not allowed", pool.Protocol) + } + } + + // Check blacklisted pool + if v.rules.BlacklistedPools[pool.Address] { + return fmt.Errorf("pool %s is blacklisted", pool.Address.Hex()) + } + + // Check blacklisted tokens + if v.rules.BlacklistedTokens[pool.Token0] || v.rules.BlacklistedTokens[pool.Token1] { + return fmt.Errorf("blacklisted token in pool") + } + + return nil +} + +// FilterValid filters a slice of swap events, returning only valid ones +func (v *validator) FilterValid(ctx context.Context, events []*types.SwapEvent) []*types.SwapEvent { + var valid []*types.SwapEvent + + for _, event := range events { + if err := v.ValidateSwapEvent(ctx, event); err == nil { + valid = append(valid, event) + } + } + + return valid +} + +// GetValidationRules returns the current validation rules +func (v *validator) GetValidationRules() *ValidationRules { + return v.rules +} + +// isZero checks if a big.Int is nil or zero +func isZero(n *big.Int) bool { + return n == nil || n.Sign() == 0 +} diff --git a/pkg/validation/validator_test.go b/pkg/validation/validator_test.go new file mode 100644 index 0000000..46eef80 --- /dev/null +++ b/pkg/validation/validator_test.go @@ -0,0 +1,395 @@ +package validation + +import ( + "context" + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + + "github.com/your-org/mev-bot/pkg/types" +) + +func createValidSwapEvent() *types.SwapEvent { + return &types.SwapEvent{ + TxHash: common.HexToHash("0x1234"), + BlockNumber: 1000, + LogIndex: 0, + Timestamp: time.Now(), + PoolAddress: common.HexToAddress("0x1111"), + Protocol: types.ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + Amount0In: big.NewInt(1000), + Amount1In: big.NewInt(0), + Amount0Out: big.NewInt(0), + Amount1Out: big.NewInt(500), + Sender: common.HexToAddress("0x4444"), + Recipient: common.HexToAddress("0x5555"), + } +} + +func createValidPoolInfo() *types.PoolInfo { + return &types.PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: types.ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + Reserve0: big.NewInt(1000000), + Reserve1: big.NewInt(500000), + IsActive: true, + } +} + +func TestNewValidator(t *testing.T) { + validator := NewValidator(nil) + if validator == nil { + t.Fatal("NewValidator returned nil") + } + + rules := DefaultValidationRules() + validator = NewValidator(rules) + if validator == nil { + t.Fatal("NewValidator with rules returned nil") + } +} + +func TestDefaultValidationRules(t *testing.T) { + rules := DefaultValidationRules() + + if rules == nil { + t.Fatal("DefaultValidationRules returned nil") + } + + if !rules.RejectZeroAddresses { + t.Error("DefaultValidationRules RejectZeroAddresses should be true") + } + + if !rules.RejectZeroAmounts { + t.Error("DefaultValidationRules RejectZeroAmounts should be true") + } + + if rules.MinAmount == nil { + t.Error("DefaultValidationRules MinAmount should not be nil") + } + + if rules.MaxAmount == nil { + t.Error("DefaultValidationRules MaxAmount should not be nil") + } +} + +func TestValidator_ValidateSwapEvent(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + rules *ValidationRules + event *types.SwapEvent + wantErr bool + }{ + { + name: "valid event with default rules", + rules: DefaultValidationRules(), + event: createValidSwapEvent(), + wantErr: false, + }, + { + name: "nil event", + rules: DefaultValidationRules(), + event: nil, + wantErr: true, + }, + { + name: "amount below minimum", + rules: DefaultValidationRules(), + event: func() *types.SwapEvent { + e := createValidSwapEvent() + e.Amount0In = big.NewInt(1) // Below default minimum + return e + }(), + wantErr: true, + }, + { + name: "amount above maximum", + rules: &ValidationRules{ + RejectZeroAddresses: true, + RejectZeroAmounts: true, + MinAmount: big.NewInt(1), + MaxAmount: big.NewInt(100), + AllowedProtocols: make(map[types.ProtocolType]bool), + BlacklistedPools: make(map[common.Address]bool), + BlacklistedTokens: make(map[common.Address]bool), + }, + event: func() *types.SwapEvent { + e := createValidSwapEvent() + e.Amount0In = big.NewInt(1000) // Above maximum + return e + }(), + wantErr: true, + }, + { + name: "protocol not allowed", + rules: &ValidationRules{ + RejectZeroAddresses: true, + RejectZeroAmounts: true, + MinAmount: big.NewInt(1), + MaxAmount: big.NewInt(1e18), + AllowedProtocols: map[types.ProtocolType]bool{ + types.ProtocolUniswapV3: true, + }, + BlacklistedPools: make(map[common.Address]bool), + BlacklistedTokens: make(map[common.Address]bool), + }, + event: createValidSwapEvent(), // UniswapV2 + wantErr: true, + }, + { + name: "blacklisted pool", + rules: &ValidationRules{ + RejectZeroAddresses: true, + RejectZeroAmounts: true, + MinAmount: big.NewInt(1), + MaxAmount: big.NewInt(1e18), + AllowedProtocols: make(map[types.ProtocolType]bool), + BlacklistedPools: map[common.Address]bool{ + common.HexToAddress("0x1111"): true, + }, + BlacklistedTokens: make(map[common.Address]bool), + }, + event: createValidSwapEvent(), + wantErr: true, + }, + { + name: "blacklisted token", + rules: &ValidationRules{ + RejectZeroAddresses: true, + RejectZeroAmounts: true, + MinAmount: big.NewInt(1), + MaxAmount: big.NewInt(1e18), + AllowedProtocols: make(map[types.ProtocolType]bool), + BlacklistedPools: make(map[common.Address]bool), + BlacklistedTokens: map[common.Address]bool{ + common.HexToAddress("0x2222"): true, + }, + }, + event: createValidSwapEvent(), + wantErr: true, + }, + { + name: "zero amounts when rejected", + rules: DefaultValidationRules(), + event: func() *types.SwapEvent { + e := createValidSwapEvent() + e.Amount0In = big.NewInt(0) + e.Amount1In = big.NewInt(0) + e.Amount0Out = big.NewInt(0) + e.Amount1Out = big.NewInt(0) + return e + }(), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewValidator(tt.rules) + err := validator.ValidateSwapEvent(ctx, tt.event) + + if tt.wantErr { + if err == nil { + t.Error("ValidateSwapEvent() expected error, got nil") + } + } else { + if err != nil { + t.Errorf("ValidateSwapEvent() unexpected error: %v", err) + } + } + }) + } +} + +func TestValidator_ValidatePoolInfo(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + rules *ValidationRules + pool *types.PoolInfo + wantErr bool + }{ + { + name: "valid pool with default rules", + rules: DefaultValidationRules(), + pool: createValidPoolInfo(), + wantErr: false, + }, + { + name: "nil pool", + rules: DefaultValidationRules(), + pool: nil, + wantErr: true, + }, + { + name: "protocol not allowed", + rules: &ValidationRules{ + AllowedProtocols: map[types.ProtocolType]bool{ + types.ProtocolUniswapV3: true, + }, + BlacklistedPools: make(map[common.Address]bool), + BlacklistedTokens: make(map[common.Address]bool), + }, + pool: createValidPoolInfo(), // UniswapV2 + wantErr: true, + }, + { + name: "blacklisted pool", + rules: &ValidationRules{ + AllowedProtocols: make(map[types.ProtocolType]bool), + BlacklistedPools: map[common.Address]bool{ + common.HexToAddress("0x1111"): true, + }, + BlacklistedTokens: make(map[common.Address]bool), + }, + pool: createValidPoolInfo(), + wantErr: true, + }, + { + name: "blacklisted token", + rules: &ValidationRules{ + AllowedProtocols: make(map[types.ProtocolType]bool), + BlacklistedPools: make(map[common.Address]bool), + BlacklistedTokens: map[common.Address]bool{ + common.HexToAddress("0x2222"): true, + }, + }, + pool: createValidPoolInfo(), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewValidator(tt.rules) + err := validator.ValidatePoolInfo(ctx, tt.pool) + + if tt.wantErr { + if err == nil { + t.Error("ValidatePoolInfo() expected error, got nil") + } + } else { + if err != nil { + t.Errorf("ValidatePoolInfo() unexpected error: %v", err) + } + } + }) + } +} + +func TestValidator_FilterValid(t *testing.T) { + ctx := context.Background() + + validEvent1 := createValidSwapEvent() + validEvent2 := createValidSwapEvent() + validEvent2.PoolAddress = common.HexToAddress("0x9999") + + invalidEvent := createValidSwapEvent() + invalidEvent.Amount0In = big.NewInt(0) + invalidEvent.Amount1In = big.NewInt(0) + invalidEvent.Amount0Out = big.NewInt(0) + invalidEvent.Amount1Out = big.NewInt(0) + + tests := []struct { + name string + rules *ValidationRules + events []*types.SwapEvent + wantCount int + }{ + { + name: "all valid events", + rules: DefaultValidationRules(), + events: []*types.SwapEvent{validEvent1, validEvent2}, + wantCount: 2, + }, + { + name: "mixed valid and invalid", + rules: DefaultValidationRules(), + events: []*types.SwapEvent{validEvent1, invalidEvent, validEvent2}, + wantCount: 2, + }, + { + name: "all invalid events", + rules: DefaultValidationRules(), + events: []*types.SwapEvent{invalidEvent}, + wantCount: 0, + }, + { + name: "empty slice", + rules: DefaultValidationRules(), + events: []*types.SwapEvent{}, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewValidator(tt.rules) + valid := validator.FilterValid(ctx, tt.events) + + if len(valid) != tt.wantCount { + t.Errorf("FilterValid() count = %d, want %d", len(valid), tt.wantCount) + } + }) + } +} + +func TestValidator_GetValidationRules(t *testing.T) { + rules := DefaultValidationRules() + validator := NewValidator(rules) + + retrievedRules := validator.GetValidationRules() + if retrievedRules != rules { + t.Error("GetValidationRules() returned different rules") + } +} + +func Test_isZero_Validation(t *testing.T) { + tests := []struct { + name string + n *big.Int + want bool + }{ + { + name: "nil is zero", + n: nil, + want: true, + }, + { + name: "zero value is zero", + n: big.NewInt(0), + want: true, + }, + { + name: "positive value is not zero", + n: big.NewInt(100), + want: false, + }, + { + name: "negative value is not zero", + n: big.NewInt(-100), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isZero(tt.n); got != tt.want { + t.Errorf("isZero() = %v, want %v", got, tt.want) + } + }) + } +}