package math import ( "fmt" "math/big" "strings" ) // UniversalDecimal represents a token amount with precise decimal handling type UniversalDecimal struct { Value *big.Int // Raw value in smallest unit Decimals uint8 // Number of decimal places (0-18) Symbol string // Token symbol for debugging } // DecimalConverter handles conversions between different decimal precisions type DecimalConverter struct { // Cache for common scaling factors to avoid repeated calculations scalingFactors map[uint8]*big.Int } // NewDecimalConverter creates a new decimal converter with caching func NewDecimalConverter() *DecimalConverter { dc := &DecimalConverter{ scalingFactors: make(map[uint8]*big.Int), } // Pre-calculate common scaling factors (0-18 decimals) for i := uint8(0); i <= 18; i++ { dc.scalingFactors[i] = new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(i)), nil) } return dc } // NewUniversalDecimal creates a new universal decimal with validation func NewUniversalDecimal(value *big.Int, decimals uint8, symbol string) (*UniversalDecimal, error) { if decimals > 18 { return nil, fmt.Errorf("decimal places cannot exceed 18, got %d for token %s", decimals, symbol) } if value == nil { value = big.NewInt(0) } // Copy the value to prevent external modifications valueCopy := new(big.Int).Set(value) return &UniversalDecimal{ Value: valueCopy, Decimals: decimals, Symbol: symbol, }, nil } // FromString creates UniversalDecimal from string representation // Intelligently determines format: // 1. Very large numbers (length >= decimals): treated as raw wei/smallest unit // 2. Small numbers (length < decimals): treated as human-readable units // 3. Numbers with decimal point: always treated as human-readable func (dc *DecimalConverter) FromString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) { // Handle empty or zero values if valueStr == "" || valueStr == "0" { return NewUniversalDecimal(big.NewInt(0), decimals, symbol) } // Remove any whitespace valueStr = strings.TrimSpace(valueStr) // Check for decimal point - if present, treat as human-readable decimal if strings.Contains(valueStr, ".") { return dc.fromDecimalString(valueStr, decimals, symbol) } // For integers without decimal point, we need to determine if this is: // - A raw value (like "1000000000000000000" = 1000000000000000000 wei) // - A human-readable value (like "1" = 1.0 ETH = 1000000000000000000 wei) // Parse the number first value := new(big.Int) _, success := value.SetString(valueStr, 10) if !success { return nil, fmt.Errorf("invalid number format: %s for token %s", valueStr, symbol) } // Heuristic: if the string length is >= decimals, treat as raw value // This handles cases like "1000000000000000000" (18 chars, 18 decimals) as raw // But treats "1" (1 char, 18 decimals) as human-readable if len(valueStr) >= int(decimals) && decimals > 0 { // Treat as raw value in smallest unit return NewUniversalDecimal(value, decimals, symbol) } // Treat as human-readable value - convert to smallest unit scalingFactor := dc.getScalingFactor(decimals) scaledValue := new(big.Int).Mul(value, scalingFactor) return NewUniversalDecimal(scaledValue, decimals, symbol) } // fromDecimalString parses decimal string (e.g., "1.23") to smallest unit func (dc *DecimalConverter) fromDecimalString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) { parts := strings.Split(valueStr, ".") if len(parts) != 2 { return nil, fmt.Errorf("invalid decimal format: %s for token %s", valueStr, symbol) } integerPart := parts[0] decimalPart := parts[1] // Validate decimal part doesn't exceed token decimals if len(decimalPart) > int(decimals) { return nil, fmt.Errorf("decimal part %s has %d digits, but token %s only supports %d decimals", decimalPart, len(decimalPart), symbol, decimals) } // Parse integer part intValue := new(big.Int) if integerPart != "" && integerPart != "0" { _, success := intValue.SetString(integerPart, 10) if !success { return nil, fmt.Errorf("invalid integer part: %s for token %s", integerPart, symbol) } } // Parse decimal part decValue := new(big.Int) if decimalPart != "" && decimalPart != "0" { // Pad decimal part to full precision paddedDecimal := decimalPart for len(paddedDecimal) < int(decimals) { paddedDecimal += "0" } _, success := decValue.SetString(paddedDecimal, 10) if !success { return nil, fmt.Errorf("invalid decimal part: %s for token %s", decimalPart, symbol) } } // Combine integer and decimal parts scalingFactor := dc.getScalingFactor(decimals) totalValue := new(big.Int).Mul(intValue, scalingFactor) totalValue.Add(totalValue, decValue) return NewUniversalDecimal(totalValue, decimals, symbol) } // ToHumanReadable converts to human-readable decimal string // For round-trip precision preservation with FromString, returns raw value when appropriate func (dc *DecimalConverter) ToHumanReadable(ud *UniversalDecimal) string { if ud.Value.Sign() == 0 { return "0" } // For round-trip precision preservation, if the value represents exact units // (like 1000000000000000000 wei = exactly 1 ETH), output the human readable form // Otherwise, output the raw value to preserve precision if ud.Decimals == 0 { return ud.Value.String() } scalingFactor := dc.getScalingFactor(ud.Decimals) // Get integer and remainder parts integerPart := new(big.Int).Div(ud.Value, scalingFactor) remainder := new(big.Int).Mod(ud.Value, scalingFactor) // If this is an exact unit (no fractional part), return human readable if remainder.Sign() == 0 { return integerPart.String() } // For values with fractional parts, we need to decide: // If the value looks like it came from raw input (very large numbers), // preserve it as raw to maintain round-trip precision // Check if this looks like a raw value by comparing magnitude valueStr := ud.Value.String() if len(valueStr) >= int(ud.Decimals) { // This is likely a raw value, preserve as raw for round-trip return ud.Value.String() } // Format as human readable decimal decimalStr := remainder.String() for len(decimalStr) < int(ud.Decimals) { decimalStr = "0" + decimalStr } // Remove trailing zeros for readability decimalStr = strings.TrimRight(decimalStr, "0") if decimalStr == "" { return integerPart.String() } return fmt.Sprintf("%s.%s", integerPart.String(), decimalStr) } // ConvertTo converts between different decimal precisions func (dc *DecimalConverter) ConvertTo(from *UniversalDecimal, toDecimals uint8, toSymbol string) (*UniversalDecimal, error) { if from.Decimals == toDecimals { // Same precision, just copy with new symbol return NewUniversalDecimal(from.Value, toDecimals, toSymbol) } var convertedValue *big.Int if from.Decimals < toDecimals { // Increase precision (multiply) decimalDiff := toDecimals - from.Decimals scalingFactor := dc.getScalingFactor(decimalDiff) convertedValue = new(big.Int).Mul(from.Value, scalingFactor) } else { // Decrease precision (divide with rounding) decimalDiff := from.Decimals - toDecimals scalingFactor := dc.getScalingFactor(decimalDiff) // Round to nearest (banker's rounding) halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2)) roundedValue := new(big.Int).Add(from.Value, halfScaling) convertedValue = new(big.Int).Div(roundedValue, scalingFactor) } return NewUniversalDecimal(convertedValue, toDecimals, toSymbol) } // Multiply performs precise multiplication between different decimal tokens func (dc *DecimalConverter) Multiply(a, b *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) { // Multiply raw values product := new(big.Int).Mul(a.Value, b.Value) // Adjust for decimal places (division by 10^(a.decimals + b.decimals - result.decimals)) totalInputDecimals := a.Decimals + b.Decimals var adjustedProduct *big.Int if totalInputDecimals >= resultDecimals { decimalDiff := totalInputDecimals - resultDecimals scalingFactor := dc.getScalingFactor(decimalDiff) // Round to nearest halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2)) roundedProduct := new(big.Int).Add(product, halfScaling) adjustedProduct = new(big.Int).Div(roundedProduct, scalingFactor) } else { decimalDiff := resultDecimals - totalInputDecimals scalingFactor := dc.getScalingFactor(decimalDiff) adjustedProduct = new(big.Int).Mul(product, scalingFactor) } return NewUniversalDecimal(adjustedProduct, resultDecimals, resultSymbol) } // Divide performs precise division between different decimal tokens func (dc *DecimalConverter) Divide(numerator, denominator *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) { if denominator.Value.Sign() == 0 { return nil, fmt.Errorf("division by zero: %s / %s", numerator.Symbol, denominator.Symbol) } // Scale numerator to maintain precision totalDecimals := numerator.Decimals + resultDecimals scalingFactor := dc.getScalingFactor(totalDecimals - denominator.Decimals) scaledNumerator := new(big.Int).Mul(numerator.Value, scalingFactor) quotient := new(big.Int).Div(scaledNumerator, denominator.Value) return NewUniversalDecimal(quotient, resultDecimals, resultSymbol) } // Add adds two UniversalDecimals with same precision func (dc *DecimalConverter) Add(a, b *UniversalDecimal) (*UniversalDecimal, error) { if a.Decimals != b.Decimals { return nil, fmt.Errorf("cannot add tokens with different decimals: %s(%d) + %s(%d)", a.Symbol, a.Decimals, b.Symbol, b.Decimals) } sum := new(big.Int).Add(a.Value, b.Value) resultSymbol := a.Symbol if a.Symbol != b.Symbol { resultSymbol = fmt.Sprintf("%s+%s", a.Symbol, b.Symbol) } return NewUniversalDecimal(sum, a.Decimals, resultSymbol) } // Subtract subtracts two UniversalDecimals with same precision func (dc *DecimalConverter) Subtract(a, b *UniversalDecimal) (*UniversalDecimal, error) { if a.Decimals != b.Decimals { return nil, fmt.Errorf("cannot subtract tokens with different decimals: %s(%d) - %s(%d)", a.Symbol, a.Decimals, b.Symbol, b.Decimals) } diff := new(big.Int).Sub(a.Value, b.Value) resultSymbol := a.Symbol if a.Symbol != b.Symbol { resultSymbol = fmt.Sprintf("%s-%s", a.Symbol, b.Symbol) } return NewUniversalDecimal(diff, a.Decimals, resultSymbol) } // Compare returns -1, 0, or 1 for a < b, a == b, a > b respectively func (dc *DecimalConverter) Compare(a, b *UniversalDecimal) (int, error) { if a.Decimals != b.Decimals { // Convert to same precision for comparison converted, err := dc.ConvertTo(b, a.Decimals, b.Symbol) if err != nil { return 0, fmt.Errorf("cannot compare tokens with different decimals: %w", err) } b = converted } return a.Value.Cmp(b.Value), nil } // IsZero checks if the value is zero func (ud *UniversalDecimal) IsZero() bool { return ud.Value.Sign() == 0 } // IsPositive checks if the value is positive func (ud *UniversalDecimal) IsPositive() bool { return ud.Value.Sign() > 0 } // IsNegative checks if the value is negative func (ud *UniversalDecimal) IsNegative() bool { return ud.Value.Sign() < 0 } // Copy creates a deep copy of the UniversalDecimal func (ud *UniversalDecimal) Copy() *UniversalDecimal { return &UniversalDecimal{ Value: new(big.Int).Set(ud.Value), Decimals: ud.Decimals, Symbol: ud.Symbol, } } // getScalingFactor returns the scaling factor for given decimals (cached) func (dc *DecimalConverter) getScalingFactor(decimals uint8) *big.Int { if factor, exists := dc.scalingFactors[decimals]; exists { return factor } // Calculate and cache if not exists (shouldn't happen for 0-18) factor := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimals)), nil) dc.scalingFactors[decimals] = factor return factor } // ToWei converts any decimal precision to 18-decimal wei representation func (dc *DecimalConverter) ToWei(ud *UniversalDecimal) *UniversalDecimal { weiValue, _ := dc.ConvertTo(ud, 18, "WEI") return weiValue } // FromWei converts 18-decimal wei to specified decimal precision func (dc *DecimalConverter) FromWei(weiValue *big.Int, targetDecimals uint8, targetSymbol string) *UniversalDecimal { weiDecimal := &UniversalDecimal{ Value: new(big.Int).Set(weiValue), Decimals: 18, Symbol: "WEI", } result, _ := dc.ConvertTo(weiDecimal, targetDecimals, targetSymbol) return result } // CalculatePercentage calculates percentage with precise decimal handling // Returns percentage as UniversalDecimal with 4 decimal places (e.g., 1.5000% = 15000 with 4 decimals) func (dc *DecimalConverter) CalculatePercentage(value, total *UniversalDecimal) (*UniversalDecimal, error) { if total.IsZero() { return nil, fmt.Errorf("cannot calculate percentage with zero total") } // Convert to same precision if needed if value.Decimals != total.Decimals { convertedValue, err := dc.ConvertTo(value, total.Decimals, value.Symbol) if err != nil { return nil, fmt.Errorf("error converting decimals for percentage: %w", err) } value = convertedValue } // Calculate (value / total) * 100 using integer arithmetic to avoid floating point errors // Formula: (value * 100 * 10^4) / total where 10^4 gives us 4 decimal places // Multiply value by 100 * 10^4 = 1,000,000 for percentage with 4 decimal places hundredWithDecimals := big.NewInt(1000000) // 100.0000 in 4-decimal format numerator := new(big.Int).Mul(value.Value, hundredWithDecimals) // Divide by total to get percentage percentage := new(big.Int).Div(numerator, total.Value) return NewUniversalDecimal(percentage, 4, "PERCENT") } // String returns string representation for debugging func (ud *UniversalDecimal) String() string { dc := NewDecimalConverter() humanReadable := dc.ToHumanReadable(ud) return fmt.Sprintf("%s %s", humanReadable, ud.Symbol) }