Files
mev-beta/pkg/security/safemath.go
2025-10-04 09:31:02 -05:00

235 lines
6.8 KiB
Go

package security
import (
"errors"
"fmt"
"math"
"math/big"
)
var (
// ErrIntegerOverflow indicates an integer overflow would occur
ErrIntegerOverflow = errors.New("integer overflow detected")
// ErrIntegerUnderflow indicates an integer underflow would occur
ErrIntegerUnderflow = errors.New("integer underflow detected")
// ErrDivisionByZero indicates division by zero was attempted
ErrDivisionByZero = errors.New("division by zero")
// ErrInvalidConversion indicates an invalid type conversion
ErrInvalidConversion = errors.New("invalid type conversion")
)
// SafeMath provides safe mathematical operations with overflow protection
type SafeMath struct {
// MaxGasPrice is the maximum allowed gas price in wei
MaxGasPrice *big.Int
// MaxTransactionValue is the maximum allowed transaction value
MaxTransactionValue *big.Int
}
// NewSafeMath creates a new SafeMath instance with security limits
func NewSafeMath() *SafeMath {
// 10000 Gwei max gas price
maxGasPrice := new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9))
// 10000 ETH max transaction value
maxTxValue := new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e18))
return &SafeMath{
MaxGasPrice: maxGasPrice,
MaxTransactionValue: maxTxValue,
}
}
// SafeUint8 safely converts uint64 to uint8 with overflow check
func SafeUint8(val uint64) (uint8, error) {
if val > math.MaxUint8 {
return 0, fmt.Errorf("%w: value %d exceeds uint8 max %d", ErrIntegerOverflow, val, math.MaxUint8)
}
return uint8(val), nil
}
// SafeUint32 safely converts uint64 to uint32 with overflow check
func SafeUint32(val uint64) (uint32, error) {
if val > math.MaxUint32 {
return 0, fmt.Errorf("%w: value %d exceeds uint32 max %d", ErrIntegerOverflow, val, math.MaxUint32)
}
return uint32(val), nil
}
// SafeUint64FromBigInt safely converts big.Int to uint64
func SafeUint64FromBigInt(val *big.Int) (uint64, error) {
if val == nil {
return 0, fmt.Errorf("%w: nil value", ErrInvalidConversion)
}
if val.Sign() < 0 {
return 0, fmt.Errorf("%w: negative value %s", ErrIntegerUnderflow, val.String())
}
if val.BitLen() > 64 {
return 0, fmt.Errorf("%w: value %s exceeds uint64 max", ErrIntegerOverflow, val.String())
}
return val.Uint64(), nil
}
// SafeAdd performs safe addition with overflow check
func (sm *SafeMath) SafeAdd(a, b *big.Int) (*big.Int, error) {
if a == nil || b == nil {
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
}
result := new(big.Int).Add(a, b)
// Check against maximum transaction value
if result.Cmp(sm.MaxTransactionValue) > 0 {
return nil, fmt.Errorf("%w: sum exceeds max transaction value", ErrIntegerOverflow)
}
return result, nil
}
// SafeSubtract performs safe subtraction with underflow check
func (sm *SafeMath) SafeSubtract(a, b *big.Int) (*big.Int, error) {
if a == nil || b == nil {
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
}
result := new(big.Int).Sub(a, b)
// Check for negative result (underflow)
if result.Sign() < 0 {
return nil, fmt.Errorf("%w: subtraction would result in negative value", ErrIntegerUnderflow)
}
return result, nil
}
// SafeMultiply performs safe multiplication with overflow check
func (sm *SafeMath) SafeMultiply(a, b *big.Int) (*big.Int, error) {
if a == nil || b == nil {
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
}
// Check for zero to avoid unnecessary computation
if a.Sign() == 0 || b.Sign() == 0 {
return big.NewInt(0), nil
}
result := new(big.Int).Mul(a, b)
// Check against maximum transaction value
if result.Cmp(sm.MaxTransactionValue) > 0 {
return nil, fmt.Errorf("%w: product exceeds max transaction value", ErrIntegerOverflow)
}
return result, nil
}
// SafeDivide performs safe division with zero check
func (sm *SafeMath) SafeDivide(a, b *big.Int) (*big.Int, error) {
if a == nil || b == nil {
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
}
if b.Sign() == 0 {
return nil, ErrDivisionByZero
}
return new(big.Int).Div(a, b), nil
}
// SafePercent calculates percentage safely (value * percent / 100)
func (sm *SafeMath) SafePercent(value *big.Int, percent uint64) (*big.Int, error) {
if value == nil {
return nil, fmt.Errorf("%w: nil value", ErrInvalidConversion)
}
if percent > 10000 { // Max 100.00% with 2 decimal precision
return nil, fmt.Errorf("%w: percent %d exceeds 10000 (100%%)", ErrIntegerOverflow, percent)
}
percentBig := big.NewInt(int64(percent))
hundred := big.NewInt(100)
temp := new(big.Int).Mul(value, percentBig)
result := new(big.Int).Div(temp, hundred)
return result, nil
}
// ValidateGasPrice ensures gas price is within safe bounds
func (sm *SafeMath) ValidateGasPrice(gasPrice *big.Int) error {
if gasPrice == nil {
return fmt.Errorf("gas price cannot be nil")
}
if gasPrice.Sign() < 0 {
return fmt.Errorf("gas price cannot be negative")
}
if gasPrice.Cmp(sm.MaxGasPrice) > 0 {
return fmt.Errorf("gas price %s exceeds maximum %s", gasPrice.String(), sm.MaxGasPrice.String())
}
return nil
}
// ValidateTransactionValue ensures transaction value is within safe bounds
func (sm *SafeMath) ValidateTransactionValue(value *big.Int) error {
if value == nil {
return fmt.Errorf("transaction value cannot be nil")
}
if value.Sign() < 0 {
return fmt.Errorf("transaction value cannot be negative")
}
if value.Cmp(sm.MaxTransactionValue) > 0 {
return fmt.Errorf("transaction value %s exceeds maximum %s", value.String(), sm.MaxTransactionValue.String())
}
return nil
}
// CalculateMinimumProfit calculates minimum profit required for a trade
func (sm *SafeMath) CalculateMinimumProfit(gasPrice, gasLimit *big.Int) (*big.Int, error) {
if err := sm.ValidateGasPrice(gasPrice); err != nil {
return nil, fmt.Errorf("invalid gas price: %w", err)
}
// Calculate gas cost
gasCost, err := sm.SafeMultiply(gasPrice, gasLimit)
if err != nil {
return nil, fmt.Errorf("failed to calculate gas cost: %w", err)
}
// Add 20% buffer for safety
buffer, err := sm.SafePercent(gasCost, 120)
if err != nil {
return nil, fmt.Errorf("failed to calculate buffer: %w", err)
}
return buffer, nil
}
// SafeSlippage calculates safe slippage amount
func (sm *SafeMath) SafeSlippage(amount *big.Int, slippageBps uint64) (*big.Int, error) {
if amount == nil {
return nil, fmt.Errorf("%w: nil amount", ErrInvalidConversion)
}
// Slippage in basis points (1 bp = 0.01%)
if slippageBps > 10000 { // Max 100%
return nil, fmt.Errorf("%w: slippage %d bps exceeds maximum", ErrIntegerOverflow, slippageBps)
}
// Calculate slippage amount
slippageAmount := new(big.Int).Mul(amount, big.NewInt(int64(slippageBps)))
slippageAmount.Div(slippageAmount, big.NewInt(10000))
// Calculate amount after slippage
result := new(big.Int).Sub(amount, slippageAmount)
if result.Sign() < 0 {
return nil, fmt.Errorf("%w: slippage exceeds amount", ErrIntegerUnderflow)
}
return result, nil
}