package validation import ( "context" "fmt" "math/big" "github.com/ethereum/go-ethereum/common" "github.com/your-org/mev-bot/pkg/types" ) // validator implements the Validator interface type validator struct { rules *ValidationRules } // NewValidator creates a new validator with the given rules func NewValidator(rules *ValidationRules) Validator { if rules == nil { rules = DefaultValidationRules() } return &validator{ rules: rules, } } // ValidateSwapEvent validates a swap event func (v *validator) ValidateSwapEvent(ctx context.Context, event *types.SwapEvent) error { if event == nil { return fmt.Errorf("event cannot be nil") } // First, run built-in validation if err := event.Validate(); err != nil { return err } // Additional validation based on rules if v.rules.RejectZeroAddresses { if event.Token0 == (common.Address{}) || event.Token1 == (common.Address{}) { return types.ErrInvalidToken0Address } } if v.rules.RejectZeroAmounts { if isZero(event.Amount0In) && isZero(event.Amount1In) && isZero(event.Amount0Out) && isZero(event.Amount1Out) { return types.ErrZeroAmounts } } // Check amount thresholds amounts := []*big.Int{event.Amount0In, event.Amount1In, event.Amount0Out, event.Amount1Out} for _, amount := range amounts { if amount == nil || amount.Sign() == 0 { continue } if v.rules.MinAmount != nil && amount.Cmp(v.rules.MinAmount) < 0 { return fmt.Errorf("amount %s below minimum %s", amount.String(), v.rules.MinAmount.String()) } if v.rules.MaxAmount != nil && amount.Cmp(v.rules.MaxAmount) > 0 { return fmt.Errorf("amount %s exceeds maximum %s", amount.String(), v.rules.MaxAmount.String()) } } // Check if protocol is allowed if len(v.rules.AllowedProtocols) > 0 { if !v.rules.AllowedProtocols[event.Protocol] { return fmt.Errorf("protocol %s not allowed", event.Protocol) } } // Check blacklisted pools if v.rules.BlacklistedPools[event.PoolAddress] { return fmt.Errorf("pool %s is blacklisted", event.PoolAddress.Hex()) } // Check blacklisted tokens if v.rules.BlacklistedTokens[event.Token0] || v.rules.BlacklistedTokens[event.Token1] { return fmt.Errorf("blacklisted token in swap") } return nil } // ValidatePoolInfo validates pool information func (v *validator) ValidatePoolInfo(ctx context.Context, pool *types.PoolInfo) error { if pool == nil { return fmt.Errorf("pool cannot be nil") } // Run built-in validation if err := pool.Validate(); err != nil { return err } // Check if protocol is allowed if len(v.rules.AllowedProtocols) > 0 { if !v.rules.AllowedProtocols[pool.Protocol] { return fmt.Errorf("protocol %s not allowed", pool.Protocol) } } // Check blacklisted pool if v.rules.BlacklistedPools[pool.Address] { return fmt.Errorf("pool %s is blacklisted", pool.Address.Hex()) } // Check blacklisted tokens if v.rules.BlacklistedTokens[pool.Token0] || v.rules.BlacklistedTokens[pool.Token1] { return fmt.Errorf("blacklisted token in pool") } return nil } // FilterValid filters a slice of swap events, returning only valid ones func (v *validator) FilterValid(ctx context.Context, events []*types.SwapEvent) []*types.SwapEvent { var valid []*types.SwapEvent for _, event := range events { if err := v.ValidateSwapEvent(ctx, event); err == nil { valid = append(valid, event) } } return valid } // GetValidationRules returns the current validation rules func (v *validator) GetValidationRules() *ValidationRules { return v.rules } // isZero checks if a big.Int is nil or zero func isZero(n *big.Int) bool { return n == nil || n.Sign() == 0 }