235 lines
6.8 KiB
Go
235 lines
6.8 KiB
Go
package security
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"math/big"
|
|
)
|
|
|
|
var (
|
|
// ErrIntegerOverflow indicates an integer overflow would occur
|
|
ErrIntegerOverflow = errors.New("integer overflow detected")
|
|
// ErrIntegerUnderflow indicates an integer underflow would occur
|
|
ErrIntegerUnderflow = errors.New("integer underflow detected")
|
|
// ErrDivisionByZero indicates division by zero was attempted
|
|
ErrDivisionByZero = errors.New("division by zero")
|
|
// ErrInvalidConversion indicates an invalid type conversion
|
|
ErrInvalidConversion = errors.New("invalid type conversion")
|
|
)
|
|
|
|
// SafeMath provides safe mathematical operations with overflow protection
|
|
type SafeMath struct {
|
|
// MaxGasPrice is the maximum allowed gas price in wei
|
|
MaxGasPrice *big.Int
|
|
// MaxTransactionValue is the maximum allowed transaction value
|
|
MaxTransactionValue *big.Int
|
|
}
|
|
|
|
// NewSafeMath creates a new SafeMath instance with security limits
|
|
func NewSafeMath() *SafeMath {
|
|
// 10000 Gwei max gas price
|
|
maxGasPrice := new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9))
|
|
// 10000 ETH max transaction value
|
|
maxTxValue := new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e18))
|
|
|
|
return &SafeMath{
|
|
MaxGasPrice: maxGasPrice,
|
|
MaxTransactionValue: maxTxValue,
|
|
}
|
|
}
|
|
|
|
// SafeUint8 safely converts uint64 to uint8 with overflow check
|
|
func SafeUint8(val uint64) (uint8, error) {
|
|
if val > math.MaxUint8 {
|
|
return 0, fmt.Errorf("%w: value %d exceeds uint8 max %d", ErrIntegerOverflow, val, math.MaxUint8)
|
|
}
|
|
return uint8(val), nil
|
|
}
|
|
|
|
// SafeUint32 safely converts uint64 to uint32 with overflow check
|
|
func SafeUint32(val uint64) (uint32, error) {
|
|
if val > math.MaxUint32 {
|
|
return 0, fmt.Errorf("%w: value %d exceeds uint32 max %d", ErrIntegerOverflow, val, math.MaxUint32)
|
|
}
|
|
return uint32(val), nil
|
|
}
|
|
|
|
// SafeUint64FromBigInt safely converts big.Int to uint64
|
|
func SafeUint64FromBigInt(val *big.Int) (uint64, error) {
|
|
if val == nil {
|
|
return 0, fmt.Errorf("%w: nil value", ErrInvalidConversion)
|
|
}
|
|
if val.Sign() < 0 {
|
|
return 0, fmt.Errorf("%w: negative value %s", ErrIntegerUnderflow, val.String())
|
|
}
|
|
if val.BitLen() > 64 {
|
|
return 0, fmt.Errorf("%w: value %s exceeds uint64 max", ErrIntegerOverflow, val.String())
|
|
}
|
|
return val.Uint64(), nil
|
|
}
|
|
|
|
// SafeAdd performs safe addition with overflow check
|
|
func (sm *SafeMath) SafeAdd(a, b *big.Int) (*big.Int, error) {
|
|
if a == nil || b == nil {
|
|
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
|
|
}
|
|
|
|
result := new(big.Int).Add(a, b)
|
|
|
|
// Check against maximum transaction value
|
|
if result.Cmp(sm.MaxTransactionValue) > 0 {
|
|
return nil, fmt.Errorf("%w: sum exceeds max transaction value", ErrIntegerOverflow)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// SafeSubtract performs safe subtraction with underflow check
|
|
func (sm *SafeMath) SafeSubtract(a, b *big.Int) (*big.Int, error) {
|
|
if a == nil || b == nil {
|
|
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
|
|
}
|
|
|
|
result := new(big.Int).Sub(a, b)
|
|
|
|
// Check for negative result (underflow)
|
|
if result.Sign() < 0 {
|
|
return nil, fmt.Errorf("%w: subtraction would result in negative value", ErrIntegerUnderflow)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// SafeMultiply performs safe multiplication with overflow check
|
|
func (sm *SafeMath) SafeMultiply(a, b *big.Int) (*big.Int, error) {
|
|
if a == nil || b == nil {
|
|
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
|
|
}
|
|
|
|
// Check for zero to avoid unnecessary computation
|
|
if a.Sign() == 0 || b.Sign() == 0 {
|
|
return big.NewInt(0), nil
|
|
}
|
|
|
|
result := new(big.Int).Mul(a, b)
|
|
|
|
// Check against maximum transaction value
|
|
if result.Cmp(sm.MaxTransactionValue) > 0 {
|
|
return nil, fmt.Errorf("%w: product exceeds max transaction value", ErrIntegerOverflow)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// SafeDivide performs safe division with zero check
|
|
func (sm *SafeMath) SafeDivide(a, b *big.Int) (*big.Int, error) {
|
|
if a == nil || b == nil {
|
|
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
|
|
}
|
|
|
|
if b.Sign() == 0 {
|
|
return nil, ErrDivisionByZero
|
|
}
|
|
|
|
return new(big.Int).Div(a, b), nil
|
|
}
|
|
|
|
// SafePercent calculates percentage safely (value * percent / 100)
|
|
func (sm *SafeMath) SafePercent(value *big.Int, percent uint64) (*big.Int, error) {
|
|
if value == nil {
|
|
return nil, fmt.Errorf("%w: nil value", ErrInvalidConversion)
|
|
}
|
|
|
|
if percent > 10000 { // Max 100.00% with 2 decimal precision
|
|
return nil, fmt.Errorf("%w: percent %d exceeds 10000 (100%%)", ErrIntegerOverflow, percent)
|
|
}
|
|
|
|
percentBig := big.NewInt(int64(percent))
|
|
hundred := big.NewInt(100)
|
|
|
|
temp := new(big.Int).Mul(value, percentBig)
|
|
result := new(big.Int).Div(temp, hundred)
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// ValidateGasPrice ensures gas price is within safe bounds
|
|
func (sm *SafeMath) ValidateGasPrice(gasPrice *big.Int) error {
|
|
if gasPrice == nil {
|
|
return fmt.Errorf("gas price cannot be nil")
|
|
}
|
|
|
|
if gasPrice.Sign() < 0 {
|
|
return fmt.Errorf("gas price cannot be negative")
|
|
}
|
|
|
|
if gasPrice.Cmp(sm.MaxGasPrice) > 0 {
|
|
return fmt.Errorf("gas price %s exceeds maximum %s", gasPrice.String(), sm.MaxGasPrice.String())
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidateTransactionValue ensures transaction value is within safe bounds
|
|
func (sm *SafeMath) ValidateTransactionValue(value *big.Int) error {
|
|
if value == nil {
|
|
return fmt.Errorf("transaction value cannot be nil")
|
|
}
|
|
|
|
if value.Sign() < 0 {
|
|
return fmt.Errorf("transaction value cannot be negative")
|
|
}
|
|
|
|
if value.Cmp(sm.MaxTransactionValue) > 0 {
|
|
return fmt.Errorf("transaction value %s exceeds maximum %s", value.String(), sm.MaxTransactionValue.String())
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CalculateMinimumProfit calculates minimum profit required for a trade
|
|
func (sm *SafeMath) CalculateMinimumProfit(gasPrice, gasLimit *big.Int) (*big.Int, error) {
|
|
if err := sm.ValidateGasPrice(gasPrice); err != nil {
|
|
return nil, fmt.Errorf("invalid gas price: %w", err)
|
|
}
|
|
|
|
// Calculate gas cost
|
|
gasCost, err := sm.SafeMultiply(gasPrice, gasLimit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to calculate gas cost: %w", err)
|
|
}
|
|
|
|
// Add 20% buffer for safety
|
|
buffer, err := sm.SafePercent(gasCost, 120)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to calculate buffer: %w", err)
|
|
}
|
|
|
|
return buffer, nil
|
|
}
|
|
|
|
// SafeSlippage calculates safe slippage amount
|
|
func (sm *SafeMath) SafeSlippage(amount *big.Int, slippageBps uint64) (*big.Int, error) {
|
|
if amount == nil {
|
|
return nil, fmt.Errorf("%w: nil amount", ErrInvalidConversion)
|
|
}
|
|
|
|
// Slippage in basis points (1 bp = 0.01%)
|
|
if slippageBps > 10000 { // Max 100%
|
|
return nil, fmt.Errorf("%w: slippage %d bps exceeds maximum", ErrIntegerOverflow, slippageBps)
|
|
}
|
|
|
|
// Calculate slippage amount
|
|
slippageAmount := new(big.Int).Mul(amount, big.NewInt(int64(slippageBps)))
|
|
slippageAmount.Div(slippageAmount, big.NewInt(10000))
|
|
|
|
// Calculate amount after slippage
|
|
result := new(big.Int).Sub(amount, slippageAmount)
|
|
if result.Sign() < 0 {
|
|
return nil, fmt.Errorf("%w: slippage exceeds amount", ErrIntegerUnderflow)
|
|
}
|
|
|
|
return result, nil
|
|
}
|