saving in place

This commit is contained in:
Krypto Kajun
2025-10-04 09:31:02 -05:00
parent 76c1b5cee1
commit f358f49aa9
295 changed files with 72071 additions and 17209 deletions

View File

@@ -0,0 +1,627 @@
package math
import (
"context"
"fmt"
"math/big"
"sort"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/fraktal/mev-beta/pkg/types"
)
// Use the canonical ArbitrageOpportunity from types package
// Extended fields for advanced calculations can be added as needed
// ExchangeStep represents one step in the arbitrage execution
type ExchangeStep struct {
Exchange ExchangeType
Pool *PoolData
TokenIn TokenInfo
TokenOut TokenInfo
AmountIn *UniversalDecimal
AmountOut *UniversalDecimal
PriceImpact *UniversalDecimal
EstimatedGas uint64
}
// RiskAssessment evaluates the risk level of an arbitrage opportunity
type RiskAssessment struct {
Overall RiskLevel
Liquidity RiskLevel
PriceImpact RiskLevel
Competition RiskLevel
Slippage RiskLevel
GasPrice RiskLevel
Warnings []string
OverallRisk float64 // Numeric representation of overall risk (0.0 to 1.0)
}
// RiskLevel represents different risk categories
type RiskLevel string
const (
RiskLow RiskLevel = "low"
RiskMedium RiskLevel = "medium"
RiskHigh RiskLevel = "high"
RiskCritical RiskLevel = "critical"
)
// ArbitrageCalculator performs precise arbitrage calculations
type ArbitrageCalculator struct {
pricingEngine *ExchangePricingEngine
decimalConverter *DecimalConverter
gasEstimator GasEstimator
// Configuration
minProfitThreshold *UniversalDecimal
maxPriceImpact *UniversalDecimal
maxSlippage *UniversalDecimal
maxGasPriceGwei *UniversalDecimal
}
// GasEstimator interface for gas cost calculations
type GasEstimator interface {
EstimateSwapGas(exchange ExchangeType, poolData *PoolData) (uint64, error)
EstimateFlashSwapGas(route []*PoolData) (uint64, error)
GetCurrentGasPrice() (*UniversalDecimal, error)
}
// NewArbitrageCalculator creates a new arbitrage calculator
func NewArbitrageCalculator(gasEstimator GasEstimator) *ArbitrageCalculator {
dc := NewDecimalConverter()
// Default configuration
minProfit, _ := dc.FromString("0.01", 18, "ETH") // 0.01 ETH minimum
maxImpact, _ := dc.FromString("2", 4, "PERCENT") // 2% max price impact
maxSlip, _ := dc.FromString("1", 4, "PERCENT") // 1% max slippage
maxGas, _ := dc.FromString("50", 9, "GWEI") // 50 gwei max gas
return &ArbitrageCalculator{
pricingEngine: NewExchangePricingEngine(),
decimalConverter: dc,
gasEstimator: gasEstimator,
minProfitThreshold: minProfit,
maxPriceImpact: maxImpact,
maxSlippage: maxSlip,
maxGasPriceGwei: maxGas,
}
}
// CalculateArbitrageOpportunity performs comprehensive arbitrage analysis
func (calc *ArbitrageCalculator) CalculateArbitrageOpportunity(
path []*PoolData,
inputAmount *UniversalDecimal,
inputToken TokenInfo,
outputToken TokenInfo,
) (*types.ArbitrageOpportunity, error) {
if len(path) == 0 {
return nil, fmt.Errorf("empty arbitrage path")
}
// Step 1: Calculate execution route with amounts
route, err := calc.calculateExecutionRoute(path, inputAmount, inputToken)
if err != nil {
return nil, fmt.Errorf("error calculating execution route: %w", err)
}
// Step 2: Get final output amount
finalOutput := route[len(route)-1].AmountOut
// Step 3: Calculate gas costs
totalGasCost, err := calc.calculateTotalGasCost(route)
if err != nil {
return nil, fmt.Errorf("error calculating gas cost: %w", err)
}
// Step 4: Calculate profits (convert to common denomination - ETH)
_, netProfit, profitPercentage, err := calc.calculateProfits(
inputAmount, finalOutput, totalGasCost, inputToken, outputToken)
if err != nil {
return nil, fmt.Errorf("error calculating profits: %w", err)
}
// Step 5: Calculate total price impact
totalPriceImpact, err := calc.calculateTotalPriceImpact(route)
if err != nil {
return nil, fmt.Errorf("error calculating price impact: %w", err)
}
// Step 6: Calculate minimum output with slippage (we don't use this in the final result)
_, err = calc.calculateMinimumOutput(finalOutput)
if err != nil {
return nil, fmt.Errorf("error calculating minimum output: %w", err)
}
// Step 7: Assess risks
riskAssessment := calc.assessRisks(route, totalPriceImpact, netProfit)
// Step 8: Calculate confidence and execution time
confidence := calc.calculateConfidence(riskAssessment, netProfit, totalPriceImpact)
executionTime := calc.estimateExecutionTime(route)
// Convert path to string array
pathStrings := make([]string, len(path))
for i, pool := range path {
pathStrings[i] = pool.Address // Address is already a string
}
// Convert pools to string array
poolStrings := make([]string, len(path))
for i, pool := range path {
poolStrings[i] = pool.Address // Address is already a string
}
opportunity := &types.ArbitrageOpportunity{
Path: pathStrings,
Pools: poolStrings,
AmountIn: inputAmount.Value,
Profit: netProfit.Value,
NetProfit: netProfit.Value,
GasEstimate: totalGasCost.Value,
ROI: func() float64 { f, _ := profitPercentage.Value.Float64(); return f }(),
Protocol: "multi", // Default protocol for multi-step arbitrage
ExecutionTime: executionTime,
Confidence: confidence,
PriceImpact: func() float64 { f, _ := totalPriceImpact.Value.Float64(); return f }(),
MaxSlippage: 0.01, // Default 1% max slippage
TokenIn: common.HexToAddress(inputToken.Address),
TokenOut: common.HexToAddress(outputToken.Address),
Timestamp: time.Now().Unix(),
Risk: riskAssessment.OverallRisk,
}
return opportunity, nil
}
// calculateExecutionRoute calculates amounts through each step of the arbitrage
func (calc *ArbitrageCalculator) calculateExecutionRoute(
path []*PoolData,
inputAmount *UniversalDecimal,
inputToken TokenInfo,
) ([]ExchangeStep, error) {
route := make([]ExchangeStep, len(path))
currentAmount := inputAmount
currentToken := inputToken
for i, pool := range path {
// Determine output token for this step
var outputToken TokenInfo
if currentToken.Address == pool.Token0.Address {
outputToken = TokenInfo{
Address: pool.Token1.Address,
Symbol: "TOKEN1", // In a real implementation, you'd fetch the actual symbol
Decimals: 18,
}
} else if currentToken.Address == pool.Token1.Address {
outputToken = TokenInfo{
Address: pool.Token0.Address,
Symbol: "TOKEN0", // In a real implementation, you'd fetch the actual symbol
Decimals: 18,
}
} else {
return nil, fmt.Errorf("token %s not found in pool %s", currentToken.Symbol, pool.Address)
}
// For this simplified implementation, we'll calculate a mock amount out
// In a real implementation, you'd use the pricer's CalculateAmountOut method
amountOut := currentAmount // Simple 1:1 for this example
priceImpact := &UniversalDecimal{Value: big.NewInt(0), Decimals: 4, Symbol: "PERCENT"} // No impact in mock
// Estimate gas for this step
estimatedGas, err := calc.gasEstimator.EstimateSwapGas(ExchangeUniswapV3, pool) // Using a mock exchange type
if err != nil {
return nil, fmt.Errorf("error estimating gas for pool %s: %w", pool.Address, err)
}
// Create execution step
route[i] = ExchangeStep{
Exchange: ExchangeUniswapV3, // Using a mock exchange type
Pool: pool,
TokenIn: currentToken,
TokenOut: outputToken,
AmountIn: currentAmount,
AmountOut: amountOut,
PriceImpact: priceImpact,
EstimatedGas: estimatedGas,
}
// Update for next iteration
currentAmount = amountOut
currentToken = outputToken
}
return route, nil
}
// calculateTotalGasCost calculates the total gas cost for the entire route
func (calc *ArbitrageCalculator) calculateTotalGasCost(route []ExchangeStep) (*UniversalDecimal, error) {
// Get current gas price
gasPrice, err := calc.gasEstimator.GetCurrentGasPrice()
if err != nil {
return nil, fmt.Errorf("error getting gas price: %w", err)
}
// Sum up all gas estimates
totalGas := uint64(0)
for _, step := range route {
totalGas += step.EstimatedGas
}
// Add flash swap overhead if multi-step
if len(route) > 1 {
flashSwapGas, err := calc.gasEstimator.EstimateFlashSwapGas([]*PoolData{})
if err == nil {
totalGas += flashSwapGas
}
}
// Convert to gas cost in ETH
totalGasBig := big.NewInt(int64(totalGas))
totalGasDecimal, err := NewUniversalDecimal(totalGasBig, 0, "GAS")
if err != nil {
return nil, err
}
return calc.decimalConverter.Multiply(totalGasDecimal, gasPrice, 18, "ETH")
}
// calculateProfits calculates gross profit, net profit, and profit percentage
func (calc *ArbitrageCalculator) calculateProfits(
inputAmount, outputAmount, gasCost *UniversalDecimal,
inputToken, outputToken TokenInfo,
) (*UniversalDecimal, *UniversalDecimal, *UniversalDecimal, error) {
// Convert amounts to common denomination (ETH) for comparison
inputETH := calc.convertToETH(inputAmount, inputToken)
outputETH := calc.convertToETH(outputAmount, outputToken)
// Gross profit = output - input (in ETH terms)
grossProfit, err := calc.decimalConverter.Subtract(outputETH, inputETH)
if err != nil {
return nil, nil, nil, fmt.Errorf("error calculating gross profit: %w", err)
}
// Net profit = gross profit - gas cost
netProfit, err := calc.decimalConverter.Subtract(grossProfit, gasCost)
if err != nil {
return nil, nil, nil, fmt.Errorf("error calculating net profit: %w", err)
}
// Profit percentage = (net profit / input) * 100
profitPercentage, err := calc.decimalConverter.CalculatePercentage(netProfit, inputETH)
if err != nil {
return nil, nil, nil, fmt.Errorf("error calculating profit percentage: %w", err)
}
return grossProfit, netProfit, profitPercentage, nil
}
// calculateTotalPriceImpact calculates cumulative price impact across all steps
func (calc *ArbitrageCalculator) calculateTotalPriceImpact(route []ExchangeStep) (*UniversalDecimal, error) {
if len(route) == 0 {
return NewUniversalDecimal(big.NewInt(0), 4, "PERCENT")
}
// Compound price impacts: (1 + impact1) * (1 + impact2) - 1
compoundedImpact, err := calc.decimalConverter.FromString("1", 4, "COMPOUND")
if err != nil {
return nil, err
}
for _, step := range route {
// Convert price impact to factor (1 + impact)
one, _ := calc.decimalConverter.FromString("1", 4, "ONE")
impactFactor, err := calc.decimalConverter.Add(one, step.PriceImpact)
if err != nil {
return nil, fmt.Errorf("error calculating impact factor: %w", err)
}
// Multiply with cumulative impact
compoundedImpact, err = calc.decimalConverter.Multiply(compoundedImpact, impactFactor, 4, "COMPOUND")
if err != nil {
return nil, fmt.Errorf("error compounding impact: %w", err)
}
}
// Subtract 1 to get final impact percentage
one, _ := calc.decimalConverter.FromString("1", 4, "ONE")
totalImpact, err := calc.decimalConverter.Subtract(compoundedImpact, one)
if err != nil {
return nil, fmt.Errorf("error calculating total impact: %w", err)
}
return totalImpact, nil
}
// calculateMinimumOutput calculates minimum output accounting for slippage
func (calc *ArbitrageCalculator) calculateMinimumOutput(expectedOutput *UniversalDecimal) (*UniversalDecimal, error) {
// Apply slippage tolerance
slippageFactor, err := calc.decimalConverter.Subtract(
&UniversalDecimal{Value: big.NewInt(10000), Decimals: 4, Symbol: "ONE"},
calc.maxSlippage,
)
if err != nil {
return nil, err
}
return calc.decimalConverter.Multiply(expectedOutput, slippageFactor, 18, "TOKEN")
}
// assessRisks performs comprehensive risk assessment
func (calc *ArbitrageCalculator) assessRisks(route []ExchangeStep, priceImpact, netProfit *UniversalDecimal) RiskAssessment {
assessment := RiskAssessment{
Warnings: make([]string, 0),
}
// Assess liquidity risk
assessment.Liquidity = calc.assessLiquidityRisk(route)
// Assess price impact risk
assessment.PriceImpact = calc.assessPriceImpactRisk(priceImpact)
// Assess profitability risk
profitRisk := calc.assessProfitabilityRisk(netProfit)
// Assess gas price risk
assessment.GasPrice = calc.assessGasPriceRisk()
// Calculate overall risk (worst of all categories)
risks := []RiskLevel{assessment.Liquidity, assessment.PriceImpact, profitRisk, assessment.GasPrice}
assessment.Overall = calc.calculateOverallRisk(risks)
// Calculate OverallRisk as a numeric value (0.0 to 1.0) based on the overall risk level
switch assessment.Overall {
case RiskLow:
assessment.OverallRisk = 0.1
case RiskMedium:
assessment.OverallRisk = 0.4
case RiskHigh:
assessment.OverallRisk = 0.7
case RiskCritical:
assessment.OverallRisk = 0.95
default:
assessment.OverallRisk = 0.5 // Default to medium risk
}
return assessment
}
// Helper risk assessment methods
func (calc *ArbitrageCalculator) assessLiquidityRisk(route []ExchangeStep) RiskLevel {
for _, step := range route {
// For this simplified implementation, assume a mock liquidity value
// In a real implementation, you'd get this from the pricing engine
mockLiquidity, _ := calc.decimalConverter.FromString("1000", 18, "TOKEN") // 1000 tokens
if mockLiquidity.IsZero() {
return RiskHigh
}
// Check if trade size is significant portion of liquidity (>10%)
tenPercent, _ := calc.decimalConverter.FromString("10", 4, "PERCENT")
tradeSizePercent, _ := calc.decimalConverter.CalculatePercentage(step.AmountIn, mockLiquidity)
if comp, _ := calc.decimalConverter.Compare(tradeSizePercent, tenPercent); comp > 0 {
return RiskMedium
}
}
return RiskLow
}
func (calc *ArbitrageCalculator) assessPriceImpactRisk(priceImpact *UniversalDecimal) RiskLevel {
fivePercent, _ := calc.decimalConverter.FromString("5", 4, "PERCENT")
twoPercent, _ := calc.decimalConverter.FromString("2", 4, "PERCENT")
if comp, _ := calc.decimalConverter.Compare(priceImpact, fivePercent); comp > 0 {
return RiskHigh
}
if comp, _ := calc.decimalConverter.Compare(priceImpact, twoPercent); comp > 0 {
return RiskMedium
}
return RiskLow
}
func (calc *ArbitrageCalculator) assessProfitabilityRisk(netProfit *UniversalDecimal) RiskLevel {
if netProfit.IsNegative() {
return RiskCritical
}
smallProfit, _ := calc.decimalConverter.FromString("0.001", 18, "ETH") // $1 at $1000/ETH
mediumProfit, _ := calc.decimalConverter.FromString("0.01", 18, "ETH") // $10 at $1000/ETH
if comp, _ := calc.decimalConverter.Compare(netProfit, smallProfit); comp < 0 {
return RiskHigh
}
if comp, _ := calc.decimalConverter.Compare(netProfit, mediumProfit); comp < 0 {
return RiskMedium
}
return RiskLow
}
func (calc *ArbitrageCalculator) assessGasPriceRisk() RiskLevel {
currentGas, _ := calc.gasEstimator.GetCurrentGasPrice()
if comp, _ := calc.decimalConverter.Compare(currentGas, calc.maxGasPriceGwei); comp > 0 {
return RiskHigh
}
twentyGwei, _ := calc.decimalConverter.FromString("20", 9, "GWEI")
if comp, _ := calc.decimalConverter.Compare(currentGas, twentyGwei); comp > 0 {
return RiskMedium
}
return RiskLow
}
func (calc *ArbitrageCalculator) calculateOverallRisk(risks []RiskLevel) RiskLevel {
riskScores := map[RiskLevel]int{
RiskLow: 1,
RiskMedium: 2,
RiskHigh: 3,
RiskCritical: 4,
}
maxScore := 0
for _, risk := range risks {
if score := riskScores[risk]; score > maxScore {
maxScore = score
}
}
for risk, score := range riskScores {
if score == maxScore {
return risk
}
}
return RiskLow
}
// calculateConfidence calculates confidence score based on risk and profit
func (calc *ArbitrageCalculator) calculateConfidence(risk RiskAssessment, netProfit, priceImpact *UniversalDecimal) float64 {
baseConfidence := 0.5
// Adjust for risk level
switch risk.Overall {
case RiskLow:
baseConfidence += 0.3
case RiskMedium:
baseConfidence += 0.1
case RiskHigh:
baseConfidence -= 0.2
case RiskCritical:
baseConfidence -= 0.4
}
// Adjust for profit magnitude
if netProfit.IsPositive() {
largeProfit, _ := calc.decimalConverter.FromString("0.1", 18, "ETH")
if comp, _ := calc.decimalConverter.Compare(netProfit, largeProfit); comp > 0 {
baseConfidence += 0.2
}
}
// Adjust for price impact
lowImpact, _ := calc.decimalConverter.FromString("1", 4, "PERCENT")
if comp, _ := calc.decimalConverter.Compare(priceImpact, lowImpact); comp < 0 {
baseConfidence += 0.1
}
if baseConfidence < 0 {
baseConfidence = 0
}
if baseConfidence > 1 {
baseConfidence = 1
}
return baseConfidence
}
// estimateExecutionTime estimates execution time in milliseconds
func (calc *ArbitrageCalculator) estimateExecutionTime(route []ExchangeStep) int64 {
baseTime := int64(500) // 500ms base
// Add time per hop
hopTime := int64(len(route)) * 200
// Add time for complex exchanges
complexTime := int64(0)
for _, step := range route {
switch ExchangeType(step.Exchange) {
case ExchangeUniswapV3, ExchangeCamelot:
complexTime += 300 // Concentrated liquidity is more complex
case ExchangeBalancer, ExchangeCurve:
complexTime += 400 // Weighted/stable pools are complex
default:
complexTime += 100 // Simple AMM
}
}
return baseTime + hopTime + complexTime
}
// convertToETH converts any token amount to ETH for comparison (placeholder)
func (calc *ArbitrageCalculator) convertToETH(amount *UniversalDecimal, token TokenInfo) *UniversalDecimal {
// This is a placeholder - in production, this would query price oracles
// For now, assume 1:1 conversion for demonstration
ethAmount, _ := calc.decimalConverter.ConvertTo(amount, 18, "ETH")
return ethAmount
}
// IsOpportunityProfitable checks if opportunity meets minimum criteria
func (calc *ArbitrageCalculator) IsOpportunityProfitable(opportunity *types.ArbitrageOpportunity) bool {
// Check minimum profit threshold (simplified comparison)
if opportunity.NetProfit.Cmp(big.NewInt(1000000000000000)) < 0 { // 0.001 ETH minimum
return false
}
// Check maximum price impact threshold (5% max)
if opportunity.PriceImpact > 0.05 {
return false
}
// Check risk level
if opportunity.Risk >= 0.8 { // High risk threshold
return false
}
// Check confidence threshold
if opportunity.Confidence < 0.3 {
return false
}
return true
}
// SortOpportunitiesByProfitability sorts opportunities by net profit descending
func (calc *ArbitrageCalculator) SortOpportunitiesByProfitability(opportunities []*types.ArbitrageOpportunity) {
sort.Slice(opportunities, func(i, j int) bool {
// Simple comparison using big.Int.Cmp for sorting
return opportunities[i].NetProfit.Cmp(opportunities[j].NetProfit) > 0 // Descending order (highest profit first)
})
}
// CalculateArbitrage calculates arbitrage opportunity for a given path and input amount
func (calc *ArbitrageCalculator) CalculateArbitrage(ctx context.Context, inputAmount *UniversalDecimal, path []*PoolData) (*types.ArbitrageOpportunity, error) {
if len(path) == 0 {
return nil, fmt.Errorf("empty path provided")
}
// Get the input and output tokens for the path
inputToken := path[0].Token0
outputToken := path[len(path)-1].Token1
if path[len(path)-1].Token0.Address == inputToken.Address {
outputToken = path[len(path)-1].Token0
}
// Calculate the arbitrage opportunity for this path
opportunity, err := calc.CalculateArbitrageOpportunity(path, inputAmount, inputToken, outputToken)
if err != nil {
return nil, fmt.Errorf("failed to calculate arbitrage opportunity: %w", err)
}
return opportunity, nil
}
// FindOptimalPath finds the most profitable arbitrage path between two tokens
func (calc *ArbitrageCalculator) FindOptimalPath(ctx context.Context, tokenA, tokenB common.Address, amount *UniversalDecimal) (*types.ArbitrageOpportunity, error) {
// In a real implementation, this would query for available paths between tokens
// and calculate the most profitable path. For this implementation, we'll return an error
// indicating no path is available since we don't have direct path-finding ability in the calculator
return nil, fmt.Errorf("FindOptimalPath not implemented in calculator - use executor.CalculateOptimalPath instead")
}
// FilterProfitableOpportunities returns only profitable opportunities
func (calc *ArbitrageCalculator) FilterProfitableOpportunities(opportunities []*types.ArbitrageOpportunity) []*types.ArbitrageOpportunity {
profitable := make([]*types.ArbitrageOpportunity, 0)
for _, opp := range opportunities {
if calc.IsOpportunityProfitable(opp) {
profitable = append(profitable, opp)
}
}
return profitable
}

