diff --git a/pkg/validation/validator.go b/pkg/validation/validator.go new file mode 100644 index 0000000..9765f9b --- /dev/null +++ b/pkg/validation/validator.go @@ -0,0 +1,141 @@ +package validation + +import ( + "context" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + + "github.com/your-org/mev-bot/pkg/types" +) + +// validator implements the Validator interface +type validator struct { + rules *ValidationRules +} + +// NewValidator creates a new validator with the given rules +func NewValidator(rules *ValidationRules) Validator { + if rules == nil { + rules = DefaultValidationRules() + } + return &validator{ + rules: rules, + } +} + +// ValidateSwapEvent validates a swap event +func (v *validator) ValidateSwapEvent(ctx context.Context, event *types.SwapEvent) error { + if event == nil { + return fmt.Errorf("event cannot be nil") + } + + // First, run built-in validation + if err := event.Validate(); err != nil { + return err + } + + // Additional validation based on rules + if v.rules.RejectZeroAddresses { + if event.Token0 == (common.Address{}) || event.Token1 == (common.Address{}) { + return types.ErrInvalidToken0Address + } + } + + if v.rules.RejectZeroAmounts { + if isZero(event.Amount0In) && isZero(event.Amount1In) && + isZero(event.Amount0Out) && isZero(event.Amount1Out) { + return types.ErrZeroAmounts + } + } + + // Check amount thresholds + amounts := []*big.Int{event.Amount0In, event.Amount1In, event.Amount0Out, event.Amount1Out} + for _, amount := range amounts { + if amount == nil || amount.Sign() == 0 { + continue + } + + if v.rules.MinAmount != nil && amount.Cmp(v.rules.MinAmount) < 0 { + return fmt.Errorf("amount %s below minimum %s", amount.String(), v.rules.MinAmount.String()) + } + + if v.rules.MaxAmount != nil && amount.Cmp(v.rules.MaxAmount) > 0 { + return fmt.Errorf("amount %s exceeds maximum %s", amount.String(), v.rules.MaxAmount.String()) + } + } + + // Check if protocol is allowed + if len(v.rules.AllowedProtocols) > 0 { + if !v.rules.AllowedProtocols[event.Protocol] { + return fmt.Errorf("protocol %s not allowed", event.Protocol) + } + } + + // Check blacklisted pools + if v.rules.BlacklistedPools[event.PoolAddress] { + return fmt.Errorf("pool %s is blacklisted", event.PoolAddress.Hex()) + } + + // Check blacklisted tokens + if v.rules.BlacklistedTokens[event.Token0] || v.rules.BlacklistedTokens[event.Token1] { + return fmt.Errorf("blacklisted token in swap") + } + + return nil +} + +// ValidatePoolInfo validates pool information +func (v *validator) ValidatePoolInfo(ctx context.Context, pool *types.PoolInfo) error { + if pool == nil { + return fmt.Errorf("pool cannot be nil") + } + + // Run built-in validation + if err := pool.Validate(); err != nil { + return err + } + + // Check if protocol is allowed + if len(v.rules.AllowedProtocols) > 0 { + if !v.rules.AllowedProtocols[pool.Protocol] { + return fmt.Errorf("protocol %s not allowed", pool.Protocol) + } + } + + // Check blacklisted pool + if v.rules.BlacklistedPools[pool.Address] { + return fmt.Errorf("pool %s is blacklisted", pool.Address.Hex()) + } + + // Check blacklisted tokens + if v.rules.BlacklistedTokens[pool.Token0] || v.rules.BlacklistedTokens[pool.Token1] { + return fmt.Errorf("blacklisted token in pool") + } + + return nil +} + +// FilterValid filters a slice of swap events, returning only valid ones +func (v *validator) FilterValid(ctx context.Context, events []*types.SwapEvent) []*types.SwapEvent { + var valid []*types.SwapEvent + + for _, event := range events { + if err := v.ValidateSwapEvent(ctx, event); err == nil { + valid = append(valid, event) + } + } + + return valid +} + +// GetValidationRules returns the current validation rules +func (v *validator) GetValidationRules() *ValidationRules { + return v.rules +} + +// isZero checks if a big.Int is nil or zero +func isZero(n *big.Int) bool { + return n == nil || n.Sign() == 0 +} diff --git a/pkg/validation/validator_test.go b/pkg/validation/validator_test.go new file mode 100644 index 0000000..46eef80 --- /dev/null +++ b/pkg/validation/validator_test.go @@ -0,0 +1,395 @@ +package validation + +import ( + "context" + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + + "github.com/your-org/mev-bot/pkg/types" +) + +func createValidSwapEvent() *types.SwapEvent { + return &types.SwapEvent{ + TxHash: common.HexToHash("0x1234"), + BlockNumber: 1000, + LogIndex: 0, + Timestamp: time.Now(), + PoolAddress: common.HexToAddress("0x1111"), + Protocol: types.ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + Amount0In: big.NewInt(1000), + Amount1In: big.NewInt(0), + Amount0Out: big.NewInt(0), + Amount1Out: big.NewInt(500), + Sender: common.HexToAddress("0x4444"), + Recipient: common.HexToAddress("0x5555"), + } +} + +func createValidPoolInfo() *types.PoolInfo { + return &types.PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: types.ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + Reserve0: big.NewInt(1000000), + Reserve1: big.NewInt(500000), + IsActive: true, + } +} + +func TestNewValidator(t *testing.T) { + validator := NewValidator(nil) + if validator == nil { + t.Fatal("NewValidator returned nil") + } + + rules := DefaultValidationRules() + validator = NewValidator(rules) + if validator == nil { + t.Fatal("NewValidator with rules returned nil") + } +} + +func TestDefaultValidationRules(t *testing.T) { + rules := DefaultValidationRules() + + if rules == nil { + t.Fatal("DefaultValidationRules returned nil") + } + + if !rules.RejectZeroAddresses { + t.Error("DefaultValidationRules RejectZeroAddresses should be true") + } + + if !rules.RejectZeroAmounts { + t.Error("DefaultValidationRules RejectZeroAmounts should be true") + } + + if rules.MinAmount == nil { + t.Error("DefaultValidationRules MinAmount should not be nil") + } + + if rules.MaxAmount == nil { + t.Error("DefaultValidationRules MaxAmount should not be nil") + } +} + +func TestValidator_ValidateSwapEvent(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + rules *ValidationRules + event *types.SwapEvent + wantErr bool + }{ + { + name: "valid event with default rules", + rules: DefaultValidationRules(), + event: createValidSwapEvent(), + wantErr: false, + }, + { + name: "nil event", + rules: DefaultValidationRules(), + event: nil, + wantErr: true, + }, + { + name: "amount below minimum", + rules: DefaultValidationRules(), + event: func() *types.SwapEvent { + e := createValidSwapEvent() + e.Amount0In = big.NewInt(1) // Below default minimum + return e + }(), + wantErr: true, + }, + { + name: "amount above maximum", + rules: &ValidationRules{ + RejectZeroAddresses: true, + RejectZeroAmounts: true, + MinAmount: big.NewInt(1), + MaxAmount: big.NewInt(100), + AllowedProtocols: make(map[types.ProtocolType]bool), + BlacklistedPools: make(map[common.Address]bool), + BlacklistedTokens: make(map[common.Address]bool), + }, + event: func() *types.SwapEvent { + e := createValidSwapEvent() + e.Amount0In = big.NewInt(1000) // Above maximum + return e + }(), + wantErr: true, + }, + { + name: "protocol not allowed", + rules: &ValidationRules{ + RejectZeroAddresses: true, + RejectZeroAmounts: true, + MinAmount: big.NewInt(1), + MaxAmount: big.NewInt(1e18), + AllowedProtocols: map[types.ProtocolType]bool{ + types.ProtocolUniswapV3: true, + }, + BlacklistedPools: make(map[common.Address]bool), + BlacklistedTokens: make(map[common.Address]bool), + }, + event: createValidSwapEvent(), // UniswapV2 + wantErr: true, + }, + { + name: "blacklisted pool", + rules: &ValidationRules{ + RejectZeroAddresses: true, + RejectZeroAmounts: true, + MinAmount: big.NewInt(1), + MaxAmount: big.NewInt(1e18), + AllowedProtocols: make(map[types.ProtocolType]bool), + BlacklistedPools: map[common.Address]bool{ + common.HexToAddress("0x1111"): true, + }, + BlacklistedTokens: make(map[common.Address]bool), + }, + event: createValidSwapEvent(), + wantErr: true, + }, + { + name: "blacklisted token", + rules: &ValidationRules{ + RejectZeroAddresses: true, + RejectZeroAmounts: true, + MinAmount: big.NewInt(1), + MaxAmount: big.NewInt(1e18), + AllowedProtocols: make(map[types.ProtocolType]bool), + BlacklistedPools: make(map[common.Address]bool), + BlacklistedTokens: map[common.Address]bool{ + common.HexToAddress("0x2222"): true, + }, + }, + event: createValidSwapEvent(), + wantErr: true, + }, + { + name: "zero amounts when rejected", + rules: DefaultValidationRules(), + event: func() *types.SwapEvent { + e := createValidSwapEvent() + e.Amount0In = big.NewInt(0) + e.Amount1In = big.NewInt(0) + e.Amount0Out = big.NewInt(0) + e.Amount1Out = big.NewInt(0) + return e + }(), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewValidator(tt.rules) + err := validator.ValidateSwapEvent(ctx, tt.event) + + if tt.wantErr { + if err == nil { + t.Error("ValidateSwapEvent() expected error, got nil") + } + } else { + if err != nil { + t.Errorf("ValidateSwapEvent() unexpected error: %v", err) + } + } + }) + } +} + +func TestValidator_ValidatePoolInfo(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + rules *ValidationRules + pool *types.PoolInfo + wantErr bool + }{ + { + name: "valid pool with default rules", + rules: DefaultValidationRules(), + pool: createValidPoolInfo(), + wantErr: false, + }, + { + name: "nil pool", + rules: DefaultValidationRules(), + pool: nil, + wantErr: true, + }, + { + name: "protocol not allowed", + rules: &ValidationRules{ + AllowedProtocols: map[types.ProtocolType]bool{ + types.ProtocolUniswapV3: true, + }, + BlacklistedPools: make(map[common.Address]bool), + BlacklistedTokens: make(map[common.Address]bool), + }, + pool: createValidPoolInfo(), // UniswapV2 + wantErr: true, + }, + { + name: "blacklisted pool", + rules: &ValidationRules{ + AllowedProtocols: make(map[types.ProtocolType]bool), + BlacklistedPools: map[common.Address]bool{ + common.HexToAddress("0x1111"): true, + }, + BlacklistedTokens: make(map[common.Address]bool), + }, + pool: createValidPoolInfo(), + wantErr: true, + }, + { + name: "blacklisted token", + rules: &ValidationRules{ + AllowedProtocols: make(map[types.ProtocolType]bool), + BlacklistedPools: make(map[common.Address]bool), + BlacklistedTokens: map[common.Address]bool{ + common.HexToAddress("0x2222"): true, + }, + }, + pool: createValidPoolInfo(), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewValidator(tt.rules) + err := validator.ValidatePoolInfo(ctx, tt.pool) + + if tt.wantErr { + if err == nil { + t.Error("ValidatePoolInfo() expected error, got nil") + } + } else { + if err != nil { + t.Errorf("ValidatePoolInfo() unexpected error: %v", err) + } + } + }) + } +} + +func TestValidator_FilterValid(t *testing.T) { + ctx := context.Background() + + validEvent1 := createValidSwapEvent() + validEvent2 := createValidSwapEvent() + validEvent2.PoolAddress = common.HexToAddress("0x9999") + + invalidEvent := createValidSwapEvent() + invalidEvent.Amount0In = big.NewInt(0) + invalidEvent.Amount1In = big.NewInt(0) + invalidEvent.Amount0Out = big.NewInt(0) + invalidEvent.Amount1Out = big.NewInt(0) + + tests := []struct { + name string + rules *ValidationRules + events []*types.SwapEvent + wantCount int + }{ + { + name: "all valid events", + rules: DefaultValidationRules(), + events: []*types.SwapEvent{validEvent1, validEvent2}, + wantCount: 2, + }, + { + name: "mixed valid and invalid", + rules: DefaultValidationRules(), + events: []*types.SwapEvent{validEvent1, invalidEvent, validEvent2}, + wantCount: 2, + }, + { + name: "all invalid events", + rules: DefaultValidationRules(), + events: []*types.SwapEvent{invalidEvent}, + wantCount: 0, + }, + { + name: "empty slice", + rules: DefaultValidationRules(), + events: []*types.SwapEvent{}, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewValidator(tt.rules) + valid := validator.FilterValid(ctx, tt.events) + + if len(valid) != tt.wantCount { + t.Errorf("FilterValid() count = %d, want %d", len(valid), tt.wantCount) + } + }) + } +} + +func TestValidator_GetValidationRules(t *testing.T) { + rules := DefaultValidationRules() + validator := NewValidator(rules) + + retrievedRules := validator.GetValidationRules() + if retrievedRules != rules { + t.Error("GetValidationRules() returned different rules") + } +} + +func Test_isZero_Validation(t *testing.T) { + tests := []struct { + name string + n *big.Int + want bool + }{ + { + name: "nil is zero", + n: nil, + want: true, + }, + { + name: "zero value is zero", + n: big.NewInt(0), + want: true, + }, + { + name: "positive value is not zero", + n: big.NewInt(100), + want: false, + }, + { + name: "negative value is not zero", + n: big.NewInt(-100), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isZero(tt.n); got != tt.want { + t.Errorf("isZero() = %v, want %v", got, tt.want) + } + }) + } +}