package validation import ( "context" "math/big" "testing" "time" "github.com/ethereum/go-ethereum/common" "github.com/your-org/mev-bot/pkg/types" ) func createValidSwapEvent() *types.SwapEvent { return &types.SwapEvent{ TxHash: common.HexToHash("0x1234"), BlockNumber: 1000, LogIndex: 0, Timestamp: time.Now(), PoolAddress: common.HexToAddress("0x1111"), Protocol: types.ProtocolUniswapV2, Token0: common.HexToAddress("0x2222"), Token1: common.HexToAddress("0x3333"), Token0Decimals: 18, Token1Decimals: 6, Amount0In: big.NewInt(1000), Amount1In: big.NewInt(0), Amount0Out: big.NewInt(0), Amount1Out: big.NewInt(500), Sender: common.HexToAddress("0x4444"), Recipient: common.HexToAddress("0x5555"), } } func createValidPoolInfo() *types.PoolInfo { return &types.PoolInfo{ Address: common.HexToAddress("0x1111"), Protocol: types.ProtocolUniswapV2, Token0: common.HexToAddress("0x2222"), Token1: common.HexToAddress("0x3333"), Token0Decimals: 18, Token1Decimals: 6, Reserve0: big.NewInt(1000000), Reserve1: big.NewInt(500000), IsActive: true, } } func TestNewValidator(t *testing.T) { validator := NewValidator(nil) if validator == nil { t.Fatal("NewValidator returned nil") } rules := DefaultValidationRules() validator = NewValidator(rules) if validator == nil { t.Fatal("NewValidator with rules returned nil") } } func TestDefaultValidationRules(t *testing.T) { rules := DefaultValidationRules() if rules == nil { t.Fatal("DefaultValidationRules returned nil") } if !rules.RejectZeroAddresses { t.Error("DefaultValidationRules RejectZeroAddresses should be true") } if !rules.RejectZeroAmounts { t.Error("DefaultValidationRules RejectZeroAmounts should be true") } if rules.MinAmount == nil { t.Error("DefaultValidationRules MinAmount should not be nil") } if rules.MaxAmount == nil { t.Error("DefaultValidationRules MaxAmount should not be nil") } } func TestValidator_ValidateSwapEvent(t *testing.T) { ctx := context.Background() tests := []struct { name string rules *ValidationRules event *types.SwapEvent wantErr bool }{ { name: "valid event with default rules", rules: DefaultValidationRules(), event: createValidSwapEvent(), wantErr: false, }, { name: "nil event", rules: DefaultValidationRules(), event: nil, wantErr: true, }, { name: "amount below minimum", rules: DefaultValidationRules(), event: func() *types.SwapEvent { e := createValidSwapEvent() e.Amount0In = big.NewInt(1) // Below default minimum return e }(), wantErr: true, }, { name: "amount above maximum", rules: &ValidationRules{ RejectZeroAddresses: true, RejectZeroAmounts: true, MinAmount: big.NewInt(1), MaxAmount: big.NewInt(100), AllowedProtocols: make(map[types.ProtocolType]bool), BlacklistedPools: make(map[common.Address]bool), BlacklistedTokens: make(map[common.Address]bool), }, event: func() *types.SwapEvent { e := createValidSwapEvent() e.Amount0In = big.NewInt(1000) // Above maximum return e }(), wantErr: true, }, { name: "protocol not allowed", rules: &ValidationRules{ RejectZeroAddresses: true, RejectZeroAmounts: true, MinAmount: big.NewInt(1), MaxAmount: big.NewInt(1e18), AllowedProtocols: map[types.ProtocolType]bool{ types.ProtocolUniswapV3: true, }, BlacklistedPools: make(map[common.Address]bool), BlacklistedTokens: make(map[common.Address]bool), }, event: createValidSwapEvent(), // UniswapV2 wantErr: true, }, { name: "blacklisted pool", rules: &ValidationRules{ RejectZeroAddresses: true, RejectZeroAmounts: true, MinAmount: big.NewInt(1), MaxAmount: big.NewInt(1e18), AllowedProtocols: make(map[types.ProtocolType]bool), BlacklistedPools: map[common.Address]bool{ common.HexToAddress("0x1111"): true, }, BlacklistedTokens: make(map[common.Address]bool), }, event: createValidSwapEvent(), wantErr: true, }, { name: "blacklisted token", rules: &ValidationRules{ RejectZeroAddresses: true, RejectZeroAmounts: true, MinAmount: big.NewInt(1), MaxAmount: big.NewInt(1e18), AllowedProtocols: make(map[types.ProtocolType]bool), BlacklistedPools: make(map[common.Address]bool), BlacklistedTokens: map[common.Address]bool{ common.HexToAddress("0x2222"): true, }, }, event: createValidSwapEvent(), wantErr: true, }, { name: "zero amounts when rejected", rules: DefaultValidationRules(), event: func() *types.SwapEvent { e := createValidSwapEvent() e.Amount0In = big.NewInt(0) e.Amount1In = big.NewInt(0) e.Amount0Out = big.NewInt(0) e.Amount1Out = big.NewInt(0) return e }(), wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { validator := NewValidator(tt.rules) err := validator.ValidateSwapEvent(ctx, tt.event) if tt.wantErr { if err == nil { t.Error("ValidateSwapEvent() expected error, got nil") } } else { if err != nil { t.Errorf("ValidateSwapEvent() unexpected error: %v", err) } } }) } } func TestValidator_ValidatePoolInfo(t *testing.T) { ctx := context.Background() tests := []struct { name string rules *ValidationRules pool *types.PoolInfo wantErr bool }{ { name: "valid pool with default rules", rules: DefaultValidationRules(), pool: createValidPoolInfo(), wantErr: false, }, { name: "nil pool", rules: DefaultValidationRules(), pool: nil, wantErr: true, }, { name: "protocol not allowed", rules: &ValidationRules{ AllowedProtocols: map[types.ProtocolType]bool{ types.ProtocolUniswapV3: true, }, BlacklistedPools: make(map[common.Address]bool), BlacklistedTokens: make(map[common.Address]bool), }, pool: createValidPoolInfo(), // UniswapV2 wantErr: true, }, { name: "blacklisted pool", rules: &ValidationRules{ AllowedProtocols: make(map[types.ProtocolType]bool), BlacklistedPools: map[common.Address]bool{ common.HexToAddress("0x1111"): true, }, BlacklistedTokens: make(map[common.Address]bool), }, pool: createValidPoolInfo(), wantErr: true, }, { name: "blacklisted token", rules: &ValidationRules{ AllowedProtocols: make(map[types.ProtocolType]bool), BlacklistedPools: make(map[common.Address]bool), BlacklistedTokens: map[common.Address]bool{ common.HexToAddress("0x2222"): true, }, }, pool: createValidPoolInfo(), wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { validator := NewValidator(tt.rules) err := validator.ValidatePoolInfo(ctx, tt.pool) if tt.wantErr { if err == nil { t.Error("ValidatePoolInfo() expected error, got nil") } } else { if err != nil { t.Errorf("ValidatePoolInfo() unexpected error: %v", err) } } }) } } func TestValidator_FilterValid(t *testing.T) { ctx := context.Background() validEvent1 := createValidSwapEvent() validEvent2 := createValidSwapEvent() validEvent2.PoolAddress = common.HexToAddress("0x9999") invalidEvent := createValidSwapEvent() invalidEvent.Amount0In = big.NewInt(0) invalidEvent.Amount1In = big.NewInt(0) invalidEvent.Amount0Out = big.NewInt(0) invalidEvent.Amount1Out = big.NewInt(0) tests := []struct { name string rules *ValidationRules events []*types.SwapEvent wantCount int }{ { name: "all valid events", rules: DefaultValidationRules(), events: []*types.SwapEvent{validEvent1, validEvent2}, wantCount: 2, }, { name: "mixed valid and invalid", rules: DefaultValidationRules(), events: []*types.SwapEvent{validEvent1, invalidEvent, validEvent2}, wantCount: 2, }, { name: "all invalid events", rules: DefaultValidationRules(), events: []*types.SwapEvent{invalidEvent}, wantCount: 0, }, { name: "empty slice", rules: DefaultValidationRules(), events: []*types.SwapEvent{}, wantCount: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { validator := NewValidator(tt.rules) valid := validator.FilterValid(ctx, tt.events) if len(valid) != tt.wantCount { t.Errorf("FilterValid() count = %d, want %d", len(valid), tt.wantCount) } }) } } func TestValidator_GetValidationRules(t *testing.T) { rules := DefaultValidationRules() validator := NewValidator(rules) retrievedRules := validator.GetValidationRules() if retrievedRules != rules { t.Error("GetValidationRules() returned different rules") } } func Test_isZero_Validation(t *testing.T) { tests := []struct { name string n *big.Int want bool }{ { name: "nil is zero", n: nil, want: true, }, { name: "zero value is zero", n: big.NewInt(0), want: true, }, { name: "positive value is not zero", n: big.NewInt(100), want: false, }, { name: "negative value is not zero", n: big.NewInt(-100), want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := isZero(tt.n); got != tt.want { t.Errorf("isZero() = %v, want %v", got, tt.want) } }) } }