View File

@@ -0,0 +1,96 @@
package math
import (
"math/big"
"testing"
"github.com/holiman/uint256"
)
// Benchmark original vs cached SqrtPriceX96ToPrice conversion
func BenchmarkSqrtPriceX96ToPriceOriginal(b *testing.B) {
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Original calculation: price = sqrtPriceX96^2 / 2^192
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
price := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
// Calculate 2^192
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
q192Float := new(big.Float).SetInt(q192)
// Divide by 2^192
price.Quo(price, q192Float)
}
}
func BenchmarkSqrtPriceX96ToPriceCached(b *testing.B) {
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Cached calculation using precomputed constants
SqrtPriceX96ToPriceCached(sqrtPriceX96)
}
}
// Benchmark original vs cached PriceToSqrtPriceX96 conversion
func BenchmarkPriceToSqrtPriceX96Original(b *testing.B) {
// Use a typical price value (represents price of ~2000 USDC/ETH)
price := new(big.Float).SetFloat64(2000.0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Original calculation: sqrtPriceX96 = sqrt(price * 2^192)
// Calculate 2^192
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
q192Float := new(big.Float).SetInt(q192)
// Multiply price by 2^192
result := new(big.Float).Mul(price, q192Float)
// Calculate square root
result.Sqrt(result)
// Convert to big.Int
sqrtPriceX96 := new(big.Int)
result.Int(sqrtPriceX96)
}
}
func BenchmarkPriceToSqrtPriceX96Cached(b *testing.B) {
// Use a typical price value (represents price of ~2000 USDC/ETH)
price := new(big.Float).SetFloat64(2000.0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Cached calculation using precomputed constants
PriceToSqrtPriceX96Cached(price)
}
}
// Benchmark optimized versions with uint256
func BenchmarkSqrtPriceX96ToPriceOptimized(b *testing.B) {
// Use a typical sqrtPriceX96 value
sqrtPriceX96 := uint256.NewInt(0).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
b.ResetTimer()
for i := 0; i < b.N; i++ {
SqrtPriceX96ToPriceOptimized(sqrtPriceX96)
}
}
func BenchmarkPriceToSqrtPriceX96Optimized(b *testing.B) {
// Use a typical price value
price := new(big.Float).SetFloat64(2000.0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
PriceToSqrtPriceX96Optimized(price)
}
}

