package security import ( "errors" "fmt" "math" "math/big" ) var ( // ErrIntegerOverflow indicates an integer overflow would occur ErrIntegerOverflow = errors.New("integer overflow detected") // ErrIntegerUnderflow indicates an integer underflow would occur ErrIntegerUnderflow = errors.New("integer underflow detected") // ErrDivisionByZero indicates division by zero was attempted ErrDivisionByZero = errors.New("division by zero") // ErrInvalidConversion indicates an invalid type conversion ErrInvalidConversion = errors.New("invalid type conversion") ) // SafeMath provides safe mathematical operations with overflow protection type SafeMath struct { // MaxGasPrice is the maximum allowed gas price in wei MaxGasPrice *big.Int // MaxTransactionValue is the maximum allowed transaction value MaxTransactionValue *big.Int } // NewSafeMath creates a new SafeMath instance with security limits func NewSafeMath() *SafeMath { // 10000 Gwei max gas price maxGasPrice := new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9)) // 10000 ETH max transaction value maxTxValue := new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e18)) return &SafeMath{ MaxGasPrice: maxGasPrice, MaxTransactionValue: maxTxValue, } } // SafeUint8 safely converts uint64 to uint8 with overflow check func SafeUint8(val uint64) (uint8, error) { if val > math.MaxUint8 { return 0, fmt.Errorf("%w: value %d exceeds uint8 max %d", ErrIntegerOverflow, val, math.MaxUint8) } return uint8(val), nil } // SafeUint32 safely converts uint64 to uint32 with overflow check func SafeUint32(val uint64) (uint32, error) { if val > math.MaxUint32 { return 0, fmt.Errorf("%w: value %d exceeds uint32 max %d", ErrIntegerOverflow, val, math.MaxUint32) } return uint32(val), nil } // SafeUint64FromBigInt safely converts big.Int to uint64 func SafeUint64FromBigInt(val *big.Int) (uint64, error) { if val == nil { return 0, fmt.Errorf("%w: nil value", ErrInvalidConversion) } if val.Sign() < 0 { return 0, fmt.Errorf("%w: negative value %s", ErrIntegerUnderflow, val.String()) } if val.BitLen() > 64 { return 0, fmt.Errorf("%w: value %s exceeds uint64 max", ErrIntegerOverflow, val.String()) } return val.Uint64(), nil } // SafeAdd performs safe addition with overflow check func (sm *SafeMath) SafeAdd(a, b *big.Int) (*big.Int, error) { if a == nil || b == nil { return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion) } result := new(big.Int).Add(a, b) // Check against maximum transaction value if result.Cmp(sm.MaxTransactionValue) > 0 { return nil, fmt.Errorf("%w: sum exceeds max transaction value", ErrIntegerOverflow) } return result, nil } // SafeSubtract performs safe subtraction with underflow check func (sm *SafeMath) SafeSubtract(a, b *big.Int) (*big.Int, error) { if a == nil || b == nil { return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion) } result := new(big.Int).Sub(a, b) // Check for negative result (underflow) if result.Sign() < 0 { return nil, fmt.Errorf("%w: subtraction would result in negative value", ErrIntegerUnderflow) } return result, nil } // SafeMultiply performs safe multiplication with overflow check func (sm *SafeMath) SafeMultiply(a, b *big.Int) (*big.Int, error) { if a == nil || b == nil { return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion) } // Check for zero to avoid unnecessary computation if a.Sign() == 0 || b.Sign() == 0 { return big.NewInt(0), nil } result := new(big.Int).Mul(a, b) // Check against maximum transaction value if result.Cmp(sm.MaxTransactionValue) > 0 { return nil, fmt.Errorf("%w: product exceeds max transaction value", ErrIntegerOverflow) } return result, nil } // SafeDivide performs safe division with zero check func (sm *SafeMath) SafeDivide(a, b *big.Int) (*big.Int, error) { if a == nil || b == nil { return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion) } if b.Sign() == 0 { return nil, ErrDivisionByZero } return new(big.Int).Div(a, b), nil } // SafePercent calculates percentage safely (value * percent / 100) func (sm *SafeMath) SafePercent(value *big.Int, percent uint64) (*big.Int, error) { if value == nil { return nil, fmt.Errorf("%w: nil value", ErrInvalidConversion) } if percent > 10000 { // Max 100.00% with 2 decimal precision return nil, fmt.Errorf("%w: percent %d exceeds 10000 (100%%)", ErrIntegerOverflow, percent) } percentBig := big.NewInt(int64(percent)) hundred := big.NewInt(100) temp := new(big.Int).Mul(value, percentBig) result := new(big.Int).Div(temp, hundred) return result, nil } // ValidateGasPrice ensures gas price is within safe bounds func (sm *SafeMath) ValidateGasPrice(gasPrice *big.Int) error { if gasPrice == nil { return fmt.Errorf("gas price cannot be nil") } if gasPrice.Sign() < 0 { return fmt.Errorf("gas price cannot be negative") } if gasPrice.Cmp(sm.MaxGasPrice) > 0 { return fmt.Errorf("gas price %s exceeds maximum %s", gasPrice.String(), sm.MaxGasPrice.String()) } return nil } // ValidateTransactionValue ensures transaction value is within safe bounds func (sm *SafeMath) ValidateTransactionValue(value *big.Int) error { if value == nil { return fmt.Errorf("transaction value cannot be nil") } if value.Sign() < 0 { return fmt.Errorf("transaction value cannot be negative") } if value.Cmp(sm.MaxTransactionValue) > 0 { return fmt.Errorf("transaction value %s exceeds maximum %s", value.String(), sm.MaxTransactionValue.String()) } return nil } // CalculateMinimumProfit calculates minimum profit required for a trade func (sm *SafeMath) CalculateMinimumProfit(gasPrice, gasLimit *big.Int) (*big.Int, error) { if err := sm.ValidateGasPrice(gasPrice); err != nil { return nil, fmt.Errorf("invalid gas price: %w", err) } // Calculate gas cost gasCost, err := sm.SafeMultiply(gasPrice, gasLimit) if err != nil { return nil, fmt.Errorf("failed to calculate gas cost: %w", err) } // Add 20% buffer for safety buffer, err := sm.SafePercent(gasCost, 120) if err != nil { return nil, fmt.Errorf("failed to calculate buffer: %w", err) } return buffer, nil } // SafeSlippage calculates safe slippage amount func (sm *SafeMath) SafeSlippage(amount *big.Int, slippageBps uint64) (*big.Int, error) { if amount == nil { return nil, fmt.Errorf("%w: nil amount", ErrInvalidConversion) } // Slippage in basis points (1 bp = 0.01%) if slippageBps > 10000 { // Max 100% return nil, fmt.Errorf("%w: slippage %d bps exceeds maximum", ErrIntegerOverflow, slippageBps) } // Calculate slippage amount slippageAmount := new(big.Int).Mul(amount, big.NewInt(int64(slippageBps))) slippageAmount.Div(slippageAmount, big.NewInt(10000)) // Calculate amount after slippage result := new(big.Int).Sub(amount, slippageAmount) if result.Sign() < 0 { return nil, fmt.Errorf("%w: slippage exceeds amount", ErrIntegerUnderflow) } return result, nil }