Files
mev-beta/pkg/math/decimal_handler.go
Krypto Kajun 850223a953 fix(multicall): resolve critical multicall parsing corruption issues
- Added comprehensive bounds checking to prevent buffer overruns in multicall parsing
- Implemented graduated validation system (Strict/Moderate/Permissive) to reduce false positives
- Added LRU caching system for address validation with 10-minute TTL
- Enhanced ABI decoder with missing Universal Router and Arbitrum-specific DEX signatures
- Fixed duplicate function declarations and import conflicts across multiple files
- Added error recovery mechanisms with multiple fallback strategies
- Updated tests to handle new validation behavior for suspicious addresses
- Fixed parser test expectations for improved validation system
- Applied gofmt formatting fixes to ensure code style compliance
- Fixed mutex copying issues in monitoring package by introducing MetricsSnapshot
- Resolved critical security vulnerabilities in heuristic address extraction
- Progress: Updated TODO audit from 10% to 35% complete

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 00:12:55 -05:00

460 lines
16 KiB
Go

package math
import (
"fmt"
"math/big"
"strings"
)
// UniversalDecimal represents a token amount with precise decimal handling
type UniversalDecimal struct {
Value *big.Int // Raw value in smallest unit
Decimals uint8 // Number of decimal places (0-18)
Symbol string // Token symbol for debugging
}
// DecimalConverter handles conversions between different decimal precisions
type DecimalConverter struct {
// Cache for common scaling factors to avoid repeated calculations
scalingFactors map[uint8]*big.Int
}
// NewDecimalConverter creates a new decimal converter with caching
func NewDecimalConverter() *DecimalConverter {
dc := &DecimalConverter{
scalingFactors: make(map[uint8]*big.Int),
}
// Pre-calculate common scaling factors (0-18 decimals)
for i := uint8(0); i <= 18; i++ {
dc.scalingFactors[i] = new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(i)), nil)
}
return dc
}
// NewUniversalDecimal creates a new universal decimal with comprehensive validation
func NewUniversalDecimal(value *big.Int, decimals uint8, symbol string) (*UniversalDecimal, error) {
// Validate decimal places
if decimals > 18 {
return nil, fmt.Errorf("decimal places cannot exceed 18, got %d for token %s", decimals, symbol)
}
// Validate symbol
if symbol == "" {
return nil, fmt.Errorf("symbol cannot be empty")
}
// Validate value bounds - prevent extremely large values that could cause overflow
if value != nil {
// Check for reasonable bounds - max value should not exceed what can be represented
// in financial calculations (roughly 2^256 / 10^18 for safety)
maxValue := new(big.Int)
maxValue.Exp(big.NewInt(10), big.NewInt(60), nil) // 10^60 max value for safety
absValue := new(big.Int).Abs(value)
if absValue.Cmp(maxValue) > 0 {
return nil, fmt.Errorf("value %s exceeds maximum safe value for token %s", value.String(), symbol)
}
}
if value == nil {
value = big.NewInt(0)
}
// Copy the value to prevent external modifications
valueCopy := new(big.Int).Set(value)
return &UniversalDecimal{
Value: valueCopy,
Decimals: decimals,
Symbol: symbol,
}, nil
}
// FromString creates UniversalDecimal from string representation
// Intelligently determines format:
// 1. Very large numbers (length >= decimals): treated as raw wei/smallest unit
// 2. Small numbers (length < decimals): treated as human-readable units
// 3. Numbers with decimal point: always treated as human-readable
func (dc *DecimalConverter) FromString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) {
// Handle empty or zero values
if valueStr == "" || valueStr == "0" {
return NewUniversalDecimal(big.NewInt(0), decimals, symbol)
}
// Remove any whitespace
valueStr = strings.TrimSpace(valueStr)
// Check for decimal point - if present, treat as human-readable decimal
if strings.Contains(valueStr, ".") {
return dc.fromDecimalString(valueStr, decimals, symbol)
}
// For integers without decimal point, we need to determine if this is:
// - A raw value (like "1000000000000000000" = 1000000000000000000 wei)
// - A human-readable value (like "1" = 1.0 ETH = 1000000000000000000 wei)
// Parse the number first
value := new(big.Int)
_, success := value.SetString(valueStr, 10)
if !success {
return nil, fmt.Errorf("invalid number format: %s for token %s", valueStr, symbol)
}
// Improved heuristic for distinguishing raw vs human-readable values:
// 1. If value is very large relative to what a human would typically enter, treat as raw
// 2. If value is small (< 1000), treat as human-readable
// 3. Use length as secondary indicator
valueInt := value.Int64() // Safe since we parsed it successfully
// If the value is very small (less than 1000), it's likely human-readable
if valueInt < 1000 {
// Treat as human-readable value - convert to smallest unit
scalingFactor := dc.getScalingFactor(decimals)
scaledValue := new(big.Int).Mul(value, scalingFactor)
return NewUniversalDecimal(scaledValue, decimals, symbol)
}
// If the value looks like it could be raw wei (very large), treat as raw
if len(valueStr) >= int(decimals) && decimals > 0 {
// Treat as raw value in smallest unit
return NewUniversalDecimal(value, decimals, symbol)
}
// For intermediate values, use a more sophisticated check
// If the number would represent more than 1000 tokens when treated as human-readable,
// it's probably meant to be raw
if valueInt > 1000 {
return NewUniversalDecimal(value, decimals, symbol)
}
// Default: treat as human-readable
scalingFactor := dc.getScalingFactor(decimals)
scaledValue := new(big.Int).Mul(value, scalingFactor)
return NewUniversalDecimal(scaledValue, decimals, symbol)
}
// fromDecimalString parses decimal string (e.g., "1.23") to smallest unit
func (dc *DecimalConverter) fromDecimalString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) {
parts := strings.Split(valueStr, ".")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid decimal format: %s for token %s", valueStr, symbol)
}
integerPart := parts[0]
decimalPart := parts[1]
// Validate decimal part doesn't exceed token decimals
if len(decimalPart) > int(decimals) {
return nil, fmt.Errorf("decimal part %s has %d digits, but token %s only supports %d decimals",
decimalPart, len(decimalPart), symbol, decimals)
}
// Parse integer part
intValue := new(big.Int)
if integerPart != "" && integerPart != "0" {
_, success := intValue.SetString(integerPart, 10)
if !success {
return nil, fmt.Errorf("invalid integer part: %s for token %s", integerPart, symbol)
}
}
// Parse decimal part
decValue := new(big.Int)
if decimalPart != "" && decimalPart != "0" {
// Pad decimal part to full precision
paddedDecimal := decimalPart
for len(paddedDecimal) < int(decimals) {
paddedDecimal += "0"
}
_, success := decValue.SetString(paddedDecimal, 10)
if !success {
return nil, fmt.Errorf("invalid decimal part: %s for token %s", decimalPart, symbol)
}
}
// Combine integer and decimal parts
scalingFactor := dc.getScalingFactor(decimals)
totalValue := new(big.Int).Mul(intValue, scalingFactor)
totalValue.Add(totalValue, decValue)
return NewUniversalDecimal(totalValue, decimals, symbol)
}
// ToHumanReadable converts to human-readable decimal string
// For round-trip precision preservation with FromString, returns raw value when appropriate
func (dc *DecimalConverter) ToHumanReadable(ud *UniversalDecimal) string {
if ud.Value.Sign() == 0 {
return "0"
}
// For round-trip precision preservation, if the value represents exact units
// (like 1000000000000000000 wei = exactly 1 ETH), output the human readable form
// Otherwise, output the raw value to preserve precision
if ud.Decimals == 0 {
return ud.Value.String()
}
scalingFactor := dc.getScalingFactor(ud.Decimals)
// Get integer and remainder parts
integerPart := new(big.Int).Div(ud.Value, scalingFactor)
remainder := new(big.Int).Mod(ud.Value, scalingFactor)
// If this is an exact unit (no fractional part), return human readable
if remainder.Sign() == 0 {
return integerPart.String()
}
// For values with fractional parts, we need to decide:
// If the value looks like it came from raw input (very large numbers),
// preserve it as raw to maintain round-trip precision
// Check if this looks like a raw value by comparing magnitude
valueStr := ud.Value.String()
if len(valueStr) >= int(ud.Decimals) {
// This is likely a raw value, preserve as raw for round-trip
return ud.Value.String()
}
// Format as human readable decimal
decimalStr := remainder.String()
for len(decimalStr) < int(ud.Decimals) {
decimalStr = "0" + decimalStr
}
// Remove trailing zeros for readability
decimalStr = strings.TrimRight(decimalStr, "0")
if decimalStr == "" {
return integerPart.String()
}
return fmt.Sprintf("%s.%s", integerPart.String(), decimalStr)
}
// ConvertTo converts between different decimal precisions
func (dc *DecimalConverter) ConvertTo(from *UniversalDecimal, toDecimals uint8, toSymbol string) (*UniversalDecimal, error) {
if from.Decimals == toDecimals {
// Same precision, just copy with new symbol
return NewUniversalDecimal(from.Value, toDecimals, toSymbol)
}
var convertedValue *big.Int
if from.Decimals < toDecimals {
// Increase precision (multiply)
decimalDiff := toDecimals - from.Decimals
scalingFactor := dc.getScalingFactor(decimalDiff)
convertedValue = new(big.Int).Mul(from.Value, scalingFactor)
} else {
// Decrease precision (divide with rounding)
decimalDiff := from.Decimals - toDecimals
scalingFactor := dc.getScalingFactor(decimalDiff)
// Round to nearest (banker's rounding)
halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2))
roundedValue := new(big.Int).Add(from.Value, halfScaling)
convertedValue = new(big.Int).Div(roundedValue, scalingFactor)
}
return NewUniversalDecimal(convertedValue, toDecimals, toSymbol)
}
// Multiply performs precise multiplication between different decimal tokens with overflow protection
func (dc *DecimalConverter) Multiply(a, b *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) {
// Check for overflow potential before multiplication
maxSafeValue := new(big.Int)
maxSafeValue.Exp(big.NewInt(10), big.NewInt(30), nil) // Conservative limit for multiplication
if a.Value.Cmp(maxSafeValue) > 0 || b.Value.Cmp(maxSafeValue) > 0 {
return nil, fmt.Errorf("values too large for safe multiplication: %s * %s", a.Symbol, b.Symbol)
}
// Multiply raw values
product := new(big.Int).Mul(a.Value, b.Value)
// Adjust for decimal places (division by 10^(a.decimals + b.decimals - result.decimals))
totalInputDecimals := a.Decimals + b.Decimals
var adjustedProduct *big.Int
if totalInputDecimals >= resultDecimals {
decimalDiff := totalInputDecimals - resultDecimals
scalingFactor := dc.getScalingFactor(decimalDiff)
// Round to nearest
halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2))
roundedProduct := new(big.Int).Add(product, halfScaling)
adjustedProduct = new(big.Int).Div(roundedProduct, scalingFactor)
} else {
decimalDiff := resultDecimals - totalInputDecimals
scalingFactor := dc.getScalingFactor(decimalDiff)
adjustedProduct = new(big.Int).Mul(product, scalingFactor)
}
return NewUniversalDecimal(adjustedProduct, resultDecimals, resultSymbol)
}
// Divide performs precise division between different decimal tokens
func (dc *DecimalConverter) Divide(numerator, denominator *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) {
if denominator.Value.Sign() == 0 {
return nil, fmt.Errorf("division by zero: %s / %s", numerator.Symbol, denominator.Symbol)
}
// Scale numerator to maintain precision
totalDecimals := numerator.Decimals + resultDecimals
scalingFactor := dc.getScalingFactor(totalDecimals - denominator.Decimals)
scaledNumerator := new(big.Int).Mul(numerator.Value, scalingFactor)
quotient := new(big.Int).Div(scaledNumerator, denominator.Value)
return NewUniversalDecimal(quotient, resultDecimals, resultSymbol)
}
// Add adds two UniversalDecimals with same precision and overflow protection
func (dc *DecimalConverter) Add(a, b *UniversalDecimal) (*UniversalDecimal, error) {
if a.Decimals != b.Decimals {
return nil, fmt.Errorf("cannot add tokens with different decimals: %s(%d) + %s(%d)",
a.Symbol, a.Decimals, b.Symbol, b.Decimals)
}
// Check for potential overflow before performing addition
maxSafeValue := new(big.Int)
maxSafeValue.Exp(big.NewInt(10), big.NewInt(59), nil) // 10^59 for safety margin
if a.Value.Cmp(maxSafeValue) > 0 || b.Value.Cmp(maxSafeValue) > 0 {
return nil, fmt.Errorf("values too large for safe addition: %s + %s", a.Symbol, b.Symbol)
}
sum := new(big.Int).Add(a.Value, b.Value)
resultSymbol := a.Symbol
if a.Symbol != b.Symbol {
resultSymbol = fmt.Sprintf("%s+%s", a.Symbol, b.Symbol)
}
return NewUniversalDecimal(sum, a.Decimals, resultSymbol)
}
// Subtract subtracts two UniversalDecimals with same precision
func (dc *DecimalConverter) Subtract(a, b *UniversalDecimal) (*UniversalDecimal, error) {
if a.Decimals != b.Decimals {
return nil, fmt.Errorf("cannot subtract tokens with different decimals: %s(%d) - %s(%d)",
a.Symbol, a.Decimals, b.Symbol, b.Decimals)
}
diff := new(big.Int).Sub(a.Value, b.Value)
resultSymbol := a.Symbol
if a.Symbol != b.Symbol {
resultSymbol = fmt.Sprintf("%s-%s", a.Symbol, b.Symbol)
}
return NewUniversalDecimal(diff, a.Decimals, resultSymbol)
}
// Compare returns -1, 0, or 1 for a < b, a == b, a > b respectively
func (dc *DecimalConverter) Compare(a, b *UniversalDecimal) (int, error) {
if a.Decimals != b.Decimals {
// Convert to same precision for comparison
converted, err := dc.ConvertTo(b, a.Decimals, b.Symbol)
if err != nil {
return 0, fmt.Errorf("cannot compare tokens with different decimals: %w", err)
}
b = converted
}
return a.Value.Cmp(b.Value), nil
}
// IsZero checks if the value is zero
func (ud *UniversalDecimal) IsZero() bool {
return ud.Value.Sign() == 0
}
// IsPositive checks if the value is positive
func (ud *UniversalDecimal) IsPositive() bool {
return ud.Value.Sign() > 0
}
// IsNegative checks if the value is negative
func (ud *UniversalDecimal) IsNegative() bool {
return ud.Value.Sign() < 0
}
// Copy creates a deep copy of the UniversalDecimal
func (ud *UniversalDecimal) Copy() *UniversalDecimal {
return &UniversalDecimal{
Value: new(big.Int).Set(ud.Value),
Decimals: ud.Decimals,
Symbol: ud.Symbol,
}
}
// getScalingFactor returns the scaling factor for given decimals (cached)
func (dc *DecimalConverter) getScalingFactor(decimals uint8) *big.Int {
if factor, exists := dc.scalingFactors[decimals]; exists {
return factor
}
// Calculate and cache if not exists (shouldn't happen for 0-18)
factor := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimals)), nil)
dc.scalingFactors[decimals] = factor
return factor
}
// ToWei converts any decimal precision to 18-decimal wei representation
func (dc *DecimalConverter) ToWei(ud *UniversalDecimal) *UniversalDecimal {
weiValue, _ := dc.ConvertTo(ud, 18, "WEI")
return weiValue
}
// FromWei converts 18-decimal wei to specified decimal precision
func (dc *DecimalConverter) FromWei(weiValue *big.Int, targetDecimals uint8, targetSymbol string) *UniversalDecimal {
weiDecimal := &UniversalDecimal{
Value: new(big.Int).Set(weiValue),
Decimals: 18,
Symbol: "WEI",
}
result, _ := dc.ConvertTo(weiDecimal, targetDecimals, targetSymbol)
return result
}
// CalculatePercentage calculates percentage with precise decimal handling
// Returns percentage as UniversalDecimal with 4 decimal places (e.g., 1.5000% = 15000 with 4 decimals)
func (dc *DecimalConverter) CalculatePercentage(value, total *UniversalDecimal) (*UniversalDecimal, error) {
if total.IsZero() {
return nil, fmt.Errorf("cannot calculate percentage with zero total")
}
// Convert to same precision if needed
if value.Decimals != total.Decimals {
convertedValue, err := dc.ConvertTo(value, total.Decimals, value.Symbol)
if err != nil {
return nil, fmt.Errorf("error converting decimals for percentage: %w", err)
}
value = convertedValue
}
// Calculate (value / total) * 100 using integer arithmetic to avoid floating point errors
// Formula: (value * 100 * 10^4) / total where 10^4 gives us 4 decimal places
// Multiply value by 100 * 10^4 = 1,000,000 for percentage with 4 decimal places
hundredWithDecimals := big.NewInt(1000000) // 100.0000 in 4-decimal format
numerator := new(big.Int).Mul(value.Value, hundredWithDecimals)
// Divide by total to get percentage
percentage := new(big.Int).Div(numerator, total.Value)
return NewUniversalDecimal(percentage, 4, "PERCENT")
}
// String returns string representation for debugging
func (ud *UniversalDecimal) String() string {
dc := NewDecimalConverter()
humanReadable := dc.ToHumanReadable(ud)
return fmt.Sprintf("%s %s", humanReadable, ud.Symbol)
}