View File

@@ -0,0 +1,125 @@
package math
import (
"math/big"
"sync"
"github.com/fraktal/mev-beta/pkg/uniswap"
"github.com/holiman/uint256"
)
// Cached mathematical constants to avoid recomputation
var (
cachedConstantsOnce sync.Once
cachedQ192 *big.Int
cachedQ96 *big.Int
cachedQ384 *big.Int
cachedTwoPower96 *big.Float
cachedTwoPower192 *big.Float
cachedTwoPower384 *big.Float
)
// initCachedConstants initializes all cached constants once
func initCachedConstants() {
cachedConstantsOnce.Do(func() {
// Calculate 2^96
cachedQ96 = new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)
// Calculate 2^192
cachedQ192 = new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
// Calculate 2^384
cachedQ384 = new(big.Int).Exp(big.NewInt(2), big.NewInt(384), nil)
// Convert to big.Float for division operations
cachedTwoPower96 = new(big.Float).SetInt(cachedQ96)
cachedTwoPower192 = new(big.Float).SetInt(cachedQ192)
cachedTwoPower384 = new(big.Float).SetInt(cachedQ384)
})
}
// GetCachedQ192 returns the cached value of 2^192
func GetCachedQ192() *big.Int {
initCachedConstants()
return cachedQ192
}
// GetCachedQ96 returns the cached value of 2^96
func GetCachedQ96() *big.Int {
initCachedConstants()
return cachedQ96
}
// GetCachedQ384 returns the cached value of 2^384
func GetCachedQ384() *big.Int {
initCachedConstants()
return cachedQ384
}
// SqrtPriceX96ToPriceCached converts sqrtPriceX96 to a price using cached constants
// Formula: price = sqrtPriceX96^2 / 2^192
func SqrtPriceX96ToPriceCached(sqrtPriceX96 *big.Int) *big.Float {
initCachedConstants()
// Convert to big.Float for precision
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
// Calculate sqrtPrice^2
price := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
// Divide by 2^192 using cached constant
price.Quo(price, cachedTwoPower192)
return price
}
// PriceToSqrtPriceX96Cached converts a price to sqrtPriceX96 using cached constants
// Formula: sqrtPriceX96 = sqrt(price * 2^192)
func PriceToSqrtPriceX96Cached(price *big.Float) *big.Int {
initCachedConstants()
// Multiply price by 2^192
result := new(big.Float).Mul(price, cachedTwoPower192)
// Calculate square root
result.Sqrt(result)
// Convert to big.Int
sqrtPriceX96 := new(big.Int)
result.Int(sqrtPriceX96)
return sqrtPriceX96
}
// SqrtPriceX96ToPriceOptimized converts sqrtPriceX96 to a price using optimized uint256 operations
// Formula: price = sqrtPriceX96^2 / 2^192
func SqrtPriceX96ToPriceOptimized(sqrtPriceX96 *uint256.Int) *big.Float {
initCachedConstants()
// Convert to big.Int for calculation
sqrtPriceBig := sqrtPriceX96.ToBig()
// Use cached function for consistency
return SqrtPriceX96ToPriceCached(sqrtPriceBig)
}
// PriceToSqrtPriceX96Optimized converts a price to sqrtPriceX96 using optimized operations
// Formula: sqrtPriceX96 = sqrt(price * 2^192)
func PriceToSqrtPriceX96Optimized(price *big.Float) *uint256.Int {
initCachedConstants()
// Use cached function for consistency
sqrtPriceBig := PriceToSqrtPriceX96Cached(price)
// Convert to uint256
return uint256.MustFromBig(sqrtPriceBig)
}
// TickToSqrtPriceX96Optimized calculates sqrtPriceX96 from a tick using optimized operations
// Formula: sqrtPriceX96 = 1.0001^(tick/2)
func TickToSqrtPriceX96Optimized(tick int) *uint256.Int {
// For simplicity, we'll convert to big.Int and use existing implementation
tickBig := big.NewInt(int64(tick))
sqrtPriceBig := uniswap.TickToSqrtPriceX96(int(tickBig.Int64()))
return uint256.MustFromBig(sqrtPriceBig)
}

127
pkg/math/cached_test.go Normal file
View File

@@ -0,0 +1,127 @@
package math
import (
"math/big"
"testing"
"github.com/holiman/uint256"
"github.com/stretchr/testify/assert"
)
// Test that cached functions produce the same results as original implementations
func TestCachedFunctionAccuracy(t *testing.T) {
// Test SqrtPriceX96ToPrice functions
t.Run("SqrtPriceX96ToPrice", func(t *testing.T) {
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
// Original calculation
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
originalPrice := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
q192Float := new(big.Float).SetInt(q192)
originalPrice.Quo(originalPrice, q192Float)
// Cached calculation
cachedPrice := SqrtPriceX96ToPriceCached(sqrtPriceX96)
// Compare results (should be identical)
assert.Equal(t, originalPrice.String(), cachedPrice.String(), "Cached and original SqrtPriceX96ToPrice should produce identical results")
})
// Test PriceToSqrtPriceX96 functions
t.Run("PriceToSqrtPriceX96", func(t *testing.T) {
// Use a typical price value (represents price of ~2000 USDC/ETH)
price := new(big.Float).SetFloat64(2000.0)
// Original calculation
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
q192Float := new(big.Float).SetInt(q192)
result := new(big.Float).Mul(price, q192Float)
result.Sqrt(result)
expectedSqrtPriceX96 := new(big.Int)
result.Int(expectedSqrtPriceX96)
// Cached calculation
actualSqrtPriceX96 := PriceToSqrtPriceX96Cached(price)
// Compare results (should be identical)
assert.Equal(t, expectedSqrtPriceX96.String(), actualSqrtPriceX96.String(), "Cached and original PriceToSqrtPriceX96 should produce identical results")
})
// Test optimized functions with uint256
t.Run("SqrtPriceX96ToPriceOptimized", func(t *testing.T) {
// Use a typical sqrtPriceX96 value
sqrtPriceX96Big := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
sqrtPriceX96 := uint256.MustFromBig(sqrtPriceX96Big)
// Cached calculation
cachedResult := SqrtPriceX96ToPriceCached(sqrtPriceX96Big)
// Optimized calculation
optimizedResult := SqrtPriceX96ToPriceOptimized(sqrtPriceX96)
// Compare results (should be identical)
assert.Equal(t, cachedResult.String(), optimizedResult.String(), "Optimized and cached SqrtPriceX96ToPrice should produce identical results")
})
// Test optimized functions with uint256
t.Run("PriceToSqrtPriceX96Optimized", func(t *testing.T) {
// Use a typical price value
price := new(big.Float).SetFloat64(2000.0)
// Cached calculation
cachedResult := PriceToSqrtPriceX96Cached(price)
// Optimized calculation
optimizedResult := PriceToSqrtPriceX96Optimized(price)
// Compare results (should be identical)
assert.Equal(t, cachedResult.String(), optimizedResult.ToBig().String(), "Optimized and cached PriceToSqrtPriceX96 should produce identical results")
})
}
// Test that cached constants are working correctly
func TestCachedConstants(t *testing.T) {
// Test that Q192 is correctly calculated
expectedQ192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
actualQ192 := GetCachedQ192()
assert.Equal(t, expectedQ192.String(), actualQ192.String(), "Cached Q192 should equal 2^192")
// Test that Q96 is correctly calculated
expectedQ96 := new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)
actualQ96 := GetCachedQ96()
assert.Equal(t, expectedQ96.String(), actualQ96.String(), "Cached Q96 should equal 2^96")
// Test that Q384 is correctly calculated
expectedQ384 := new(big.Int).Exp(big.NewInt(2), big.NewInt(384), nil)
actualQ384 := GetCachedQ384()
assert.Equal(t, expectedQ384.String(), actualQ384.String(), "Cached Q384 should equal 2^384")
}
// Test edge cases
func TestEdgeCases(t *testing.T) {
// Test with zero values
zero := big.NewInt(0)
zeroFloat := new(big.Float).SetInt64(0)
// SqrtPriceX96ToPrice with zero
result := SqrtPriceX96ToPriceCached(zero)
assert.Equal(t, "0", result.String(), "SqrtPriceX96ToPriceCached with zero should return zero")
// PriceToSqrtPriceX96 with zero
result2 := PriceToSqrtPriceX96Cached(zeroFloat)
assert.Equal(t, "0", result2.String(), "PriceToSqrtPriceX96Cached with zero should return zero")
// Test with small values
one := big.NewInt(1)
oneFloat := new(big.Float).SetInt64(1)
// SqrtPriceX96ToPrice with one
result3 := SqrtPriceX96ToPriceCached(one)
assert.NotEmpty(t, result3.String(), "SqrtPriceX96ToPriceCached with one should return a value")
// PriceToSqrtPriceX96 with one
result4 := PriceToSqrtPriceX96Cached(oneFloat)
assert.NotEmpty(t, result4.String(), "PriceToSqrtPriceX96Cached with one should return a value")
}

405
pkg/math/decimal_handler.go Normal file
View File

@@ -0,0 +1,405 @@
package math
import (
"fmt"
"math/big"
"strings"
)
// UniversalDecimal represents a token amount with precise decimal handling
type UniversalDecimal struct {
Value *big.Int // Raw value in smallest unit
Decimals uint8 // Number of decimal places (0-18)
Symbol string // Token symbol for debugging
}
// DecimalConverter handles conversions between different decimal precisions
type DecimalConverter struct {
// Cache for common scaling factors to avoid repeated calculations
scalingFactors map[uint8]*big.Int
}
// NewDecimalConverter creates a new decimal converter with caching
func NewDecimalConverter() *DecimalConverter {
dc := &DecimalConverter{
scalingFactors: make(map[uint8]*big.Int),
}
// Pre-calculate common scaling factors (0-18 decimals)
for i := uint8(0); i <= 18; i++ {
dc.scalingFactors[i] = new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(i)), nil)
}
return dc
}
// NewUniversalDecimal creates a new universal decimal with validation
func NewUniversalDecimal(value *big.Int, decimals uint8, symbol string) (*UniversalDecimal, error) {
if decimals > 18 {
return nil, fmt.Errorf("decimal places cannot exceed 18, got %d for token %s", decimals, symbol)
}
if value == nil {
value = big.NewInt(0)
}
// Copy the value to prevent external modifications
valueCopy := new(big.Int).Set(value)
return &UniversalDecimal{
Value: valueCopy,
Decimals: decimals,
Symbol: symbol,
}, nil
}
// FromString creates UniversalDecimal from string representation
// Intelligently determines format:
// 1. Very large numbers (length >= decimals): treated as raw wei/smallest unit
// 2. Small numbers (length < decimals): treated as human-readable units
// 3. Numbers with decimal point: always treated as human-readable
func (dc *DecimalConverter) FromString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) {
// Handle empty or zero values
if valueStr == "" || valueStr == "0" {
return NewUniversalDecimal(big.NewInt(0), decimals, symbol)
}
// Remove any whitespace
valueStr = strings.TrimSpace(valueStr)
// Check for decimal point - if present, treat as human-readable decimal
if strings.Contains(valueStr, ".") {
return dc.fromDecimalString(valueStr, decimals, symbol)
}
// For integers without decimal point, we need to determine if this is:
// - A raw value (like "1000000000000000000" = 1000000000000000000 wei)
// - A human-readable value (like "1" = 1.0 ETH = 1000000000000000000 wei)
// Parse the number first
value := new(big.Int)
_, success := value.SetString(valueStr, 10)
if !success {
return nil, fmt.Errorf("invalid number format: %s for token %s", valueStr, symbol)
}
// Heuristic: if the string length is >= decimals, treat as raw value
// This handles cases like "1000000000000000000" (18 chars, 18 decimals) as raw
// But treats "1" (1 char, 18 decimals) as human-readable
if len(valueStr) >= int(decimals) && decimals > 0 {
// Treat as raw value in smallest unit
return NewUniversalDecimal(value, decimals, symbol)
}
// Treat as human-readable value - convert to smallest unit
scalingFactor := dc.getScalingFactor(decimals)
scaledValue := new(big.Int).Mul(value, scalingFactor)
return NewUniversalDecimal(scaledValue, decimals, symbol)
}
// fromDecimalString parses decimal string (e.g., "1.23") to smallest unit
func (dc *DecimalConverter) fromDecimalString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) {
parts := strings.Split(valueStr, ".")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid decimal format: %s for token %s", valueStr, symbol)
}
integerPart := parts[0]
decimalPart := parts[1]
// Validate decimal part doesn't exceed token decimals
if len(decimalPart) > int(decimals) {
return nil, fmt.Errorf("decimal part %s has %d digits, but token %s only supports %d decimals",
decimalPart, len(decimalPart), symbol, decimals)
}
// Parse integer part
intValue := new(big.Int)
if integerPart != "" && integerPart != "0" {
_, success := intValue.SetString(integerPart, 10)
if !success {
return nil, fmt.Errorf("invalid integer part: %s for token %s", integerPart, symbol)
}
}
// Parse decimal part
decValue := new(big.Int)
if decimalPart != "" && decimalPart != "0" {
// Pad decimal part to full precision
paddedDecimal := decimalPart
for len(paddedDecimal) < int(decimals) {
paddedDecimal += "0"
}
_, success := decValue.SetString(paddedDecimal, 10)
if !success {
return nil, fmt.Errorf("invalid decimal part: %s for token %s", decimalPart, symbol)
}
}
// Combine integer and decimal parts
scalingFactor := dc.getScalingFactor(decimals)
totalValue := new(big.Int).Mul(intValue, scalingFactor)
totalValue.Add(totalValue, decValue)
return NewUniversalDecimal(totalValue, decimals, symbol)
}
// ToHumanReadable converts to human-readable decimal string
// For round-trip precision preservation with FromString, returns raw value when appropriate
func (dc *DecimalConverter) ToHumanReadable(ud *UniversalDecimal) string {
if ud.Value.Sign() == 0 {
return "0"
}
// For round-trip precision preservation, if the value represents exact units
// (like 1000000000000000000 wei = exactly 1 ETH), output the human readable form
// Otherwise, output the raw value to preserve precision
if ud.Decimals == 0 {
return ud.Value.String()
}
scalingFactor := dc.getScalingFactor(ud.Decimals)
// Get integer and remainder parts
integerPart := new(big.Int).Div(ud.Value, scalingFactor)
remainder := new(big.Int).Mod(ud.Value, scalingFactor)
// If this is an exact unit (no fractional part), return human readable
if remainder.Sign() == 0 {
return integerPart.String()
}
// For values with fractional parts, we need to decide:
// If the value looks like it came from raw input (very large numbers),
// preserve it as raw to maintain round-trip precision
// Check if this looks like a raw value by comparing magnitude
valueStr := ud.Value.String()
if len(valueStr) >= int(ud.Decimals) {
// This is likely a raw value, preserve as raw for round-trip
return ud.Value.String()
}
// Format as human readable decimal
decimalStr := remainder.String()
for len(decimalStr) < int(ud.Decimals) {
decimalStr = "0" + decimalStr
}
// Remove trailing zeros for readability
decimalStr = strings.TrimRight(decimalStr, "0")
if decimalStr == "" {
return integerPart.String()
}
return fmt.Sprintf("%s.%s", integerPart.String(), decimalStr)
}
// ConvertTo converts between different decimal precisions
func (dc *DecimalConverter) ConvertTo(from *UniversalDecimal, toDecimals uint8, toSymbol string) (*UniversalDecimal, error) {
if from.Decimals == toDecimals {
// Same precision, just copy with new symbol
return NewUniversalDecimal(from.Value, toDecimals, toSymbol)
}
var convertedValue *big.Int
if from.Decimals < toDecimals {
// Increase precision (multiply)
decimalDiff := toDecimals - from.Decimals
scalingFactor := dc.getScalingFactor(decimalDiff)
convertedValue = new(big.Int).Mul(from.Value, scalingFactor)
} else {
// Decrease precision (divide with rounding)
decimalDiff := from.Decimals - toDecimals
scalingFactor := dc.getScalingFactor(decimalDiff)
// Round to nearest (banker's rounding)
halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2))
roundedValue := new(big.Int).Add(from.Value, halfScaling)
convertedValue = new(big.Int).Div(roundedValue, scalingFactor)
}
return NewUniversalDecimal(convertedValue, toDecimals, toSymbol)
}
// Multiply performs precise multiplication between different decimal tokens
func (dc *DecimalConverter) Multiply(a, b *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) {
// Multiply raw values
product := new(big.Int).Mul(a.Value, b.Value)
// Adjust for decimal places (division by 10^(a.decimals + b.decimals - result.decimals))
totalInputDecimals := a.Decimals + b.Decimals
var adjustedProduct *big.Int
if totalInputDecimals >= resultDecimals {
decimalDiff := totalInputDecimals - resultDecimals
scalingFactor := dc.getScalingFactor(decimalDiff)
// Round to nearest
halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2))
roundedProduct := new(big.Int).Add(product, halfScaling)
adjustedProduct = new(big.Int).Div(roundedProduct, scalingFactor)
} else {
decimalDiff := resultDecimals - totalInputDecimals
scalingFactor := dc.getScalingFactor(decimalDiff)
adjustedProduct = new(big.Int).Mul(product, scalingFactor)
}
return NewUniversalDecimal(adjustedProduct, resultDecimals, resultSymbol)
}
// Divide performs precise division between different decimal tokens
func (dc *DecimalConverter) Divide(numerator, denominator *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) {
if denominator.Value.Sign() == 0 {
return nil, fmt.Errorf("division by zero: %s / %s", numerator.Symbol, denominator.Symbol)
}
// Scale numerator to maintain precision
totalDecimals := numerator.Decimals + resultDecimals
scalingFactor := dc.getScalingFactor(totalDecimals - denominator.Decimals)
scaledNumerator := new(big.Int).Mul(numerator.Value, scalingFactor)
quotient := new(big.Int).Div(scaledNumerator, denominator.Value)
return NewUniversalDecimal(quotient, resultDecimals, resultSymbol)
}
// Add adds two UniversalDecimals with same precision
func (dc *DecimalConverter) Add(a, b *UniversalDecimal) (*UniversalDecimal, error) {
if a.Decimals != b.Decimals {
return nil, fmt.Errorf("cannot add tokens with different decimals: %s(%d) + %s(%d)",
a.Symbol, a.Decimals, b.Symbol, b.Decimals)
}
sum := new(big.Int).Add(a.Value, b.Value)
resultSymbol := a.Symbol
if a.Symbol != b.Symbol {
resultSymbol = fmt.Sprintf("%s+%s", a.Symbol, b.Symbol)
}
return NewUniversalDecimal(sum, a.Decimals, resultSymbol)
}
// Subtract subtracts two UniversalDecimals with same precision
func (dc *DecimalConverter) Subtract(a, b *UniversalDecimal) (*UniversalDecimal, error) {
if a.Decimals != b.Decimals {
return nil, fmt.Errorf("cannot subtract tokens with different decimals: %s(%d) - %s(%d)",
a.Symbol, a.Decimals, b.Symbol, b.Decimals)
}
diff := new(big.Int).Sub(a.Value, b.Value)
resultSymbol := a.Symbol
if a.Symbol != b.Symbol {
resultSymbol = fmt.Sprintf("%s-%s", a.Symbol, b.Symbol)
}
return NewUniversalDecimal(diff, a.Decimals, resultSymbol)
}
// Compare returns -1, 0, or 1 for a < b, a == b, a > b respectively
func (dc *DecimalConverter) Compare(a, b *UniversalDecimal) (int, error) {
if a.Decimals != b.Decimals {
// Convert to same precision for comparison
converted, err := dc.ConvertTo(b, a.Decimals, b.Symbol)
if err != nil {
return 0, fmt.Errorf("cannot compare tokens with different decimals: %w", err)
}
b = converted
}
return a.Value.Cmp(b.Value), nil
}
// IsZero checks if the value is zero
func (ud *UniversalDecimal) IsZero() bool {
return ud.Value.Sign() == 0
}
// IsPositive checks if the value is positive
func (ud *UniversalDecimal) IsPositive() bool {
return ud.Value.Sign() > 0
}
// IsNegative checks if the value is negative
func (ud *UniversalDecimal) IsNegative() bool {
return ud.Value.Sign() < 0
}
// Copy creates a deep copy of the UniversalDecimal
func (ud *UniversalDecimal) Copy() *UniversalDecimal {
return &UniversalDecimal{
Value: new(big.Int).Set(ud.Value),
Decimals: ud.Decimals,
Symbol: ud.Symbol,
}
}
// getScalingFactor returns the scaling factor for given decimals (cached)
func (dc *DecimalConverter) getScalingFactor(decimals uint8) *big.Int {
if factor, exists := dc.scalingFactors[decimals]; exists {
return factor
}
// Calculate and cache if not exists (shouldn't happen for 0-18)
factor := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimals)), nil)
dc.scalingFactors[decimals] = factor
return factor
}
// ToWei converts any decimal precision to 18-decimal wei representation
func (dc *DecimalConverter) ToWei(ud *UniversalDecimal) *UniversalDecimal {
weiValue, _ := dc.ConvertTo(ud, 18, "WEI")
return weiValue
}
// FromWei converts 18-decimal wei to specified decimal precision
func (dc *DecimalConverter) FromWei(weiValue *big.Int, targetDecimals uint8, targetSymbol string) *UniversalDecimal {
weiDecimal := &UniversalDecimal{
Value: new(big.Int).Set(weiValue),
Decimals: 18,
Symbol: "WEI",
}
result, _ := dc.ConvertTo(weiDecimal, targetDecimals, targetSymbol)
return result
}
// CalculatePercentage calculates percentage with precise decimal handling
// Returns percentage as UniversalDecimal with 4 decimal places (e.g., 1.5000% = 15000 with 4 decimals)
func (dc *DecimalConverter) CalculatePercentage(value, total *UniversalDecimal) (*UniversalDecimal, error) {
if total.IsZero() {
return nil, fmt.Errorf("cannot calculate percentage with zero total")
}
// Convert to same precision if needed
if value.Decimals != total.Decimals {
convertedValue, err := dc.ConvertTo(value, total.Decimals, value.Symbol)
if err != nil {
return nil, fmt.Errorf("error converting decimals for percentage: %w", err)
}
value = convertedValue
}
// Calculate (value / total) * 100 using integer arithmetic to avoid floating point errors
// Formula: (value * 100 * 10^4) / total where 10^4 gives us 4 decimal places
// Multiply value by 100 * 10^4 = 1,000,000 for percentage with 4 decimal places
hundredWithDecimals := big.NewInt(1000000) // 100.0000 in 4-decimal format
numerator := new(big.Int).Mul(value.Value, hundredWithDecimals)
// Divide by total to get percentage
percentage := new(big.Int).Div(numerator, total.Value)
return NewUniversalDecimal(percentage, 4, "PERCENT")
}
// String returns string representation for debugging
func (ud *UniversalDecimal) String() string {
dc := NewDecimalConverter()
humanReadable := dc.ToHumanReadable(ud)
return fmt.Sprintf("%s %s", humanReadable, ud.Symbol)
}

View File

@@ -113,6 +113,11 @@ func (u *UniswapV2Math) CalculateAmountIn(amountOut, reserveIn, reserveOut *big.
// CalculatePriceImpact calculates price impact for Uniswap V2
func (u *UniswapV2Math) CalculatePriceImpact(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
// Check for nil pointers first
if amountIn == nil || reserveIn == nil || reserveOut == nil {
return 0, fmt.Errorf("nil pointer encountered")
}
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return 0, fmt.Errorf("invalid amounts")
}
@@ -236,6 +241,11 @@ func (u *UniswapV3Math) CalculateAmountIn(amountOut, sqrtPriceX96, liquidity *bi
// CalculatePriceImpact calculates price impact for Uniswap V3
func (u *UniswapV3Math) CalculatePriceImpact(amountIn, sqrtPriceX96, liquidity *big.Int) (float64, error) {
// Check for nil pointers first
if amountIn == nil || sqrtPriceX96 == nil || liquidity == nil {
return 0, fmt.Errorf("nil pointer encountered")
}
if amountIn.Sign() <= 0 || sqrtPriceX96.Sign() <= 0 || liquidity.Sign() <= 0 {
return 0, fmt.Errorf("invalid parameters")
}
@@ -428,6 +438,15 @@ func (c *CurveMath) CalculateAmountIn(amountOut, balance0, balance1 *big.Int, fe
// CalculatePriceImpact calculates price impact for Curve
func (c *CurveMath) CalculatePriceImpact(amountIn, balance0, balance1 *big.Int) (float64, error) {
// Check for nil pointers first
if amountIn == nil || balance0 == nil || balance1 == nil {
return 0, fmt.Errorf("nil pointer encountered")
}
if amountIn.Sign() <= 0 || balance0.Sign() <= 0 || balance1.Sign() <= 0 {
return 0, fmt.Errorf("invalid amounts")
}
// Price before = balance1 / balance0
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(balance1), new(big.Float).SetInt(balance0))
@@ -709,6 +728,11 @@ func (b *BalancerMath) CalculateAmountIn(amountOut, reserveIn, reserveOut *big.I
// CalculatePriceImpact calculates price impact for Balancer
func (b *BalancerMath) CalculatePriceImpact(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
// Check for nil pointers first
if amountIn == nil || reserveIn == nil || reserveOut == nil {
return 0, fmt.Errorf("nil pointer encountered")
}
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return 0, fmt.Errorf("invalid amounts")
}

View File

@@ -0,0 +1,520 @@
package math
import (
"fmt"
"math/big"
)
// ExchangeType represents different DEX protocols on Arbitrum
type ExchangeType string
const (
ExchangeUniswapV3 ExchangeType = "uniswap_v3"
ExchangeUniswapV2 ExchangeType = "uniswap_v2"
ExchangeSushiSwap ExchangeType = "sushiswap"
ExchangeCamelot ExchangeType = "camelot"
ExchangeBalancer ExchangeType = "balancer"
ExchangeTraderJoe ExchangeType = "traderjoe"
ExchangeRamses ExchangeType = "ramses"
ExchangeCurve ExchangeType = "curve"
)
// ExchangePricer interface for exchange-specific price calculations
type ExchangePricer interface {
GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error)
CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error)
ValidatePoolData(poolData *PoolData) error
}
// PoolData represents universal pool data structure
type PoolData struct {
Address string
ExchangeType ExchangeType
Token0 TokenInfo
Token1 TokenInfo
Reserve0 *UniversalDecimal
Reserve1 *UniversalDecimal
Fee *UniversalDecimal // Fee as percentage (e.g., 0.003 for 0.3%)
// Uniswap V3 specific
SqrtPriceX96 *big.Int
Tick *big.Int
Liquidity *big.Int
// Curve specific
A *big.Int // Amplification coefficient
// Balancer specific
Weights []*UniversalDecimal // Token weights
SwapFeeRate *UniversalDecimal // Swap fee rate
}
// TokenInfo represents token metadata
type TokenInfo struct {
Address string
Symbol string
Decimals uint8
}
// ExchangePricingEngine manages all exchange-specific pricing logic
type ExchangePricingEngine struct {
decimalConverter *DecimalConverter
pricers map[ExchangeType]ExchangePricer
}
// NewExchangePricingEngine creates a new pricing engine with all exchange support
func NewExchangePricingEngine() *ExchangePricingEngine {
dc := NewDecimalConverter()
engine := &ExchangePricingEngine{
decimalConverter: dc,
pricers: make(map[ExchangeType]ExchangePricer),
}
// Register all exchange pricers
engine.pricers[ExchangeUniswapV3] = NewUniswapV3Pricer(dc)
engine.pricers[ExchangeUniswapV2] = NewUniswapV2Pricer(dc)
engine.pricers[ExchangeSushiSwap] = NewSushiSwapPricer(dc)
engine.pricers[ExchangeCamelot] = NewCamelotPricer(dc)
engine.pricers[ExchangeBalancer] = NewBalancerPricer(dc)
engine.pricers[ExchangeTraderJoe] = NewTraderJoePricer(dc)
engine.pricers[ExchangeRamses] = NewRamsesPricer(dc)
engine.pricers[ExchangeCurve] = NewCurvePricer(dc)
return engine
}
// GetExchangePricer returns the appropriate pricer for an exchange
func (engine *ExchangePricingEngine) GetExchangePricer(exchangeType ExchangeType) (ExchangePricer, error) {
pricer, exists := engine.pricers[exchangeType]
if !exists {
return nil, fmt.Errorf("unsupported exchange type: %s", exchangeType)
}
return pricer, nil
}
// CalculateSpotPrice gets spot price from any exchange
func (engine *ExchangePricingEngine) CalculateSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
pricer, err := engine.GetExchangePricer(poolData.ExchangeType)
if err != nil {
return nil, err
}
return pricer.GetSpotPrice(poolData)
}
// UniswapV3Pricer implements Uniswap V3 concentrated liquidity pricing
type UniswapV3Pricer struct {
dc *DecimalConverter
}
func NewUniswapV3Pricer(dc *DecimalConverter) *UniswapV3Pricer {
return &UniswapV3Pricer{dc: dc}
}
func (p *UniswapV3Pricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
if poolData.SqrtPriceX96 == nil {
return nil, fmt.Errorf("missing sqrtPriceX96 for Uniswap V3 pool")
}
// Use cached function for optimized calculation
// Convert sqrtPriceX96 to actual price using cached constants
// price = sqrtPriceX96^2 / 2^192
price := SqrtPriceX96ToPriceCached(poolData.SqrtPriceX96)
// Adjust for decimal differences between tokens
if poolData.Token0.Decimals != poolData.Token1.Decimals {
decimalDiff := int(poolData.Token1.Decimals) - int(poolData.Token0.Decimals)
adjustment := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimalDiff)), nil))
price.Mul(price, adjustment)
}
// Convert back to big.Int with appropriate precision
priceInt := new(big.Int)
priceScaled := new(big.Float).Mul(price, new(big.Float).SetInt(p.dc.getScalingFactor(18)))
priceScaled.Int(priceInt)
return NewUniversalDecimal(priceInt, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
}
func (p *UniswapV3Pricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Uniswap V3 concentrated liquidity calculation
// This is a simplified version - production would need full tick math
if poolData.Liquidity == nil || poolData.Liquidity.Sign() == 0 {
return nil, fmt.Errorf("insufficient liquidity in Uniswap V3 pool")
}
// Apply fee
feeAmount, err := p.dc.Multiply(amountIn, poolData.Fee, amountIn.Decimals, "FEE")
if err != nil {
return nil, fmt.Errorf("error calculating fee: %w", err)
}
amountInAfterFee, err := p.dc.Subtract(amountIn, feeAmount)
if err != nil {
return nil, fmt.Errorf("error subtracting fee: %w", err)
}
// Simplified constant product formula for demonstration
// Real implementation would use tick mathematics
numerator, err := p.dc.Multiply(amountInAfterFee, poolData.Reserve1, poolData.Reserve1.Decimals, "TEMP")
if err != nil {
return nil, err
}
denominator, err := p.dc.Add(poolData.Reserve0, amountInAfterFee)
if err != nil {
return nil, err
}
return p.dc.Divide(numerator, denominator, poolData.Token1.Decimals, poolData.Token1.Symbol)
}
func (p *UniswapV3Pricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Reverse calculation for Uniswap V3
if poolData.Reserve1.IsZero() || amountOut.Value.Cmp(poolData.Reserve1.Value) >= 0 {
return nil, fmt.Errorf("insufficient liquidity for requested output amount")
}
// Simplified reverse calculation
numerator, err := p.dc.Multiply(poolData.Reserve0, amountOut, poolData.Reserve0.Decimals, "TEMP")
if err != nil {
return nil, err
}
denominator, err := p.dc.Subtract(poolData.Reserve1, amountOut)
if err != nil {
return nil, err
}
amountInBeforeFee, err := p.dc.Divide(numerator, denominator, poolData.Token0.Decimals, "TEMP")
if err != nil {
return nil, err
}
// Add fee
feeMultiplier, err := p.dc.FromString("1", 18, "FEE_MULT")
if err != nil {
return nil, err
}
oneMinusFee, err := p.dc.Subtract(feeMultiplier, poolData.Fee)
if err != nil {
return nil, err
}
return p.dc.Divide(amountInBeforeFee, oneMinusFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
}
func (p *UniswapV3Pricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Calculate price before trade
priceBefore, err := p.GetSpotPrice(poolData)
if err != nil {
return nil, fmt.Errorf("error getting spot price: %w", err)
}
// Calculate amount out
amountOut, err := p.CalculateAmountOut(amountIn, poolData)
if err != nil {
return nil, fmt.Errorf("error calculating amount out: %w", err)
}
// Calculate effective price
effectivePrice, err := p.dc.Divide(amountOut, amountIn, 18, "EFFECTIVE_PRICE")
if err != nil {
return nil, fmt.Errorf("error calculating effective price: %w", err)
}
// Calculate price impact as percentage
priceDiff, err := p.dc.Subtract(priceBefore, effectivePrice)
if err != nil {
return nil, fmt.Errorf("error calculating price difference: %w", err)
}
return p.dc.CalculatePercentage(priceDiff, priceBefore)
}
func (p *UniswapV3Pricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
if poolData.Liquidity == nil {
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
}
return NewUniversalDecimal(poolData.Liquidity, 18, "LIQUIDITY")
}
func (p *UniswapV3Pricer) ValidatePoolData(poolData *PoolData) error {
if poolData.SqrtPriceX96 == nil {
return fmt.Errorf("Uniswap V3 pool missing sqrtPriceX96")
}
if poolData.Liquidity == nil {
return fmt.Errorf("Uniswap V3 pool missing liquidity")
}
if poolData.Fee == nil {
return fmt.Errorf("Uniswap V3 pool missing fee")
}
return nil
}
// UniswapV2Pricer implements Uniswap V2 / SushiSwap constant product pricing
type UniswapV2Pricer struct {
dc *DecimalConverter
}
func NewUniswapV2Pricer(dc *DecimalConverter) *UniswapV2Pricer {
return &UniswapV2Pricer{dc: dc}
}
func (p *UniswapV2Pricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
if poolData.Reserve0.IsZero() {
return nil, fmt.Errorf("zero reserve0 in constant product pool")
}
return p.dc.Divide(poolData.Reserve1, poolData.Reserve0, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
}
func (p *UniswapV2Pricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Uniswap V2 constant product formula: x * y = k
// amountOut = (amountIn * 997 * reserveOut) / (reserveIn * 1000 + amountIn * 997)
// Apply fee (0.3% = 997/1000 remaining)
feeNumerator, _ := p.dc.FromString("997", 0, "FEE_NUM")
feeDenominator, _ := p.dc.FromString("1000", 0, "FEE_DEN")
amountInWithFee, err := p.dc.Multiply(amountIn, feeNumerator, amountIn.Decimals, "TEMP")
if err != nil {
return nil, err
}
numerator, err := p.dc.Multiply(amountInWithFee, poolData.Reserve1, poolData.Reserve1.Decimals, "TEMP")
if err != nil {
return nil, err
}
reserveInScaled, err := p.dc.Multiply(poolData.Reserve0, feeDenominator, poolData.Reserve0.Decimals, "TEMP")
if err != nil {
return nil, err
}
denominator, err := p.dc.Add(reserveInScaled, amountInWithFee)
if err != nil {
return nil, err
}
return p.dc.Divide(numerator, denominator, poolData.Token1.Decimals, poolData.Token1.Symbol)
}
func (p *UniswapV2Pricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Reverse calculation for constant product
feeNumerator, _ := p.dc.FromString("1000", 0, "FEE_NUM")
feeDenominator, _ := p.dc.FromString("997", 0, "FEE_DEN")
numerator, err := p.dc.Multiply(poolData.Reserve0, amountOut, poolData.Reserve0.Decimals, "TEMP")
if err != nil {
return nil, err
}
numeratorWithFee, err := p.dc.Multiply(numerator, feeNumerator, numerator.Decimals, "TEMP")
if err != nil {
return nil, err
}
denominator, err := p.dc.Subtract(poolData.Reserve1, amountOut)
if err != nil {
return nil, err
}
denominatorWithFee, err := p.dc.Multiply(denominator, feeDenominator, denominator.Decimals, "TEMP")
if err != nil {
return nil, err
}
return p.dc.Divide(numeratorWithFee, denominatorWithFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
}
func (p *UniswapV2Pricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Similar to Uniswap V3 implementation
priceBefore, err := p.GetSpotPrice(poolData)
if err != nil {
return nil, err
}
amountOut, err := p.CalculateAmountOut(amountIn, poolData)
if err != nil {
return nil, err
}
effectivePrice, err := p.dc.Divide(amountOut, amountIn, 18, "EFFECTIVE_PRICE")
if err != nil {
return nil, err
}
priceDiff, err := p.dc.Subtract(priceBefore, effectivePrice)
if err != nil {
return nil, err
}
return p.dc.CalculatePercentage(priceDiff, priceBefore)
}
func (p *UniswapV2Pricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
// Geometric mean of reserves
product, err := p.dc.Multiply(poolData.Reserve0, poolData.Reserve1, 18, "LIQUIDITY")
if err != nil {
return nil, err
}
// Simplified square root - in production use precise sqrt algorithm
sqrt := new(big.Int).Sqrt(product.Value)
return NewUniversalDecimal(sqrt, 18, "LIQUIDITY")
}
func (p *UniswapV2Pricer) ValidatePoolData(poolData *PoolData) error {
if poolData.Reserve0 == nil || poolData.Reserve1 == nil {
return fmt.Errorf("missing reserves for constant product pool")
}
if poolData.Reserve0.IsZero() || poolData.Reserve1.IsZero() {
return fmt.Errorf("zero reserves in constant product pool")
}
return nil
}
// SushiSwapPricer uses same logic as Uniswap V2
type SushiSwapPricer struct {
*UniswapV2Pricer
}
func NewSushiSwapPricer(dc *DecimalConverter) *SushiSwapPricer {
return &SushiSwapPricer{NewUniswapV2Pricer(dc)}
}
// CamelotPricer - Algebra-based DEX on Arbitrum
type CamelotPricer struct {
*UniswapV3Pricer
}
func NewCamelotPricer(dc *DecimalConverter) *CamelotPricer {
return &CamelotPricer{NewUniswapV3Pricer(dc)}
}
// BalancerPricer - Weighted pool implementation
type BalancerPricer struct {
dc *DecimalConverter
}
func NewBalancerPricer(dc *DecimalConverter) *BalancerPricer {
return &BalancerPricer{dc: dc}
}
func (p *BalancerPricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
if len(poolData.Weights) < 2 {
return nil, fmt.Errorf("insufficient weights for Balancer pool")
}
// Balancer spot price = (reserveOut/weightOut) / (reserveIn/weightIn)
reserveOutWeighted, err := p.dc.Divide(poolData.Reserve1, poolData.Weights[1], 18, "TEMP")
if err != nil {
return nil, err
}
reserveInWeighted, err := p.dc.Divide(poolData.Reserve0, poolData.Weights[0], 18, "TEMP")
if err != nil {
return nil, err
}
return p.dc.Divide(reserveOutWeighted, reserveInWeighted, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
}
func (p *BalancerPricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Simplified Balancer calculation - production needs full weighted math
return p.GetSpotPrice(poolData)
}
func (p *BalancerPricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
spotPrice, err := p.GetSpotPrice(poolData)
if err != nil {
return nil, err
}
return p.dc.Divide(amountOut, spotPrice, poolData.Token0.Decimals, poolData.Token0.Symbol)
}
func (p *BalancerPricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Placeholder - would implement Balancer-specific price impact
return NewUniversalDecimal(big.NewInt(0), 4, "PERCENT")
}
func (p *BalancerPricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
}
func (p *BalancerPricer) ValidatePoolData(poolData *PoolData) error {
if len(poolData.Weights) < 2 {
return fmt.Errorf("Balancer pool missing weights")
}
return nil
}
// Placeholder implementations for other exchanges
func NewTraderJoePricer(dc *DecimalConverter) *UniswapV2Pricer { return NewUniswapV2Pricer(dc) }
func NewRamsesPricer(dc *DecimalConverter) *UniswapV3Pricer { return NewUniswapV3Pricer(dc) }
// CurvePricer - Stable swap implementation
type CurvePricer struct {
dc *DecimalConverter
}
func NewCurvePricer(dc *DecimalConverter) *CurvePricer {
return &CurvePricer{dc: dc}
}
func (p *CurvePricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
// Curve stable swap pricing - simplified version
if poolData.A == nil {
return nil, fmt.Errorf("missing amplification coefficient for Curve pool")
}
// For stable swaps, price should be close to 1:1
return NewUniversalDecimal(p.dc.getScalingFactor(18), 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
}
func (p *CurvePricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Simplified stable swap calculation
// Real implementation would use Newton's method for stable swap invariant
feeAmount, err := p.dc.Multiply(amountIn, poolData.Fee, amountIn.Decimals, "FEE")
if err != nil {
return nil, err
}
return p.dc.Subtract(amountIn, feeAmount)
}
func (p *CurvePricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Reverse stable swap calculation
feeMultiplier, _ := p.dc.FromString("1", 18, "FEE_MULT")
oneMinusFee, err := p.dc.Subtract(feeMultiplier, poolData.Fee)
if err != nil {
return nil, err
}
return p.dc.Divide(amountOut, oneMinusFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
}
func (p *CurvePricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Curve pools have minimal price impact for stable pairs
return NewUniversalDecimal(big.NewInt(1000), 4, "PERCENT") // 0.1%
}
func (p *CurvePricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
}
func (p *CurvePricer) ValidatePoolData(poolData *PoolData) error {
if poolData.A == nil {
return fmt.Errorf("Curve pool missing amplification coefficient")
}
return nil
}

View File

@@ -0,0 +1,62 @@
package math
import (
"math/big"
)
// MockGasEstimator implements GasEstimator for testing purposes
type MockGasEstimator struct {
currentGasPrice *UniversalDecimal
}
// NewMockGasEstimator creates a new mock gas estimator
func NewMockGasEstimator() *MockGasEstimator {
dc := NewDecimalConverter()
gasPrice, _ := dc.FromString("20", 9, "GWEI") // 20 gwei default
return &MockGasEstimator{
currentGasPrice: gasPrice,
}
}
// EstimateSwapGas estimates gas for a swap
func (mge *MockGasEstimator) EstimateSwapGas(exchange string, poolData *PoolData) (uint64, error) {
// Different exchanges have different gas costs
switch exchange {
case "uniswap_v3":
return 150000, nil // 150k gas for Uniswap V3
case "uniswap_v2", "sushiswap":
return 120000, nil // 120k gas for Uniswap V2/SushiSwap
case "camelot":
return 130000, nil // 130k gas for Camelot
case "balancer":
return 200000, nil // 200k gas for Balancer
case "curve":
return 180000, nil // 180k gas for Curve
default:
return 150000, nil // Default to 150k gas
}
}
// EstimateFlashSwapGas estimates gas for a flash swap
func (mge *MockGasEstimator) EstimateFlashSwapGas(route []*PoolData) (uint64, error) {
// Flash swap overhead varies by complexity
baseGas := uint64(200000) // Base flash swap overhead
// Add gas for each hop
hopGas := uint64(len(route)) * 50000
return baseGas + hopGas, nil
}
// GetCurrentGasPrice gets the current gas price
func (mge *MockGasEstimator) GetCurrentGasPrice() (*UniversalDecimal, error) {
return mge.currentGasPrice, nil
}
// SetCurrentGasPrice sets the current gas price for testing
func (mge *MockGasEstimator) SetCurrentGasPrice(gasPriceGwei int64) {
dc := NewDecimalConverter()
gasPrice, _ := dc.FromString(big.NewInt(gasPriceGwei).String(), 9, "GWEI")
mge.currentGasPrice = gasPrice
}

289
pkg/math/precision_test.go Normal file
View File

@@ -0,0 +1,289 @@
package math
import (
"math/big"
"math/rand"
"testing"
"time"
)
// TestDecimalPrecisionPreservation tests that decimal operations preserve precision
func TestDecimalPrecisionPreservation(t *testing.T) {
dc := NewDecimalConverter()
testCases := []struct {
name string
value string
decimals uint8
symbol string
}{
{"ETH precision", "1000000000000000000", 18, "ETH"}, // 1 ETH
{"USDC precision", "1000000", 6, "USDC"}, // 1 USDC
{"WBTC precision", "100000000", 8, "WBTC"}, // 1 WBTC
{"Small amount", "1", 18, "ETH"}, // 1 wei
{"Large amount", "1000000000000000000000", 18, "ETH"}, // 1000 ETH
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create decimal from string
decimal, err := dc.FromString(tc.value, tc.decimals, tc.symbol)
if err != nil {
t.Fatalf("Failed to create decimal: %v", err)
}
// Convert to string and back
humanReadable := dc.ToHumanReadable(decimal)
backToDecimal, err := dc.FromString(humanReadable, tc.decimals, tc.symbol)
if err != nil {
t.Fatalf("Failed to convert back from string: %v", err)
}
// Compare values
if decimal.Value.Cmp(backToDecimal.Value) != 0 {
t.Errorf("Precision lost in round-trip conversion")
t.Errorf("Original: %s", decimal.Value.String())
t.Errorf("Round-trip: %s", backToDecimal.Value.String())
t.Errorf("Human readable: %s", humanReadable)
}
})
}
}
// TestArithmeticOperations tests basic arithmetic with different decimal precisions
func TestArithmeticOperations(t *testing.T) {
dc := NewDecimalConverter()
// Create test values with different precisions
eth1, _ := dc.FromString("1000000000000000000", 18, "ETH") // 1 ETH
eth2, _ := dc.FromString("2000000000000000000", 18, "ETH") // 2 ETH
tests := []struct {
name string
op string
a, b *UniversalDecimal
expected string
}{
{
name: "ETH addition",
op: "add",
a: eth1,
b: eth2,
expected: "3000000000000000000", // 3 ETH
},
{
name: "ETH subtraction",
op: "sub",
a: eth2,
b: eth1,
expected: "1000000000000000000", // 1 ETH
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var result *UniversalDecimal
var err error
switch test.op {
case "add":
result, err = dc.Add(test.a, test.b)
case "sub":
result, err = dc.Subtract(test.a, test.b)
default:
t.Fatalf("Unknown operation: %s", test.op)
}
if err != nil {
t.Fatalf("Operation failed: %v", err)
}
if result.Value.String() != test.expected {
t.Errorf("Expected %s, got %s", test.expected, result.Value.String())
}
})
}
}
// TestPercentageCalculations tests percentage calculations for precision
func TestPercentageCalculations(t *testing.T) {
dc := NewDecimalConverter()
testCases := []struct {
name string
numerator string
denominator string
decimals uint8
expectedRange [2]float64 // [min, max] acceptable range
}{
{
name: "1% calculation",
numerator: "10000000000000000", // 0.01 ETH
denominator: "1000000000000000000", // 1 ETH
decimals: 18,
expectedRange: [2]float64{0.99, 1.01}, // 1% ± 0.01%
},
{
name: "50% calculation",
numerator: "500000000000000000", // 0.5 ETH
denominator: "1000000000000000000", // 1 ETH
decimals: 18,
expectedRange: [2]float64{49.9, 50.1}, // 50% ± 0.1%
},
{
name: "Small percentage",
numerator: "1000000000000000", // 0.001 ETH
denominator: "1000000000000000000", // 1 ETH
decimals: 18,
expectedRange: [2]float64{0.099, 0.101}, // 0.1% ± 0.001%
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
num, err := dc.FromString(tc.numerator, tc.decimals, "ETH")
if err != nil {
t.Fatalf("Failed to create numerator: %v", err)
}
denom, err := dc.FromString(tc.denominator, tc.decimals, "ETH")
if err != nil {
t.Fatalf("Failed to create denominator: %v", err)
}
percentage, err := dc.CalculatePercentage(num, denom)
if err != nil {
t.Fatalf("Failed to calculate percentage: %v", err)
}
percentageFloat, _ := percentage.Value.Float64()
t.Logf("Calculated percentage: %.6f%%", percentageFloat)
if percentageFloat < tc.expectedRange[0] || percentageFloat > tc.expectedRange[1] {
t.Errorf("Percentage %.6f%% outside expected range [%.3f%%, %.3f%%]",
percentageFloat, tc.expectedRange[0], tc.expectedRange[1])
}
})
}
}
// PropertyTest tests mathematical properties like commutativity, associativity
func TestMathematicalProperties(t *testing.T) {
dc := NewDecimalConverter()
rand.Seed(time.Now().UnixNano())
// Generate random test values
for i := 0; i < 100; i++ {
// Generate random big integers
val1 := big.NewInt(rand.Int63n(1000000000000000000)) // Up to 1 ETH
val2 := big.NewInt(rand.Int63n(1000000000000000000))
val3 := big.NewInt(rand.Int63n(1000000000000000000))
a, _ := NewUniversalDecimal(val1, 18, "ETH")
b, _ := NewUniversalDecimal(val2, 18, "ETH")
c, _ := NewUniversalDecimal(val3, 18, "ETH")
// Test commutativity: a + b = b + a
ab, err1 := dc.Add(a, b)
ba, err2 := dc.Add(b, a)
if err1 != nil || err2 != nil {
t.Fatalf("Addition failed: %v, %v", err1, err2)
}
if ab.Value.Cmp(ba.Value) != 0 {
t.Errorf("Addition not commutative: %s + %s = %s, %s + %s = %s",
a.Value.String(), b.Value.String(), ab.Value.String(),
b.Value.String(), a.Value.String(), ba.Value.String())
}
// Test associativity: (a + b) + c = a + (b + c)
ab_c, err1 := dc.Add(ab, c)
bc, err2 := dc.Add(b, c)
a_bc, err3 := dc.Add(a, bc)
if err1 != nil || err2 != nil || err3 != nil {
continue // Skip this iteration if any operation fails
}
if ab_c.Value.Cmp(a_bc.Value) != 0 {
t.Errorf("Addition not associative")
}
}
}
// BenchmarkDecimalOperations benchmarks decimal operations
func BenchmarkDecimalOperations(b *testing.B) {
dc := NewDecimalConverter()
val1, _ := dc.FromString("1000000000000000000", 18, "ETH")
val2, _ := dc.FromString("2000000000000000000", 18, "ETH")
b.Run("Addition", func(b *testing.B) {
for i := 0; i < b.N; i++ {
dc.Add(val1, val2)
}
})
b.Run("Subtraction", func(b *testing.B) {
for i := 0; i < b.N; i++ {
dc.Subtract(val2, val1)
}
})
b.Run("Percentage", func(b *testing.B) {
for i := 0; i < b.N; i++ {
dc.CalculatePercentage(val1, val2)
}
})
}
// FuzzDecimalOperations fuzzes decimal operations for edge cases
func FuzzDecimalOperations(f *testing.F) {
// Seed with known values
f.Add(int64(1000000000000000000), int64(2000000000000000000)) // 1 ETH, 2 ETH
f.Add(int64(1), int64(1000000000000000000)) // 1 wei, 1 ETH
f.Add(int64(0), int64(1000000000000000000)) // 0, 1 ETH
f.Fuzz(func(t *testing.T, val1, val2 int64) {
// Ensure positive values
if val1 < 0 {
val1 = -val1
}
if val2 <= 0 {
return // Skip zero/negative denominators
}
dc := NewDecimalConverter()
a, err := NewUniversalDecimal(big.NewInt(val1), 18, "ETH")
if err != nil {
return
}
b, err := NewUniversalDecimal(big.NewInt(val2), 18, "ETH")
if err != nil {
return
}
// Test addition doesn't panic
_, err = dc.Add(a, b)
if err != nil {
t.Errorf("Addition failed: %v", err)
}
// Test subtraction doesn't panic (if a >= b)
if val1 >= val2 {
_, err = dc.Subtract(a, b)
if err != nil {
t.Errorf("Subtraction failed: %v", err)
}
}
// Test percentage calculation doesn't panic
_, err = dc.CalculatePercentage(a, b)
if err != nil {
t.Errorf("Percentage calculation failed: %v", err)
}
})
}