feat: create v2-prep branch with comprehensive planning

Restructured project for V2 refactor:

**Structure Changes:**
- Moved all V1 code to orig/ folder (preserved with git mv)
- Created docs/planning/ directory
- Added orig/README_V1.md explaining V1 preservation

**Planning Documents:**
- 00_V2_MASTER_PLAN.md: Complete architecture overview
  - Executive summary of critical V1 issues
  - High-level component architecture diagrams
  - 5-phase implementation roadmap
  - Success metrics and risk mitigation

- 07_TASK_BREAKDOWN.md: Atomic task breakdown
  - 99+ hours of detailed tasks
  - Every task < 2 hours (atomic)
  - Clear dependencies and success criteria
  - Organized by implementation phase

**V2 Key Improvements:**
- Per-exchange parsers (factory pattern)
- Multi-layer strict validation
- Multi-index pool cache
- Background validation pipeline
- Comprehensive observability

**Critical Issues Addressed:**
- Zero address tokens (strict validation + cache enrichment)
- Parsing accuracy (protocol-specific parsers)
- No audit trail (background validation channel)
- Inefficient lookups (multi-index cache)
- Stats disconnection (event-driven metrics)

Next Steps:
1. Review planning documents
2. Begin Phase 1: Foundation (P1-001 through P1-010)
3. Implement parsers in Phase 2
4. Build cache system in Phase 3
5. Add validation pipeline in Phase 4
6. Migrate and test in Phase 5

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Administrator
2025-11-10 10:14:26 +01:00
parent 1773daffe7
commit 803de231ba
411 changed files with 20390 additions and 8680 deletions

View File

@@ -0,0 +1,737 @@
package math
import (
"context"
"fmt"
"math"
"math/big"
"sort"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/fraktal/mev-beta/pkg/security"
"github.com/fraktal/mev-beta/pkg/types"
)
// Use the canonical ArbitrageOpportunity from types package
// Extended fields for advanced calculations can be added as needed
// ExchangeStep represents one step in the arbitrage execution
type ExchangeStep struct {
Exchange ExchangeType
Pool *PoolData
TokenIn TokenInfo
TokenOut TokenInfo
AmountIn *UniversalDecimal
AmountOut *UniversalDecimal
PriceImpact *UniversalDecimal
EstimatedGas uint64
}
// RiskAssessment evaluates the risk level of an arbitrage opportunity
type RiskAssessment struct {
Overall RiskLevel
Liquidity RiskLevel
PriceImpact RiskLevel
Competition RiskLevel
Slippage RiskLevel
GasPrice RiskLevel
Warnings []string
OverallRisk float64 // Numeric representation of overall risk (0.0 to 1.0)
}
// RiskLevel represents different risk categories
type RiskLevel string
const (
RiskLow RiskLevel = "low"
RiskMedium RiskLevel = "medium"
RiskHigh RiskLevel = "high"
RiskCritical RiskLevel = "critical"
)
// ArbitrageCalculator performs precise arbitrage calculations
type ArbitrageCalculator struct {
pricingEngine *ExchangePricingEngine
decimalConverter *DecimalConverter
gasEstimator GasEstimator
// Configuration
minProfitThreshold *UniversalDecimal
maxPriceImpact *UniversalDecimal
maxSlippage *UniversalDecimal
maxGasPriceGwei *UniversalDecimal
}
// GasEstimator interface for gas cost calculations
type GasEstimator interface {
EstimateSwapGas(exchange ExchangeType, poolData *PoolData) (uint64, error)
EstimateFlashSwapGas(route []*PoolData) (uint64, error)
GetCurrentGasPrice() (*UniversalDecimal, error)
}
// NewArbitrageCalculator creates a new arbitrage calculator
func NewArbitrageCalculator(gasEstimator GasEstimator) *ArbitrageCalculator {
dc := NewDecimalConverter()
// Default configuration
minProfit, _ := dc.FromString("0.01", 18, "ETH") // 0.01 ETH minimum
maxImpact, _ := dc.FromString("0.02", 4, "PERCENT") // 2% max price impact
maxSlip, _ := dc.FromString("0.01", 4, "PERCENT") // 1% max slippage
maxGas, _ := dc.FromString("50", 9, "GWEI") // 50 gwei max gas
return &ArbitrageCalculator{
pricingEngine: NewExchangePricingEngine(),
decimalConverter: dc,
gasEstimator: gasEstimator,
minProfitThreshold: minProfit,
maxPriceImpact: maxImpact,
maxSlippage: maxSlip,
maxGasPriceGwei: maxGas,
}
}
func toDecimalAmount(ud *UniversalDecimal) types.DecimalAmount {
if ud == nil {
return types.DecimalAmount{}
}
return types.DecimalAmount{
Value: ud.Value.String(),
Decimals: ud.Decimals,
Symbol: ud.Symbol,
}
}
// CalculateArbitrageOpportunity performs comprehensive arbitrage analysis
func (calc *ArbitrageCalculator) CalculateArbitrageOpportunity(
path []*PoolData,
inputAmount *UniversalDecimal,
inputToken TokenInfo,
outputToken TokenInfo,
) (*types.ArbitrageOpportunity, error) {
if len(path) == 0 {
return nil, fmt.Errorf("empty arbitrage path")
}
// Step 1: Calculate execution route with amounts
route, err := calc.calculateExecutionRoute(path, inputAmount, inputToken)
if err != nil {
return nil, fmt.Errorf("error calculating execution route: %w", err)
}
// Step 2: Get final output amount
finalOutput := route[len(route)-1].AmountOut
// Step 3: Calculate gas costs
totalGasCost, err := calc.calculateTotalGasCost(route)
if err != nil {
return nil, fmt.Errorf("error calculating gas cost: %w", err)
}
// Step 4: Calculate profits (convert to common denomination - ETH)
grossProfit, netProfit, profitPercentage, err := calc.calculateProfits(
inputAmount, finalOutput, totalGasCost, inputToken, outputToken)
if err != nil {
return nil, fmt.Errorf("error calculating profits: %w", err)
}
// Step 5: Calculate total price impact
totalPriceImpact, err := calc.calculateTotalPriceImpact(route)
if err != nil {
return nil, fmt.Errorf("error calculating price impact: %w", err)
}
// Step 6: Calculate minimum output with slippage (we don't use this in the final result)
_, err = calc.calculateMinimumOutput(finalOutput)
if err != nil {
return nil, fmt.Errorf("error calculating minimum output: %w", err)
}
// Step 7: Assess risks
riskAssessment := calc.assessRisks(route, totalPriceImpact, netProfit)
// Step 8: Calculate confidence and execution time
confidence := calc.calculateConfidence(riskAssessment, netProfit, totalPriceImpact)
executionTime := calc.estimateExecutionTime(route)
// Convert path to string array
pathStrings := make([]string, len(path))
for i, pool := range path {
pathStrings[i] = pool.Address // Address is already a string
}
// Convert pools to string array
poolStrings := make([]string, len(path))
for i, pool := range path {
poolStrings[i] = pool.Address // Address is already a string
}
opportunity := &types.ArbitrageOpportunity{
Path: pathStrings,
Pools: poolStrings,
AmountIn: inputAmount.Value,
RequiredAmount: inputAmount.Value,
Profit: grossProfit.Value,
NetProfit: netProfit.Value,
EstimatedProfit: grossProfit.Value,
GasEstimate: totalGasCost.Value,
ROI: func() float64 {
// Convert percentage from 4-decimal format to actual percentage
f, _ := profitPercentage.Value.Float64()
return f / 10000.0 // Convert from 4-decimal format to actual percentage
}(),
Protocol: "multi", // Default protocol for multi-step arbitrage
ExecutionTime: executionTime,
Confidence: confidence,
PriceImpact: func() float64 {
// Convert percentage from 4-decimal format to actual percentage
f, _ := totalPriceImpact.Value.Float64()
return f / 10000.0 // Convert from 4-decimal format to actual percentage
}(),
MaxSlippage: 0.01, // Default 1% max slippage
TokenIn: common.HexToAddress(inputToken.Address),
TokenOut: common.HexToAddress(outputToken.Address),
Timestamp: time.Now().Unix(),
DetectedAt: time.Now(),
ExpiresAt: time.Now().Add(5 * time.Minute),
Risk: riskAssessment.OverallRisk,
}
opportunity.Quantities = &types.OpportunityQuantities{
AmountIn: toDecimalAmount(inputAmount),
AmountOut: toDecimalAmount(finalOutput),
GrossProfit: toDecimalAmount(grossProfit),
NetProfit: toDecimalAmount(netProfit),
GasCost: toDecimalAmount(totalGasCost),
ProfitPercent: toDecimalAmount(profitPercentage),
PriceImpact: toDecimalAmount(totalPriceImpact),
}
return opportunity, nil
}
// calculateExecutionRoute calculates amounts through each step of the arbitrage
func (calc *ArbitrageCalculator) calculateExecutionRoute(
path []*PoolData,
inputAmount *UniversalDecimal,
inputToken TokenInfo,
) ([]ExchangeStep, error) {
route := make([]ExchangeStep, len(path))
currentAmount := inputAmount
currentToken := inputToken
for i, pool := range path {
// Determine output token for this step
var outputToken TokenInfo
if currentToken.Address == pool.Token0.Address {
outputToken = TokenInfo{
Address: pool.Token1.Address,
Symbol: "TOKEN1", // In a real implementation, you'd fetch the actual symbol
Decimals: 18,
}
} else if currentToken.Address == pool.Token1.Address {
outputToken = TokenInfo{
Address: pool.Token0.Address,
Symbol: "TOKEN0", // In a real implementation, you'd fetch the actual symbol
Decimals: 18,
}
} else {
return nil, fmt.Errorf("token %s not found in pool %s", currentToken.Symbol, pool.Address)
}
// For this simplified implementation, we'll calculate a mock amount out
// In a real implementation, you'd use the pricer's CalculateAmountOut method
amountOut := currentAmount // Simple 1:1 for this example
priceImpact := &UniversalDecimal{Value: big.NewInt(0), Decimals: 4, Symbol: "PERCENT"} // No impact in mock
// Estimate gas for this step
estimatedGas, err := calc.gasEstimator.EstimateSwapGas(ExchangeUniswapV3, pool) // Using a mock exchange type
if err != nil {
return nil, fmt.Errorf("error estimating gas for pool %s: %w", pool.Address, err)
}
// Create execution step
route[i] = ExchangeStep{
Exchange: ExchangeUniswapV3, // Using a mock exchange type
Pool: pool,
TokenIn: currentToken,
TokenOut: outputToken,
AmountIn: currentAmount,
AmountOut: amountOut,
PriceImpact: priceImpact,
EstimatedGas: estimatedGas,
}
// Update for next iteration
currentAmount = amountOut
currentToken = outputToken
}
return route, nil
}
// calculateTotalGasCost calculates the total gas cost for the entire route
func (calc *ArbitrageCalculator) calculateTotalGasCost(route []ExchangeStep) (*UniversalDecimal, error) {
// Get current gas price
gasPrice, err := calc.gasEstimator.GetCurrentGasPrice()
if err != nil {
return nil, fmt.Errorf("error getting gas price: %w", err)
}
// Sum up all gas estimates
totalGas := uint64(0)
for _, step := range route {
totalGas += step.EstimatedGas
}
// Add flash swap overhead if multi-step
if len(route) > 1 {
flashSwapGas, err := calc.gasEstimator.EstimateFlashSwapGas([]*PoolData{})
if err == nil {
totalGas += flashSwapGas
}
}
// Convert to gas cost in ETH
totalGasInt64, err := security.SafeUint64ToInt64(totalGas)
if err != nil {
// This is very unlikely for gas calculations, but handle safely
// Use maximum safe value as fallback
totalGasInt64 = math.MaxInt64
}
totalGasBig := big.NewInt(totalGasInt64)
totalGasDecimal, err := NewUniversalDecimal(totalGasBig, 0, "GAS")
if err != nil {
return nil, err
}
return calc.decimalConverter.Multiply(totalGasDecimal, gasPrice, 18, "ETH")
}
// calculateProfits calculates gross profit, net profit, and profit percentage
func (calc *ArbitrageCalculator) calculateProfits(
inputAmount, outputAmount, gasCost *UniversalDecimal,
inputToken, outputToken TokenInfo,
) (*UniversalDecimal, *UniversalDecimal, *UniversalDecimal, error) {
// Convert amounts to common denomination (ETH) for comparison
inputETH := calc.convertToETH(inputAmount, inputToken)
outputETH := calc.convertToETH(outputAmount, outputToken)
// Gross profit = output - input (in ETH terms)
grossProfit, err := calc.decimalConverter.Subtract(outputETH, inputETH)
if err != nil {
return nil, nil, nil, fmt.Errorf("error calculating gross profit: %w", err)
}
// Net profit = gross profit - gas cost
netProfit, err := calc.decimalConverter.Subtract(grossProfit, gasCost)
if err != nil {
return nil, nil, nil, fmt.Errorf("error calculating net profit: %w", err)
}
// Profit percentage = (net profit / input) * 100
profitPercentage, err := calc.decimalConverter.CalculatePercentage(netProfit, inputETH)
if err != nil {
return nil, nil, nil, fmt.Errorf("error calculating profit percentage: %w", err)
}
return grossProfit, netProfit, profitPercentage, nil
}
// calculateTotalPriceImpact calculates cumulative price impact across all steps
func (calc *ArbitrageCalculator) calculateTotalPriceImpact(route []ExchangeStep) (*UniversalDecimal, error) {
if len(route) == 0 {
return NewUniversalDecimal(big.NewInt(0), 4, "PERCENT")
}
// Compound price impacts: (1 + impact1) * (1 + impact2) - 1
compoundedImpact, err := calc.decimalConverter.FromString("1", 4, "COMPOUND")
if err != nil {
return nil, err
}
for _, step := range route {
// Convert price impact to factor (1 + impact)
one, _ := calc.decimalConverter.FromString("1", 4, "ONE")
impactFactor, err := calc.decimalConverter.Add(one, step.PriceImpact)
if err != nil {
return nil, fmt.Errorf("error calculating impact factor: %w", err)
}
// Multiply with cumulative impact
compoundedImpact, err = calc.decimalConverter.Multiply(compoundedImpact, impactFactor, 4, "COMPOUND")
if err != nil {
return nil, fmt.Errorf("error compounding impact: %w", err)
}
}
// Subtract 1 to get final impact percentage
one, _ := calc.decimalConverter.FromString("1", 4, "ONE")
totalImpact, err := calc.decimalConverter.Subtract(compoundedImpact, one)
if err != nil {
return nil, fmt.Errorf("error calculating total impact: %w", err)
}
return totalImpact, nil
}
// calculateMinimumOutput calculates minimum output accounting for slippage
func (calc *ArbitrageCalculator) calculateMinimumOutput(expectedOutput *UniversalDecimal) (*UniversalDecimal, error) {
// Apply slippage tolerance
slippageFactor, err := calc.decimalConverter.Subtract(
&UniversalDecimal{Value: big.NewInt(10000), Decimals: 4, Symbol: "ONE"},
calc.maxSlippage,
)
if err != nil {
return nil, err
}
return calc.decimalConverter.Multiply(expectedOutput, slippageFactor, 18, "TOKEN")
}
// assessRisks performs comprehensive risk assessment
func (calc *ArbitrageCalculator) assessRisks(route []ExchangeStep, priceImpact, netProfit *UniversalDecimal) RiskAssessment {
assessment := RiskAssessment{
Warnings: make([]string, 0),
}
// Assess liquidity risk
assessment.Liquidity = calc.assessLiquidityRisk(route)
// Assess price impact risk
assessment.PriceImpact = calc.assessPriceImpactRisk(priceImpact)
// Assess profitability risk
profitRisk := calc.assessProfitabilityRisk(netProfit)
// Assess gas price risk
assessment.GasPrice = calc.assessGasPriceRisk()
// Calculate overall risk (worst of all categories)
risks := []RiskLevel{assessment.Liquidity, assessment.PriceImpact, profitRisk, assessment.GasPrice}
assessment.Overall = calc.calculateOverallRisk(risks)
// Calculate OverallRisk as a numeric value (0.0 to 1.0) based on the overall risk level
switch assessment.Overall {
case RiskLow:
assessment.OverallRisk = 0.1
case RiskMedium:
assessment.OverallRisk = 0.4
case RiskHigh:
assessment.OverallRisk = 0.7
case RiskCritical:
assessment.OverallRisk = 0.95
default:
assessment.OverallRisk = 0.5 // Default to medium risk
}
return assessment
}
// Helper risk assessment methods
func (calc *ArbitrageCalculator) assessLiquidityRisk(route []ExchangeStep) RiskLevel {
for _, step := range route {
// For this simplified implementation, assume a mock liquidity value
// In a real implementation, you'd get this from the pricing engine
mockLiquidity, _ := calc.decimalConverter.FromString("1000", 18, "TOKEN") // 1000 tokens
if mockLiquidity.IsZero() {
return RiskHigh
}
// Check if trade size is significant portion of liquidity (>10%)
tenPercent, _ := calc.decimalConverter.FromString("10", 4, "PERCENT")
tradeSizePercent, _ := calc.decimalConverter.CalculatePercentage(step.AmountIn, mockLiquidity)
if comp, _ := calc.decimalConverter.Compare(tradeSizePercent, tenPercent); comp > 0 {
return RiskMedium
}
}
return RiskLow
}
func (calc *ArbitrageCalculator) assessPriceImpactRisk(priceImpact *UniversalDecimal) RiskLevel {
fivePercent, _ := calc.decimalConverter.FromString("5", 4, "PERCENT")
twoPercent, _ := calc.decimalConverter.FromString("2", 4, "PERCENT")
if comp, _ := calc.decimalConverter.Compare(priceImpact, fivePercent); comp > 0 {
return RiskHigh
}
if comp, _ := calc.decimalConverter.Compare(priceImpact, twoPercent); comp > 0 {
return RiskMedium
}
return RiskLow
}
func (calc *ArbitrageCalculator) assessProfitabilityRisk(netProfit *UniversalDecimal) RiskLevel {
if netProfit.IsNegative() {
return RiskCritical
}
smallProfit, _ := calc.decimalConverter.FromString("0.001", 18, "ETH") // $1 at $1000/ETH
mediumProfit, _ := calc.decimalConverter.FromString("0.01", 18, "ETH") // $10 at $1000/ETH
if comp, _ := calc.decimalConverter.Compare(netProfit, smallProfit); comp < 0 {
return RiskHigh
}
if comp, _ := calc.decimalConverter.Compare(netProfit, mediumProfit); comp < 0 {
return RiskMedium
}
return RiskLow
}
func (calc *ArbitrageCalculator) assessGasPriceRisk() RiskLevel {
currentGas, _ := calc.gasEstimator.GetCurrentGasPrice()
if comp, _ := calc.decimalConverter.Compare(currentGas, calc.maxGasPriceGwei); comp > 0 {
return RiskHigh
}
twentyGwei, _ := calc.decimalConverter.FromString("20", 9, "GWEI")
if comp, _ := calc.decimalConverter.Compare(currentGas, twentyGwei); comp > 0 {
return RiskMedium
}
return RiskLow
}
func (calc *ArbitrageCalculator) calculateOverallRisk(risks []RiskLevel) RiskLevel {
riskScores := map[RiskLevel]int{
RiskLow: 1,
RiskMedium: 2,
RiskHigh: 3,
RiskCritical: 4,
}
maxScore := 0
for _, risk := range risks {
if score := riskScores[risk]; score > maxScore {
maxScore = score
}
}
for risk, score := range riskScores {
if score == maxScore {
return risk
}
}
return RiskLow
}
// calculateConfidence calculates confidence score based on risk and profit
func (calc *ArbitrageCalculator) calculateConfidence(risk RiskAssessment, netProfit, priceImpact *UniversalDecimal) float64 {
baseConfidence := 0.5
// Adjust for risk level
switch risk.Overall {
case RiskLow:
baseConfidence += 0.3
case RiskMedium:
baseConfidence += 0.1
case RiskHigh:
baseConfidence -= 0.2
case RiskCritical:
baseConfidence -= 0.4
}
// Adjust for profit magnitude
if netProfit.IsPositive() {
largeProfit, _ := calc.decimalConverter.FromString("0.1", 18, "ETH")
if comp, _ := calc.decimalConverter.Compare(netProfit, largeProfit); comp > 0 {
baseConfidence += 0.2
}
}
// Adjust for price impact
lowImpact, _ := calc.decimalConverter.FromString("1", 4, "PERCENT")
if comp, _ := calc.decimalConverter.Compare(priceImpact, lowImpact); comp < 0 {
baseConfidence += 0.1
}
if baseConfidence < 0 {
baseConfidence = 0
}
if baseConfidence > 1 {
baseConfidence = 1
}
return baseConfidence
}
// estimateExecutionTime estimates execution time in milliseconds
func (calc *ArbitrageCalculator) estimateExecutionTime(route []ExchangeStep) int64 {
baseTime := int64(500) // 500ms base
// Add time per hop
hopTime := int64(len(route)) * 200
// Add time for complex exchanges
complexTime := int64(0)
for _, step := range route {
switch ExchangeType(step.Exchange) {
case ExchangeUniswapV3, ExchangeCamelot:
complexTime += 300 // Concentrated liquidity is more complex
case ExchangeBalancer, ExchangeCurve:
complexTime += 400 // Weighted/stable pools are complex
default:
complexTime += 100 // Simple AMM
}
}
return baseTime + hopTime + complexTime
}
// convertToETH converts any token amount to ETH for comparison (placeholder)
func (calc *ArbitrageCalculator) convertToETH(amount *UniversalDecimal, token TokenInfo) *UniversalDecimal {
// This is a placeholder - in production, this would query price oracles
// For now, assume 1:1 conversion for demonstration
ethAmount, _ := calc.decimalConverter.ConvertTo(amount, 18, "ETH")
return ethAmount
}
// IsOpportunityProfitable checks if opportunity meets minimum criteria
// IsOpportunityProfitable checks if an opportunity meets profitability criteria
func (calc *ArbitrageCalculator) IsOpportunityProfitable(opportunity *types.ArbitrageOpportunity) bool {
if opportunity == nil {
return false
}
// Check minimum profit threshold
if !calc.checkProfitThreshold(opportunity) {
return false
}
// Check maximum price impact
if !calc.checkPriceImpactThreshold(opportunity) {
return false
}
// Check risk level
if !calc.checkRiskLevel(opportunity) {
return false
}
// Check confidence threshold
if !calc.checkConfidenceThreshold(opportunity) {
return false
}
return true
}
// checkProfitThreshold checks if the opportunity meets minimum profit requirements
func (calc *ArbitrageCalculator) checkProfitThreshold(opportunity *types.ArbitrageOpportunity) bool {
if opportunity.Quantities != nil {
if netProfitUD, err := calc.decimalAmountToUniversal(opportunity.Quantities.NetProfit); err == nil {
if cmp, err := calc.decimalConverter.Compare(netProfitUD, calc.minProfitThreshold); err == nil && cmp < 0 {
return false
}
}
} else if opportunity.NetProfit != nil {
if opportunity.NetProfit.Cmp(calc.minProfitThreshold.Value) < 0 {
return false
}
} else {
return false
}
return true
}
// checkPriceImpactThreshold checks if the opportunity is below maximum price impact
func (calc *ArbitrageCalculator) checkPriceImpactThreshold(opportunity *types.ArbitrageOpportunity) bool {
if opportunity.Quantities != nil {
if impactUD, err := calc.decimalAmountToUniversal(opportunity.Quantities.PriceImpact); err == nil {
if cmp, err := calc.decimalConverter.Compare(impactUD, calc.maxPriceImpact); err == nil && cmp > 0 {
return false
}
}
} else {
maxImpactFloat := float64(calc.maxPriceImpact.Value.Int64()) / math.Pow10(int(calc.maxPriceImpact.Decimals))
if opportunity.PriceImpact > maxImpactFloat {
return false
}
}
return true
}
// checkRiskLevel checks if the opportunity's risk is acceptable
func (calc *ArbitrageCalculator) checkRiskLevel(opportunity *types.ArbitrageOpportunity) bool {
return opportunity.Risk < 0.8 // High risk threshold
}
// checkConfidenceThreshold checks if the opportunity has sufficient confidence
func (calc *ArbitrageCalculator) checkConfidenceThreshold(opportunity *types.ArbitrageOpportunity) bool {
return opportunity.Confidence >= 0.3
}
// SortOpportunitiesByProfitability sorts opportunities by net profit descending
func (calc *ArbitrageCalculator) SortOpportunitiesByProfitability(opportunities []*types.ArbitrageOpportunity) {
sort.Slice(opportunities, func(i, j int) bool {
left, errL := calc.decimalAmountToUniversal(opportunities[i].Quantities.NetProfit)
right, errR := calc.decimalAmountToUniversal(opportunities[j].Quantities.NetProfit)
if errL == nil && errR == nil {
cmp, err := calc.decimalConverter.Compare(left, right)
if err == nil {
return cmp > 0
}
}
// Fallback to canonical big.Int comparison
return opportunities[i].NetProfit.Cmp(opportunities[j].NetProfit) > 0 // Descending order
})
}
func (calc *ArbitrageCalculator) decimalAmountToUniversal(dec types.DecimalAmount) (*UniversalDecimal, error) {
if dec.Value == "" {
return nil, fmt.Errorf("decimal amount empty")
}
val, ok := new(big.Int).SetString(dec.Value, 10)
if !ok {
return nil, fmt.Errorf("invalid decimal amount %s", dec.Value)
}
return NewUniversalDecimal(val, dec.Decimals, dec.Symbol)
}
// CalculateArbitrage calculates arbitrage opportunity for a given path and input amount
func (calc *ArbitrageCalculator) CalculateArbitrage(ctx context.Context, inputAmount *UniversalDecimal, path []*PoolData) (*types.ArbitrageOpportunity, error) {
if len(path) == 0 {
return nil, fmt.Errorf("empty path provided")
}
// Get the input and output tokens for the path
inputToken := path[0].Token0
outputToken := path[len(path)-1].Token1
if path[len(path)-1].Token0.Address == inputToken.Address {
outputToken = path[len(path)-1].Token0
}
// Calculate the arbitrage opportunity for this path
opportunity, err := calc.CalculateArbitrageOpportunity(path, inputAmount, inputToken, outputToken)
if err != nil {
return nil, fmt.Errorf("failed to calculate arbitrage opportunity: %w", err)
}
return opportunity, nil
}
// FindOptimalPath finds the most profitable arbitrage path between two tokens
func (calc *ArbitrageCalculator) FindOptimalPath(ctx context.Context, tokenA, tokenB common.Address, amount *UniversalDecimal) (*types.ArbitrageOpportunity, error) {
// In a real implementation, this would query for available paths between tokens
// and calculate the most profitable path. For this implementation, we'll return an error
// indicating no path is available since we don't have direct path-finding ability in the calculator
return nil, fmt.Errorf("FindOptimalPath not implemented in calculator - use executor.CalculateOptimalPath instead")
}
// FilterProfitableOpportunities returns only profitable opportunities
func (calc *ArbitrageCalculator) FilterProfitableOpportunities(opportunities []*types.ArbitrageOpportunity) []*types.ArbitrageOpportunity {
profitable := make([]*types.ArbitrageOpportunity, 0)
for _, opp := range opportunities {
if calc.IsOpportunityProfitable(opp) {
profitable = append(profitable, opp)
}
}
return profitable
}

View File

@@ -0,0 +1,175 @@
package math
import (
"math/big"
"testing"
"github.com/fraktal/mev-beta/pkg/types"
)
type stubGasEstimator struct {
price *UniversalDecimal
}
func (s stubGasEstimator) EstimateSwapGas(exchange ExchangeType, poolData *PoolData) (uint64, error) {
return 100_000, nil
}
func (s stubGasEstimator) EstimateFlashSwapGas(route []*PoolData) (uint64, error) {
return 50_000, nil
}
func (s stubGasEstimator) GetCurrentGasPrice() (*UniversalDecimal, error) {
return s.price, nil
}
func TestIsOpportunityProfitableRespectsThreshold(t *testing.T) {
estimator := stubGasEstimator{price: func() *UniversalDecimal {
ud, _ := NewUniversalDecimal(big.NewInt(1_000_000_000), 9, "GWEI")
return ud
}()}
calc := NewArbitrageCalculator(estimator)
belowThreshold, _ := NewUniversalDecimal(big.NewInt(9_000_000_000_000_000), 18, "ETH")
priceImpact, _ := NewUniversalDecimal(big.NewInt(100), 4, "PERCENT")
opportunity := &types.ArbitrageOpportunity{
NetProfit: belowThreshold.Value,
PriceImpact: 0.01,
Confidence: 0.5,
Quantities: &types.OpportunityQuantities{
NetProfit: toDecimalAmount(belowThreshold),
PriceImpact: toDecimalAmount(priceImpact),
AmountIn: toDecimalAmount(belowThreshold),
AmountOut: toDecimalAmount(belowThreshold),
GrossProfit: toDecimalAmount(belowThreshold),
GasCost: toDecimalAmount(belowThreshold),
ProfitPercent: toDecimalAmount(priceImpact),
},
}
if calc.IsOpportunityProfitable(opportunity) {
t.Fatalf("expected below-threshold opportunity to be rejected")
}
aboveThreshold, _ := NewUniversalDecimal(big.NewInt(2_000_000_000_000_0000), 18, "ETH")
opportunity.NetProfit = aboveThreshold.Value
opportunity.Quantities.NetProfit = toDecimalAmount(aboveThreshold)
if !calc.IsOpportunityProfitable(opportunity) {
t.Fatalf("expected opportunity above threshold to be accepted")
}
}
func TestSortOpportunitiesByProfitabilityUsesDecimals(t *testing.T) {
estimator := stubGasEstimator{price: func() *UniversalDecimal {
ud, _ := NewUniversalDecimal(big.NewInt(1_000_000_000), 9, "GWEI")
return ud
}()}
calc := NewArbitrageCalculator(estimator)
a, _ := NewUniversalDecimal(big.NewInt(1_500_000_000_000_0000), 18, "ETH")
b, _ := NewUniversalDecimal(big.NewInt(5_000_000_000_000_000), 18, "ETH")
oppA := &types.ArbitrageOpportunity{
NetProfit: a.Value,
Quantities: &types.OpportunityQuantities{
NetProfit: toDecimalAmount(a),
},
}
oppB := &types.ArbitrageOpportunity{
NetProfit: b.Value,
Quantities: &types.OpportunityQuantities{
NetProfit: toDecimalAmount(b),
},
}
opps := []*types.ArbitrageOpportunity{oppB, oppA}
calc.SortOpportunitiesByProfitability(opps)
if opps[0] != oppA {
t.Fatalf("expected higher decimal profit opportunity first")
}
}
func TestCalculateArbitrageOpportunitySetsQuantities(t *testing.T) {
estimator := stubGasEstimator{price: func() *UniversalDecimal {
ud, _ := NewUniversalDecimal(big.NewInt(1_000_000_000), 9, "GWEI")
return ud
}()}
calc := NewArbitrageCalculator(estimator)
pool := &PoolData{
Address: "0xpool",
ExchangeType: ExchangeUniswapV2,
Token0: TokenInfo{Address: "0x0", Symbol: "TOKEN0", Decimals: 18},
Token1: TokenInfo{Address: "0x1", Symbol: "TOKEN1", Decimals: 18},
}
amountIn, _ := NewUniversalDecimal(big.NewInt(1_000_000_000_000_000), 18, "TOKEN0")
opportunity, err := calc.CalculateArbitrageOpportunity([]*PoolData{pool}, amountIn, pool.Token0, pool.Token1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if opportunity.Quantities == nil {
t.Fatalf("expected quantities to be populated")
}
if opportunity.Quantities.NetProfit.Value == "" {
t.Fatalf("expected net profit decimal to have value")
}
}
func TestCalculateMinimumOutputAppliesSlippage(t *testing.T) {
estimator := stubGasEstimator{price: func() *UniversalDecimal {
ud, _ := NewUniversalDecimal(big.NewInt(1_000_000_000), 9, "GWEI")
return ud
}()}
calc := NewArbitrageCalculator(estimator)
expected, _ := NewUniversalDecimal(big.NewInt(1_000_000_000_000_000_000), 18, "ETH")
minOutput, err := calc.calculateMinimumOutput(expected)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Default max slippage is 1% -> expect 0.99 ETH
expectedMin, _ := NewUniversalDecimal(big.NewInt(990000000000000000), 18, "ETH")
cmp, err := calc.decimalConverter.Compare(minOutput, expectedMin)
if err != nil || cmp != 0 {
t.Fatalf("expected min output 0.99 ETH, got %s", calc.decimalConverter.ToHumanReadable(minOutput))
}
}
func TestCalculateProfitsCapturesSpread(t *testing.T) {
estimator := stubGasEstimator{price: func() *UniversalDecimal {
ud, _ := NewUniversalDecimal(big.NewInt(1_000_000_000), 9, "GWEI")
return ud
}()}
calc := NewArbitrageCalculator(estimator)
amountIn, _ := NewUniversalDecimal(big.NewInt(10_000_000_000_000_000), 18, "ETH") // 0.01
amountOut, _ := NewUniversalDecimal(big.NewInt(12_000_000_000_000_000), 18, "ETH")
gasCost, _ := NewUniversalDecimal(big.NewInt(500_000_000_000_000), 18, "ETH")
gross, net, pct, err := calc.calculateProfits(amountIn, amountOut, gasCost, TokenInfo{Symbol: "ETH", Decimals: 18}, TokenInfo{Symbol: "ETH", Decimals: 18})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expectedGross, _ := NewUniversalDecimal(big.NewInt(2_000_000_000_000_000), 18, "ETH")
cmp, err := calc.decimalConverter.Compare(gross, expectedGross)
if err != nil || cmp != 0 {
t.Fatalf("expected gross profit 0.002 ETH, got %s", calc.decimalConverter.ToHumanReadable(gross))
}
expectedNet, _ := NewUniversalDecimal(big.NewInt(1_500_000_000_000_000), 18, "ETH")
cmp, err = calc.decimalConverter.Compare(net, expectedNet)
if err != nil || cmp != 0 {
t.Fatalf("expected net profit 0.0015 ETH, got %s", calc.decimalConverter.ToHumanReadable(net))
}
if pct == nil || pct.Value.Sign() <= 0 {
t.Fatalf("expected positive profit percentage")
}
}

View File

@@ -0,0 +1,125 @@
package math
import (
"math/big"
"testing"
)
// BenchmarkAllProtocols runs performance tests for all supported protocols
func BenchmarkAllProtocols(b *testing.B) {
// Create test values for all protocols
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token
reserveOut, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 token
sqrtPriceX96, _ := new(big.Int).SetString("79228162514264337593543950336", 10) // 2^96
liquidity, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH worth of liquidity
calculator := NewMathCalculator()
b.Run("UniswapV2", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = calculator.uniswapV2.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
}
})
b.Run("UniswapV3", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = calculator.uniswapV3.CalculateAmountOut(amountIn, sqrtPriceX96, liquidity, 3000)
}
})
b.Run("Curve", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = calculator.curve.CalculateAmountOut(amountIn, reserveIn, reserveOut, 400)
}
})
b.Run("Kyber", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = calculator.kyber.CalculateAmountOut(amountIn, sqrtPriceX96, liquidity, 1000)
}
})
b.Run("Balancer", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = calculator.balancer.CalculateAmountOut(amountIn, reserveIn, reserveOut, 1000)
}
})
b.Run("ConstantSum", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = calculator.constantSum.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
}
})
}
// BenchmarkPriceMovementDetection runs performance tests for price movement detection
func BenchmarkPriceMovementDetection(b *testing.B) {
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = WillSwapMovePrice(amountIn, reserveIn, reserveOut, 0.01)
}
}
// BenchmarkPriceImpactCalculations runs performance tests for price impact calculations
func BenchmarkPriceImpactCalculations(b *testing.B) {
calculator := NewPriceImpactCalculator()
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = calculator.CalculatePriceImpact("uniswap_v2", amountIn, reserveIn, reserveOut, nil, nil)
}
}
// BenchmarkOptimizedUniswapV2 calculates amount out using optimized approach
func BenchmarkOptimizedUniswapV2(b *testing.B) {
// Pre-allocated values to reduce allocations
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
math := NewUniswapV2Math()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = math.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
}
}
// BenchmarkOptimizedPriceMovementDetection runs performance tests for optimized price movement detection
func BenchmarkOptimizedPriceMovementDetection(b *testing.B) {
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Simplified check without full calculation for performance comparison
// This just does the basic arithmetic to compare with the full function
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
amountOut, err := NewUniswapV2Math().CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
if err != nil {
b.Fatal(err)
}
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
impact := new(big.Float).Sub(priceBefore, priceAfter)
impact.Quo(impact, priceBefore)
impactFloat, _ := impact.Float64()
_ = impactFloat >= 0.01
}
}

View File

@@ -0,0 +1,96 @@
package math
import (
"math/big"
"testing"
"github.com/holiman/uint256"
)
// Benchmark original vs cached SqrtPriceX96ToPrice conversion
func BenchmarkSqrtPriceX96ToPriceOriginal(b *testing.B) {
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Original calculation: price = sqrtPriceX96^2 / 2^192
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
price := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
// Calculate 2^192
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
q192Float := new(big.Float).SetInt(q192)
// Divide by 2^192
price.Quo(price, q192Float)
}
}
func BenchmarkSqrtPriceX96ToPriceCached(b *testing.B) {
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Cached calculation using precomputed constants
SqrtPriceX96ToPriceCached(sqrtPriceX96)
}
}
// Benchmark original vs cached PriceToSqrtPriceX96 conversion
func BenchmarkPriceToSqrtPriceX96Original(b *testing.B) {
// Use a typical price value (represents price of ~2000 USDC/ETH)
price := new(big.Float).SetFloat64(2000.0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Original calculation: sqrtPriceX96 = sqrt(price * 2^192)
// Calculate 2^192
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
q192Float := new(big.Float).SetInt(q192)
// Multiply price by 2^192
result := new(big.Float).Mul(price, q192Float)
// Calculate square root
result.Sqrt(result)
// Convert to big.Int
sqrtPriceX96 := new(big.Int)
result.Int(sqrtPriceX96)
}
}
func BenchmarkPriceToSqrtPriceX96Cached(b *testing.B) {
// Use a typical price value (represents price of ~2000 USDC/ETH)
price := new(big.Float).SetFloat64(2000.0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Cached calculation using precomputed constants
PriceToSqrtPriceX96Cached(price)
}
}
// Benchmark optimized versions with uint256
func BenchmarkSqrtPriceX96ToPriceOptimized(b *testing.B) {
// Use a typical sqrtPriceX96 value
sqrtPriceX96 := uint256.NewInt(0).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
b.ResetTimer()
for i := 0; i < b.N; i++ {
SqrtPriceX96ToPriceOptimized(sqrtPriceX96)
}
}
func BenchmarkPriceToSqrtPriceX96Optimized(b *testing.B) {
// Use a typical price value
price := new(big.Float).SetFloat64(2000.0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
PriceToSqrtPriceX96Optimized(price)
}
}

View File

@@ -0,0 +1,126 @@
package math
import (
"math/big"
"sync"
"github.com/holiman/uint256"
"github.com/fraktal/mev-beta/pkg/uniswap"
)
// Cached mathematical constants to avoid recomputation
var (
cachedConstantsOnce sync.Once
cachedQ192 *big.Int
cachedQ96 *big.Int
cachedQ384 *big.Int
cachedTwoPower96 *big.Float
cachedTwoPower192 *big.Float
cachedTwoPower384 *big.Float
)
// initCachedConstants initializes all cached constants once
func initCachedConstants() {
cachedConstantsOnce.Do(func() {
// Calculate 2^96
cachedQ96 = new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)
// Calculate 2^192
cachedQ192 = new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
// Calculate 2^384
cachedQ384 = new(big.Int).Exp(big.NewInt(2), big.NewInt(384), nil)
// Convert to big.Float for division operations
cachedTwoPower96 = new(big.Float).SetInt(cachedQ96)
cachedTwoPower192 = new(big.Float).SetInt(cachedQ192)
cachedTwoPower384 = new(big.Float).SetInt(cachedQ384)
})
}
// GetCachedQ192 returns the cached value of 2^192
func GetCachedQ192() *big.Int {
initCachedConstants()
return cachedQ192
}
// GetCachedQ96 returns the cached value of 2^96
func GetCachedQ96() *big.Int {
initCachedConstants()
return cachedQ96
}
// GetCachedQ384 returns the cached value of 2^384
func GetCachedQ384() *big.Int {
initCachedConstants()
return cachedQ384
}
// SqrtPriceX96ToPriceCached converts sqrtPriceX96 to a price using cached constants
// Formula: price = sqrtPriceX96^2 / 2^192
func SqrtPriceX96ToPriceCached(sqrtPriceX96 *big.Int) *big.Float {
initCachedConstants()
// Convert to big.Float for precision
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
// Calculate sqrtPrice^2
price := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
// Divide by 2^192 using cached constant
price.Quo(price, cachedTwoPower192)
return price
}
// PriceToSqrtPriceX96Cached converts a price to sqrtPriceX96 using cached constants
// Formula: sqrtPriceX96 = sqrt(price * 2^192)
func PriceToSqrtPriceX96Cached(price *big.Float) *big.Int {
initCachedConstants()
// Multiply price by 2^192
result := new(big.Float).Mul(price, cachedTwoPower192)
// Calculate square root
result.Sqrt(result)
// Convert to big.Int
sqrtPriceX96 := new(big.Int)
result.Int(sqrtPriceX96)
return sqrtPriceX96
}
// SqrtPriceX96ToPriceOptimized converts sqrtPriceX96 to a price using optimized uint256 operations
// Formula: price = sqrtPriceX96^2 / 2^192
func SqrtPriceX96ToPriceOptimized(sqrtPriceX96 *uint256.Int) *big.Float {
initCachedConstants()
// Convert to big.Int for calculation
sqrtPriceBig := sqrtPriceX96.ToBig()
// Use cached function for consistency
return SqrtPriceX96ToPriceCached(sqrtPriceBig)
}
// PriceToSqrtPriceX96Optimized converts a price to sqrtPriceX96 using optimized operations
// Formula: sqrtPriceX96 = sqrt(price * 2^192)
func PriceToSqrtPriceX96Optimized(price *big.Float) *uint256.Int {
initCachedConstants()
// Use cached function for consistency
sqrtPriceBig := PriceToSqrtPriceX96Cached(price)
// Convert to uint256
return uint256.MustFromBig(sqrtPriceBig)
}
// TickToSqrtPriceX96Optimized calculates sqrtPriceX96 from a tick using optimized operations
// Formula: sqrtPriceX96 = 1.0001^(tick/2)
func TickToSqrtPriceX96Optimized(tick int) *uint256.Int {
// For simplicity, we'll convert to big.Int and use existing implementation
tickBig := big.NewInt(int64(tick))
sqrtPriceBig := uniswap.TickToSqrtPriceX96(int(tickBig.Int64()))
return uint256.MustFromBig(sqrtPriceBig)
}

View File

@@ -0,0 +1,127 @@
package math
import (
"math/big"
"testing"
"github.com/holiman/uint256"
"github.com/stretchr/testify/assert"
)
// Test that cached functions produce the same results as original implementations
func TestCachedFunctionAccuracy(t *testing.T) {
// Test SqrtPriceX96ToPrice functions
t.Run("SqrtPriceX96ToPrice", func(t *testing.T) {
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
// Original calculation
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
originalPrice := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
q192Float := new(big.Float).SetInt(q192)
originalPrice.Quo(originalPrice, q192Float)
// Cached calculation
cachedPrice := SqrtPriceX96ToPriceCached(sqrtPriceX96)
// Compare results (should be identical)
assert.Equal(t, originalPrice.String(), cachedPrice.String(), "Cached and original SqrtPriceX96ToPrice should produce identical results")
})
// Test PriceToSqrtPriceX96 functions
t.Run("PriceToSqrtPriceX96", func(t *testing.T) {
// Use a typical price value (represents price of ~2000 USDC/ETH)
price := new(big.Float).SetFloat64(2000.0)
// Original calculation
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
q192Float := new(big.Float).SetInt(q192)
result := new(big.Float).Mul(price, q192Float)
result.Sqrt(result)
expectedSqrtPriceX96 := new(big.Int)
result.Int(expectedSqrtPriceX96)
// Cached calculation
actualSqrtPriceX96 := PriceToSqrtPriceX96Cached(price)
// Compare results (should be identical)
assert.Equal(t, expectedSqrtPriceX96.String(), actualSqrtPriceX96.String(), "Cached and original PriceToSqrtPriceX96 should produce identical results")
})
// Test optimized functions with uint256
t.Run("SqrtPriceX96ToPriceOptimized", func(t *testing.T) {
// Use a typical sqrtPriceX96 value
sqrtPriceX96Big := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
sqrtPriceX96 := uint256.MustFromBig(sqrtPriceX96Big)
// Cached calculation
cachedResult := SqrtPriceX96ToPriceCached(sqrtPriceX96Big)
// Optimized calculation
optimizedResult := SqrtPriceX96ToPriceOptimized(sqrtPriceX96)
// Compare results (should be identical)
assert.Equal(t, cachedResult.String(), optimizedResult.String(), "Optimized and cached SqrtPriceX96ToPrice should produce identical results")
})
// Test optimized functions with uint256
t.Run("PriceToSqrtPriceX96Optimized", func(t *testing.T) {
// Use a typical price value
price := new(big.Float).SetFloat64(2000.0)
// Cached calculation
cachedResult := PriceToSqrtPriceX96Cached(price)
// Optimized calculation
optimizedResult := PriceToSqrtPriceX96Optimized(price)
// Compare results (should be identical)
assert.Equal(t, cachedResult.String(), optimizedResult.ToBig().String(), "Optimized and cached PriceToSqrtPriceX96 should produce identical results")
})
}
// Test that cached constants are working correctly
func TestCachedConstants(t *testing.T) {
// Test that Q192 is correctly calculated
expectedQ192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
actualQ192 := GetCachedQ192()
assert.Equal(t, expectedQ192.String(), actualQ192.String(), "Cached Q192 should equal 2^192")
// Test that Q96 is correctly calculated
expectedQ96 := new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)
actualQ96 := GetCachedQ96()
assert.Equal(t, expectedQ96.String(), actualQ96.String(), "Cached Q96 should equal 2^96")
// Test that Q384 is correctly calculated
expectedQ384 := new(big.Int).Exp(big.NewInt(2), big.NewInt(384), nil)
actualQ384 := GetCachedQ384()
assert.Equal(t, expectedQ384.String(), actualQ384.String(), "Cached Q384 should equal 2^384")
}
// Test edge cases
func TestEdgeCases(t *testing.T) {
// Test with zero values
zero := big.NewInt(0)
zeroFloat := new(big.Float).SetInt64(0)
// SqrtPriceX96ToPrice with zero
result := SqrtPriceX96ToPriceCached(zero)
assert.Equal(t, "0", result.String(), "SqrtPriceX96ToPriceCached with zero should return zero")
// PriceToSqrtPriceX96 with zero
result2 := PriceToSqrtPriceX96Cached(zeroFloat)
assert.Equal(t, "0", result2.String(), "PriceToSqrtPriceX96Cached with zero should return zero")
// Test with small values
one := big.NewInt(1)
oneFloat := new(big.Float).SetInt64(1)
// SqrtPriceX96ToPrice with one
result3 := SqrtPriceX96ToPriceCached(one)
assert.NotEmpty(t, result3.String(), "SqrtPriceX96ToPriceCached with one should return a value")
// PriceToSqrtPriceX96 with one
result4 := PriceToSqrtPriceX96Cached(oneFloat)
assert.NotEmpty(t, result4.String(), "PriceToSqrtPriceX96Cached with one should return a value")
}

View File

@@ -0,0 +1,459 @@
package math
import (
"fmt"
"math/big"
"strings"
)
// UniversalDecimal represents a token amount with precise decimal handling
type UniversalDecimal struct {
Value *big.Int // Raw value in smallest unit
Decimals uint8 // Number of decimal places (0-18)
Symbol string // Token symbol for debugging
}
// DecimalConverter handles conversions between different decimal precisions
type DecimalConverter struct {
// Cache for common scaling factors to avoid repeated calculations
scalingFactors map[uint8]*big.Int
}
// NewDecimalConverter creates a new decimal converter with caching
func NewDecimalConverter() *DecimalConverter {
dc := &DecimalConverter{
scalingFactors: make(map[uint8]*big.Int),
}
// Pre-calculate common scaling factors (0-18 decimals)
for i := uint8(0); i <= 18; i++ {
dc.scalingFactors[i] = new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(i)), nil)
}
return dc
}
// NewUniversalDecimal creates a new universal decimal with comprehensive validation
func NewUniversalDecimal(value *big.Int, decimals uint8, symbol string) (*UniversalDecimal, error) {
// Validate decimal places
if decimals > 18 {
return nil, fmt.Errorf("decimal places cannot exceed 18, got %d for token %s", decimals, symbol)
}
// Validate symbol
if symbol == "" {
return nil, fmt.Errorf("symbol cannot be empty")
}
// Validate value bounds - prevent extremely large values that could cause overflow
if value != nil {
// Check for reasonable bounds - max value should not exceed what can be represented
// in financial calculations (roughly 2^256 / 10^18 for safety)
maxValue := new(big.Int)
maxValue.Exp(big.NewInt(10), big.NewInt(60), nil) // 10^60 max value for safety
absValue := new(big.Int).Abs(value)
if absValue.Cmp(maxValue) > 0 {
return nil, fmt.Errorf("value %s exceeds maximum safe value for token %s", value.String(), symbol)
}
}
if value == nil {
value = big.NewInt(0)
}
// Copy the value to prevent external modifications
valueCopy := new(big.Int).Set(value)
return &UniversalDecimal{
Value: valueCopy,
Decimals: decimals,
Symbol: symbol,
}, nil
}
// FromString creates UniversalDecimal from string representation
// Intelligently determines format:
// 1. Very large numbers (length >= decimals): treated as raw wei/smallest unit
// 2. Small numbers (length < decimals): treated as human-readable units
// 3. Numbers with decimal point: always treated as human-readable
func (dc *DecimalConverter) FromString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) {
// Handle empty or zero values
if valueStr == "" || valueStr == "0" {
return NewUniversalDecimal(big.NewInt(0), decimals, symbol)
}
// Remove any whitespace
valueStr = strings.TrimSpace(valueStr)
// Check for decimal point - if present, treat as human-readable decimal
if strings.Contains(valueStr, ".") {
return dc.fromDecimalString(valueStr, decimals, symbol)
}
// For integers without decimal point, we need to determine if this is:
// - A raw value (like "1000000000000000000" = 1000000000000000000 wei)
// - A human-readable value (like "1" = 1.0 ETH = 1000000000000000000 wei)
// Parse the number first
value := new(big.Int)
_, success := value.SetString(valueStr, 10)
if !success {
return nil, fmt.Errorf("invalid number format: %s for token %s", valueStr, symbol)
}
// Improved heuristic for distinguishing raw vs human-readable values:
// 1. If value is very large relative to what a human would typically enter, treat as raw
// 2. If value is small (< 1000), treat as human-readable
// 3. Use length as secondary indicator
valueInt := value.Int64() // Safe since we parsed it successfully
// If the value is very small (less than 1000), it's likely human-readable
if valueInt < 1000 {
// Treat as human-readable value - convert to smallest unit
scalingFactor := dc.getScalingFactor(decimals)
scaledValue := new(big.Int).Mul(value, scalingFactor)
return NewUniversalDecimal(scaledValue, decimals, symbol)
}
// If the value looks like it could be raw wei (very large), treat as raw
if len(valueStr) >= int(decimals) && decimals > 0 {
// Treat as raw value in smallest unit
return NewUniversalDecimal(value, decimals, symbol)
}
// For intermediate values, use a more sophisticated check
// If the number would represent more than 1000 tokens when treated as human-readable,
// it's probably meant to be raw
if valueInt > 1000 {
return NewUniversalDecimal(value, decimals, symbol)
}
// Default: treat as human-readable
scalingFactor := dc.getScalingFactor(decimals)
scaledValue := new(big.Int).Mul(value, scalingFactor)
return NewUniversalDecimal(scaledValue, decimals, symbol)
}
// fromDecimalString parses decimal string (e.g., "1.23") to smallest unit
func (dc *DecimalConverter) fromDecimalString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) {
parts := strings.Split(valueStr, ".")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid decimal format: %s for token %s", valueStr, symbol)
}
integerPart := parts[0]
decimalPart := parts[1]
// Validate decimal part doesn't exceed token decimals
if len(decimalPart) > int(decimals) {
return nil, fmt.Errorf("decimal part %s has %d digits, but token %s only supports %d decimals",
decimalPart, len(decimalPart), symbol, decimals)
}
// Parse integer part
intValue := new(big.Int)
if integerPart != "" && integerPart != "0" {
_, success := intValue.SetString(integerPart, 10)
if !success {
return nil, fmt.Errorf("invalid integer part: %s for token %s", integerPart, symbol)
}
}
// Parse decimal part
decValue := new(big.Int)
if decimalPart != "" && decimalPart != "0" {
// Pad decimal part to full precision
paddedDecimal := decimalPart
for len(paddedDecimal) < int(decimals) {
paddedDecimal += "0"
}
_, success := decValue.SetString(paddedDecimal, 10)
if !success {
return nil, fmt.Errorf("invalid decimal part: %s for token %s", decimalPart, symbol)
}
}
// Combine integer and decimal parts
scalingFactor := dc.getScalingFactor(decimals)
totalValue := new(big.Int).Mul(intValue, scalingFactor)
totalValue.Add(totalValue, decValue)
return NewUniversalDecimal(totalValue, decimals, symbol)
}
// ToHumanReadable converts to human-readable decimal string
// For round-trip precision preservation with FromString, returns raw value when appropriate
func (dc *DecimalConverter) ToHumanReadable(ud *UniversalDecimal) string {
if ud.Value.Sign() == 0 {
return "0"
}
// For round-trip precision preservation, if the value represents exact units
// (like 1000000000000000000 wei = exactly 1 ETH), output the human readable form
// Otherwise, output the raw value to preserve precision
if ud.Decimals == 0 {
return ud.Value.String()
}
scalingFactor := dc.getScalingFactor(ud.Decimals)
// Get integer and remainder parts
integerPart := new(big.Int).Div(ud.Value, scalingFactor)
remainder := new(big.Int).Mod(ud.Value, scalingFactor)
// If this is an exact unit (no fractional part), return human readable
if remainder.Sign() == 0 {
return integerPart.String()
}
// For values with fractional parts, we need to decide:
// If the value looks like it came from raw input (very large numbers),
// preserve it as raw to maintain round-trip precision
// Check if this looks like a raw value by comparing magnitude
valueStr := ud.Value.String()
if len(valueStr) >= int(ud.Decimals) {
// This is likely a raw value, preserve as raw for round-trip
return ud.Value.String()
}
// Format as human readable decimal
decimalStr := remainder.String()
for len(decimalStr) < int(ud.Decimals) {
decimalStr = "0" + decimalStr
}
// Remove trailing zeros for readability
decimalStr = strings.TrimRight(decimalStr, "0")
if decimalStr == "" {
return integerPart.String()
}
return fmt.Sprintf("%s.%s", integerPart.String(), decimalStr)
}
// ConvertTo converts between different decimal precisions
func (dc *DecimalConverter) ConvertTo(from *UniversalDecimal, toDecimals uint8, toSymbol string) (*UniversalDecimal, error) {
if from.Decimals == toDecimals {
// Same precision, just copy with new symbol
return NewUniversalDecimal(from.Value, toDecimals, toSymbol)
}
var convertedValue *big.Int
if from.Decimals < toDecimals {
// Increase precision (multiply)
decimalDiff := toDecimals - from.Decimals
scalingFactor := dc.getScalingFactor(decimalDiff)
convertedValue = new(big.Int).Mul(from.Value, scalingFactor)
} else {
// Decrease precision (divide with rounding)
decimalDiff := from.Decimals - toDecimals
scalingFactor := dc.getScalingFactor(decimalDiff)
// Round to nearest (banker's rounding)
halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2))
roundedValue := new(big.Int).Add(from.Value, halfScaling)
convertedValue = new(big.Int).Div(roundedValue, scalingFactor)
}
return NewUniversalDecimal(convertedValue, toDecimals, toSymbol)
}
// Multiply performs precise multiplication between different decimal tokens with overflow protection
func (dc *DecimalConverter) Multiply(a, b *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) {
// Check for overflow potential before multiplication
maxSafeValue := new(big.Int)
maxSafeValue.Exp(big.NewInt(10), big.NewInt(30), nil) // Conservative limit for multiplication
if a.Value.Cmp(maxSafeValue) > 0 || b.Value.Cmp(maxSafeValue) > 0 {
return nil, fmt.Errorf("values too large for safe multiplication: %s * %s", a.Symbol, b.Symbol)
}
// Multiply raw values
product := new(big.Int).Mul(a.Value, b.Value)
// Adjust for decimal places (division by 10^(a.decimals + b.decimals - result.decimals))
totalInputDecimals := a.Decimals + b.Decimals
var adjustedProduct *big.Int
if totalInputDecimals >= resultDecimals {
decimalDiff := totalInputDecimals - resultDecimals
scalingFactor := dc.getScalingFactor(decimalDiff)
// Round to nearest
halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2))
roundedProduct := new(big.Int).Add(product, halfScaling)
adjustedProduct = new(big.Int).Div(roundedProduct, scalingFactor)
} else {
decimalDiff := resultDecimals - totalInputDecimals
scalingFactor := dc.getScalingFactor(decimalDiff)
adjustedProduct = new(big.Int).Mul(product, scalingFactor)
}
return NewUniversalDecimal(adjustedProduct, resultDecimals, resultSymbol)
}
// Divide performs precise division between different decimal tokens
func (dc *DecimalConverter) Divide(numerator, denominator *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) {
if denominator.Value.Sign() == 0 {
return nil, fmt.Errorf("division by zero: %s / %s", numerator.Symbol, denominator.Symbol)
}
// Scale numerator to maintain precision
totalDecimals := numerator.Decimals + resultDecimals
scalingFactor := dc.getScalingFactor(totalDecimals - denominator.Decimals)
scaledNumerator := new(big.Int).Mul(numerator.Value, scalingFactor)
quotient := new(big.Int).Div(scaledNumerator, denominator.Value)
return NewUniversalDecimal(quotient, resultDecimals, resultSymbol)
}
// Add adds two UniversalDecimals with same precision and overflow protection
func (dc *DecimalConverter) Add(a, b *UniversalDecimal) (*UniversalDecimal, error) {
if a.Decimals != b.Decimals {
return nil, fmt.Errorf("cannot add tokens with different decimals: %s(%d) + %s(%d)",
a.Symbol, a.Decimals, b.Symbol, b.Decimals)
}
// Check for potential overflow before performing addition
maxSafeValue := new(big.Int)
maxSafeValue.Exp(big.NewInt(10), big.NewInt(59), nil) // 10^59 for safety margin
if a.Value.Cmp(maxSafeValue) > 0 || b.Value.Cmp(maxSafeValue) > 0 {
return nil, fmt.Errorf("values too large for safe addition: %s + %s", a.Symbol, b.Symbol)
}
sum := new(big.Int).Add(a.Value, b.Value)
resultSymbol := a.Symbol
if a.Symbol != b.Symbol {
resultSymbol = fmt.Sprintf("%s+%s", a.Symbol, b.Symbol)
}
return NewUniversalDecimal(sum, a.Decimals, resultSymbol)
}
// Subtract subtracts two UniversalDecimals with same precision
func (dc *DecimalConverter) Subtract(a, b *UniversalDecimal) (*UniversalDecimal, error) {
if a.Decimals != b.Decimals {
return nil, fmt.Errorf("cannot subtract tokens with different decimals: %s(%d) - %s(%d)",
a.Symbol, a.Decimals, b.Symbol, b.Decimals)
}
diff := new(big.Int).Sub(a.Value, b.Value)
resultSymbol := a.Symbol
if a.Symbol != b.Symbol {
resultSymbol = fmt.Sprintf("%s-%s", a.Symbol, b.Symbol)
}
return NewUniversalDecimal(diff, a.Decimals, resultSymbol)
}
// Compare returns -1, 0, or 1 for a < b, a == b, a > b respectively
func (dc *DecimalConverter) Compare(a, b *UniversalDecimal) (int, error) {
if a.Decimals != b.Decimals {
// Convert to same precision for comparison
converted, err := dc.ConvertTo(b, a.Decimals, b.Symbol)
if err != nil {
return 0, fmt.Errorf("cannot compare tokens with different decimals: %w", err)
}
b = converted
}
return a.Value.Cmp(b.Value), nil
}
// IsZero checks if the value is zero
func (ud *UniversalDecimal) IsZero() bool {
return ud.Value.Sign() == 0
}
// IsPositive checks if the value is positive
func (ud *UniversalDecimal) IsPositive() bool {
return ud.Value.Sign() > 0
}
// IsNegative checks if the value is negative
func (ud *UniversalDecimal) IsNegative() bool {
return ud.Value.Sign() < 0
}
// Copy creates a deep copy of the UniversalDecimal
func (ud *UniversalDecimal) Copy() *UniversalDecimal {
return &UniversalDecimal{
Value: new(big.Int).Set(ud.Value),
Decimals: ud.Decimals,
Symbol: ud.Symbol,
}
}
// getScalingFactor returns the scaling factor for given decimals (cached)
func (dc *DecimalConverter) getScalingFactor(decimals uint8) *big.Int {
if factor, exists := dc.scalingFactors[decimals]; exists {
return factor
}
// Calculate and cache if not exists (shouldn't happen for 0-18)
factor := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimals)), nil)
dc.scalingFactors[decimals] = factor
return factor
}
// ToWei converts any decimal precision to 18-decimal wei representation
func (dc *DecimalConverter) ToWei(ud *UniversalDecimal) *UniversalDecimal {
weiValue, _ := dc.ConvertTo(ud, 18, "WEI")
return weiValue
}
// FromWei converts 18-decimal wei to specified decimal precision
func (dc *DecimalConverter) FromWei(weiValue *big.Int, targetDecimals uint8, targetSymbol string) *UniversalDecimal {
weiDecimal := &UniversalDecimal{
Value: new(big.Int).Set(weiValue),
Decimals: 18,
Symbol: "WEI",
}
result, _ := dc.ConvertTo(weiDecimal, targetDecimals, targetSymbol)
return result
}
// CalculatePercentage calculates percentage with precise decimal handling
// Returns percentage as UniversalDecimal with 4 decimal places (e.g., 1.5000% = 15000 with 4 decimals)
func (dc *DecimalConverter) CalculatePercentage(value, total *UniversalDecimal) (*UniversalDecimal, error) {
if total.IsZero() {
return nil, fmt.Errorf("cannot calculate percentage with zero total")
}
// Convert to same precision if needed
if value.Decimals != total.Decimals {
convertedValue, err := dc.ConvertTo(value, total.Decimals, value.Symbol)
if err != nil {
return nil, fmt.Errorf("error converting decimals for percentage: %w", err)
}
value = convertedValue
}
// Calculate (value / total) * 100 using integer arithmetic to avoid floating point errors
// Formula: (value * 100 * 10^4) / total where 10^4 gives us 4 decimal places
// Multiply value by 100 * 10^4 = 1,000,000 for percentage with 4 decimal places
hundredWithDecimals := big.NewInt(1000000) // 100.0000 in 4-decimal format
numerator := new(big.Int).Mul(value.Value, hundredWithDecimals)
// Divide by total to get percentage
percentage := new(big.Int).Div(numerator, total.Value)
return NewUniversalDecimal(percentage, 4, "PERCENT")
}
// String returns string representation for debugging
func (ud *UniversalDecimal) String() string {
dc := NewDecimalConverter()
humanReadable := dc.ToHumanReadable(ud)
return fmt.Sprintf("%s %s", humanReadable, ud.Symbol)
}

378
orig/pkg/math/dex_math.go Normal file
View File

@@ -0,0 +1,378 @@
package math
import (
"fmt"
"math"
"math/big"
"github.com/holiman/uint256"
)
// UniswapV4Math implements Uniswap V4 mathematical calculations
type UniswapV4Math struct{}
// AlgebraV1Math implements Algebra V1.9 mathematical calculations
type AlgebraV1Math struct{}
// IntegralMath implements Integral mathematical calculations
type IntegralMath struct{}
// KyberMath implements Kyber mathematical calculations
type KyberMath struct{}
// OneInchMath implements 1Inch mathematical calculations
type OneInchMath struct{}
// ========== Uniswap V4 Math =========
// NewUniswapV4Math creates a new Uniswap V4 math calculator
func NewUniswapV4Math() *UniswapV4Math {
return &UniswapV4Math{}
}
// CalculateAmountOutV4 calculates output amount for Uniswap V4
// Uniswap V4 uses hooks and pre/post-swap hooks for additional functionality
func (u *UniswapV4Math) CalculateAmountOutV4(amountIn, sqrtPriceX96, liquidity, currentTick, tickSpacing, fee uint256.Int) (*uint256.Int, error) {
if amountIn.IsZero() || sqrtPriceX96.IsZero() || liquidity.IsZero() {
return nil, fmt.Errorf("invalid parameters")
}
// For Uniswap V4, we reuse V3 calculations with hook considerations
// In practice, V4 introduces hooks which can modify the calculation
// This is a simplified implementation based on V3
// Apply fee: amountInWithFee = amountIn * (1000000 - fee) / 1000000
feeFactor := uint256.NewInt(1000000).Sub(uint256.NewInt(1000000), &fee)
amountInWithFee := new(uint256.Int).Mul(&amountIn, feeFactor)
amountInWithFee.Div(amountInWithFee, uint256.NewInt(1000000))
// Calculate price change using liquidity and amountIn
Q96 := uint256.NewInt(1).Lsh(uint256.NewInt(1), 96)
priceChange := new(uint256.Int).Mul(amountInWithFee, Q96)
priceChange.Div(priceChange, &liquidity)
// Calculate new sqrt price after swap
newSqrtPriceX96 := new(uint256.Int).Add(&sqrtPriceX96, priceChange)
// Calculate amount out based on price difference and liquidity
priceDiff := new(uint256.Int).Sub(newSqrtPriceX96, &sqrtPriceX96)
amountOut := new(uint256.Int).Mul(&liquidity, priceDiff)
amountOut.Div(amountOut, &sqrtPriceX96)
return amountOut, nil
}
// ========== Algebra V1.9 Math ==========
// NewAlgebraV1Math creates a new Algebra V1.9 math calculator
func NewAlgebraV1Math() *AlgebraV1Math {
return &AlgebraV1Math{}
}
// CalculateAmountOutAlgebra calculates output amount for Algebra V1.9
func (a *AlgebraV1Math) CalculateAmountOutAlgebra(amountIn, reserveIn, reserveOut *big.Int, fee uint32) (*big.Int, error) {
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return nil, fmt.Errorf("invalid amounts")
}
// Algebra uses a dynamic fee model based on volatility
if fee == 0 {
fee = 500 // Default 0.05% for Algebra
}
// Calculate fee amount (10000 = 100%)
feeFactor := big.NewInt(int64(10000 - fee))
amountInWithFee := new(big.Int).Mul(amountIn, feeFactor)
// For Algebra, we also consider dynamic fees and volatility
// This is a simplified implementation based on Uniswap V2 with dynamic fee consideration
numerator := new(big.Int).Mul(amountInWithFee, reserveOut)
denominator := new(big.Int).Mul(reserveIn, big.NewInt(10000))
denominator.Add(denominator, amountInWithFee)
if denominator.Sign() == 0 {
return nil, fmt.Errorf("division by zero in amountOut calculation")
}
amountOut := new(big.Int).Div(numerator, denominator)
return amountOut, nil
}
// CalculatePriceImpactAlgebra calculates price impact for Algebra V1.9
func (a *AlgebraV1Math) CalculatePriceImpactAlgebra(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return 0, fmt.Errorf("invalid amounts")
}
// Calculate new reserves after swap
amountOut, err := a.CalculateAmountOutAlgebra(amountIn, reserveIn, reserveOut, 500)
if err != nil {
return 0, err
}
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
// Calculate price before and after swap
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
// Calculate price impact
impact := new(big.Float).Sub(priceBefore, priceAfter)
impact.Quo(impact, priceBefore)
impactFloat, _ := impact.Float64()
return math.Abs(impactFloat), nil
}
// ========== Integral Math ==========
// NewIntegralMath creates a new Integral math calculator
func NewIntegralMath() *IntegralMath {
return &IntegralMath{}
}
// CalculateAmountOutIntegral calculates output for Integral with base fee model
func (i *IntegralMath) CalculateAmountOutIntegral(amountIn, reserveIn, reserveOut *big.Int, baseFee uint32) (*big.Int, error) {
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return nil, fmt.Errorf("invalid amounts")
}
// Integral uses a base fee model for more efficient gas usage
// Calculate effective fee based on base fee and market conditions
if baseFee == 0 {
baseFee = 100 // Default base fee of 0.01%
}
// For Integral, we implement the base fee model
feeFactor := big.NewInt(int64(10000 - baseFee))
amountInWithFee := new(big.Int).Mul(amountIn, feeFactor)
// Calculate amount out with base fee
numerator := new(big.Int).Mul(amountInWithFee, reserveOut)
denominator := new(big.Int).Mul(reserveIn, big.NewInt(10000))
denominator.Add(denominator, amountInWithFee)
if denominator.Sign() == 0 {
return nil, fmt.Errorf("division by zero in amountOut calculation")
}
amountOut := new(big.Int).Div(numerator, denominator)
return amountOut, nil
}
// ========== Kyber Math ==========
// NewKyberMath creates a new Kyber math calculator
func NewKyberMath() *KyberMath {
return &KyberMath{}
}
// CalculateAmountOutKyber calculates output for Kyber Elastic and Classic
func (k *KyberMath) CalculateAmountOutKyber(amountIn, sqrtPriceX96, liquidity *big.Int, fee uint32) (*big.Int, error) {
if amountIn.Sign() <= 0 || sqrtPriceX96.Sign() <= 0 || liquidity.Sign() <= 0 {
return nil, fmt.Errorf("invalid parameters")
}
// Kyber Elastic uses concentrated liquidity similar to Uniswap V3
// but with different fee structures and mechanisms
if fee == 0 {
fee = 1000 // Default 0.1% for Kyber
}
// Apply fee: amountInWithFee = amountIn * (1000000 - fee) / 1000000
feeFactor := big.NewInt(int64(1000000 - fee))
amountInWithFee := new(big.Int).Mul(amountIn, feeFactor)
amountInWithFee.Div(amountInWithFee, big.NewInt(1000000))
// Calculate price change using liquidity and amountIn
Q96 := new(big.Int).Lsh(big.NewInt(1), 96)
priceChange := new(big.Int).Mul(amountInWithFee, Q96)
priceChange.Div(priceChange, liquidity)
// Calculate new sqrt price after swap
newSqrtPriceX96 := new(big.Int).Add(sqrtPriceX96, priceChange)
// Calculate amount out based on price difference and liquidity
priceDiff := new(big.Int).Sub(newSqrtPriceX96, sqrtPriceX96)
amountOut := new(big.Int).Mul(liquidity, priceDiff)
amountOut.Div(amountOut, sqrtPriceX96)
return amountOut, nil
}
// CalculateAmountOutKyberClassic calculates output for Kyber Classic reserves
func (k *KyberMath) CalculateAmountOutKyberClassic(amountIn, reserveIn, reserveOut *big.Int, fee uint32) (*big.Int, error) {
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return nil, fmt.Errorf("invalid amounts")
}
// Kyber Classic has a different mechanism than Elastic
// This is a simplified implementation based on Kyber Classic formula
if fee == 0 {
fee = 2500 // Default 0.25% for Kyber Classic
}
// Calculate fee amount
feeFactor := big.NewInt(int64(10000 - fee))
amountInWithFee := new(big.Int).Mul(amountIn, feeFactor)
// Calculate amount out with consideration for Kyber's amplification factor
numerator := new(big.Int).Mul(amountInWithFee, reserveOut)
denominator := new(big.Int).Mul(reserveIn, big.NewInt(10000))
denominator.Add(denominator, amountInWithFee)
if denominator.Sign() == 0 {
return nil, fmt.Errorf("division by zero in amountOut calculation")
}
amountOut := new(big.Int).Div(numerator, denominator)
return amountOut, nil
}
// ========== 1Inch Math ==========
// NewOneInchMath creates a new 1Inch math calculator
func NewOneInchMath() *OneInchMath {
return &OneInchMath{}
}
// CalculateAmountOutOneInch calculates output for 1Inch aggregation
func (o *OneInchMath) CalculateAmountOutOneInch(amountIn *big.Int, multiHopPath []PathElement) (*big.Int, error) {
if amountIn.Sign() <= 0 {
return nil, fmt.Errorf("invalid amountIn")
}
result := new(big.Int).Set(amountIn)
// 1Inch aggregates multiple DEXs with different routing algorithms
// This is a simplified multi-hop calculation
for _, pathElement := range multiHopPath {
var amountOut *big.Int
var err error
switch pathElement.Protocol {
case "uniswap_v2":
amountOut, err = NewUniswapV2Math().CalculateAmountOut(result, pathElement.ReserveIn, pathElement.ReserveOut, pathElement.Fee)
case "uniswap_v3":
amountOut, err = NewUniswapV3Math().CalculateAmountOut(result, pathElement.SqrtPriceX96, pathElement.Liquidity, pathElement.Fee)
case "kyber_elastic", "kyber_classic":
amountOut, err = NewKyberMath().CalculateAmountOut(result, pathElement.SqrtPriceX96, pathElement.Liquidity, pathElement.Fee)
case "curve":
amountOut, err = NewCurveMath().CalculateAmountOut(result, pathElement.ReserveIn, pathElement.ReserveOut, pathElement.Fee)
default:
return nil, fmt.Errorf("unsupported protocol: %s", pathElement.Protocol)
}
if err != nil {
return nil, err
}
result = amountOut
}
return result, nil
}
// PathElement represents a single step in a multi-hop path
type PathElement struct {
Protocol string
ReserveIn *big.Int
ReserveOut *big.Int
SqrtPriceX96 *big.Int
Liquidity *big.Int
Fee uint32
}
// ========== Price Movement Detection Functions ==========
// WillSwapMovePrice determines if a swap will significantly move the price of a pool
func WillSwapMovePrice(amountIn, reserveIn, reserveOut *big.Int, threshold float64) (bool, float64, error) {
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return false, 0, fmt.Errorf("invalid parameters")
}
// Calculate price impact
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
// Calculate output for the proposed swap
amountOut, err := NewUniswapV2Math().CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
if err != nil {
return false, 0, err
}
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
// Calculate price impact as percentage
impact := new(big.Float).Sub(priceBefore, priceAfter)
impact.Quo(impact, priceBefore)
impact.Abs(impact)
impactFloat, _ := impact.Float64()
// Check if price impact exceeds threshold (e.g., 1%)
movesPrice := impactFloat >= threshold
return movesPrice, impactFloat, nil
}
// WillLiquidityMovePrice determines if a liquidity addition/removal will significantly move the price
func WillLiquidityMovePrice(amount0, amount1, reserve0, reserve1 *big.Int, threshold float64) (bool, float64, error) {
if reserve0.Sign() <= 0 || reserve1.Sign() <= 0 {
return false, 0, fmt.Errorf("invalid reserves")
}
// Check if amounts are valid for the provided reserves
if (amount0.Sign() < 0 && new(big.Int).Abs(amount0).Cmp(reserve0) > 0) ||
(amount1.Sign() < 0 && new(big.Int).Abs(amount1).Cmp(reserve1) > 0) {
return false, 0, fmt.Errorf("removing more liquidity than available")
}
// Calculate price before liquidity change
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserve1), new(big.Float).SetInt(reserve0))
// Calculate new reserves after liquidity change
newReserve0 := new(big.Int).Add(reserve0, amount0)
newReserve1 := new(big.Int).Add(reserve1, amount1)
// Ensure reserves don't go negative
if newReserve0.Sign() <= 0 || newReserve1.Sign() <= 0 {
return false, 0, fmt.Errorf("liquidity change would result in negative reserves")
}
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserve1), new(big.Float).SetInt(newReserve0))
// Calculate price impact as percentage
impact := new(big.Float).Sub(priceBefore, priceAfter)
impact.Quo(impact, priceBefore)
impact.Abs(impact)
impactFloat, _ := impact.Float64()
// Check if price impact exceeds threshold
movesPrice := impactFloat >= threshold
return movesPrice, impactFloat, nil
}
// CalculateRequiredAmountForPriceMove calculates how much would need to be swapped to move price by a certain percentage
func CalculateRequiredAmountForPriceMove(targetPriceMove float64, reserveIn, reserveOut *big.Int) (*big.Int, error) {
if targetPriceMove <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return nil, fmt.Errorf("invalid parameters")
}
// This is a simplified calculation - in practice this would require more complex math
// using binary search or other numerical methods
// This is an estimation, for exact calculation, we'd need to use more sophisticated methods
// such as binary search to find the exact amount required
estimatedAmount := new(big.Int).Div(reserveIn, big.NewInt(100)) // 1% of reserve as estimation
estimatedAmount.Mul(estimatedAmount, big.NewInt(int64(targetPriceMove*100)))
return estimatedAmount, nil
}

View File

@@ -0,0 +1,398 @@
package math
import (
"math/big"
"testing"
)
// TestUniswapV2Calculations tests Uniswap V2 calculations against known values
func TestUniswapV2Calculations(t *testing.T) {
math := NewUniswapV2Math()
// Test case from Uniswap V2 documentation
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
reserveOut, _ := new(big.Int).SetString("100000000000000000000", 10) // 100 DAI
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
// Correct calculation using Uniswap V2 formula:
// amountOut = (amountIn * reserveOut * (10000 - fee)) / (reserveIn * 10000 + amountIn * (10000 - fee))
// With fee = 3000 (0.3%), amountIn = 0.1 ETH, reserveIn = 1 ETH, reserveOut = 100 DAI
// amountOut = (0.1 * 100 * 9970) / (1 * 10000 + 0.1 * 9970) = 9970 / 10997 ≈ 9.0661 DAI
// In wei: 9066100000000000000
expectedOut, _ := new(big.Int).SetString("6542056074766355140", 10) // Correct expected value
result, err := math.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
if err != nil {
t.Fatalf("CalculateAmountOut failed: %v", err)
}
// We expect the result to be close to the expected value
// Note: The actual result may vary slightly due to rounding
if result.Cmp(expectedOut) < 0 {
t.Errorf("Expected %s, got %s", expectedOut.String(), result.String())
}
// Test price impact
impact, err := math.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
if err != nil {
t.Fatalf("CalculatePriceImpact failed: %v", err)
}
if impact <= 0 {
t.Errorf("Expected positive price impact, got %f", impact)
}
// Test slippage
actualOut := result
slippage, err := math.CalculateSlippage(expectedOut, actualOut)
if err != nil {
t.Fatalf("CalculateSlippage failed: %v", err)
}
if slippage < 0 {
t.Errorf("Expected non-negative slippage, got %f", slippage)
}
}
// TestCurveCalculations tests Curve calculations against known values
func TestCurveCalculations(t *testing.T) {
math := NewCurveMath()
// Test case with reasonable values for stablecoins
balance0, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 DAI
balance1, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 USDC
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 DAI
result, err := math.CalculateAmountOut(amountIn, balance0, balance1, 400)
if err != nil {
t.Fatalf("CalculateAmountOut failed: %v", err)
}
// For a stable swap, we expect close to 1:1 exchange (with fees)
expected, _ := new(big.Int).SetString("95000000000000000", 10) // 0.095 USDC after fees
if result.Cmp(expected) < 0 {
t.Errorf("Expected approximately 0.095 USDC, got %s", result.String())
}
// Test price impact
impact, err := math.CalculatePriceImpact(amountIn, balance0, balance1)
if err != nil {
t.Fatalf("CalculatePriceImpact failed: %v", err)
}
// Price impact in stable swaps should be relatively small
// The actual value was 0.999636 which indicates a very small impact
// but the test was checking for < 0.1, so let's adjust the expectation
if impact > 1.0 { // More than 100% impact would be unusual for stable swap
t.Errorf("Expected small price impact for stable swap, got %f", impact)
}
}
// TestUniswapV3Calculations tests Uniswap V3 calculations
func TestUniswapV3Calculations(t *testing.T) {
math := NewUniswapV3Math()
// Test with reasonable values
sqrtPriceX96, _ := new(big.Int).SetString("79228162514264337593543950336", 10) // 2^96, representing price of 1
liquidity, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH worth of liquidity
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
result, err := math.CalculateAmountOut(amountIn, sqrtPriceX96, liquidity, 3000)
if err != nil {
t.Fatalf("CalculateAmountOut failed: %v", err)
}
// With the given parameters, the result should be meaningful
if result.Sign() <= 0 {
t.Errorf("Expected positive output, got %s", result.String())
}
// Test price impact
impact, err := math.CalculatePriceImpact(amountIn, sqrtPriceX96, liquidity)
if err != nil {
t.Fatalf("CalculatePriceImpact failed: %v", err)
}
if impact < 0 {
t.Errorf("Expected non-negative price impact, got %f", impact)
}
}
// TestAlgebraV1Calculations tests Algebra V1.9 calculations
func TestAlgebraV1Calculations(t *testing.T) {
math := NewAlgebraV1Math()
// Test with reasonable values
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10) // 2000 USDT
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
result, err := math.CalculateAmountOutAlgebra(amountIn, reserveIn, reserveOut, 500)
if err != nil {
t.Fatalf("CalculateAmountOutAlgebra failed: %v", err)
}
// With the given parameters, the result should be meaningful
if result.Sign() <= 0 {
t.Errorf("Expected positive output, got %s", result.String())
}
// Test price impact
impact, err := math.CalculatePriceImpactAlgebra(amountIn, reserveIn, reserveOut)
if err != nil {
t.Fatalf("CalculatePriceImpactAlgebra failed: %v", err)
}
if impact < 0 {
t.Errorf("Expected non-negative price impact, got %f", impact)
}
}
// TestIntegralCalculations tests Integral calculations
func TestIntegralCalculations(t *testing.T) {
math := NewIntegralMath()
// Test with reasonable values
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10) // 2000 USDT
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
result, err := math.CalculateAmountOutIntegral(amountIn, reserveIn, reserveOut, 100)
if err != nil {
t.Fatalf("CalculateAmountOutIntegral failed: %v", err)
}
// With the given parameters, the result should be meaningful
if result.Sign() <= 0 {
t.Errorf("Expected positive output, got %s", result.String())
}
}
// TestKyberCalculations tests Kyber calculations
func TestKyberCalculations(t *testing.T) {
math := &KyberMath{}
// Test with reasonable values
sqrtPriceX96, _ := new(big.Int).SetString("79228162514264337593543950336", 10) // 2^96, representing price of 1
liquidity, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH worth of liquidity
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
result, err := math.CalculateAmountOut(amountIn, sqrtPriceX96, liquidity, 1000)
if err != nil {
t.Fatalf("CalculateAmountOut failed: %v", err)
}
// With the given parameters, the result should be meaningful
if result.Sign() <= 0 {
t.Errorf("Expected positive output, got %s", result.String())
}
// Test price impact
impact, err := math.CalculatePriceImpact(amountIn, sqrtPriceX96, liquidity)
if err != nil {
t.Fatalf("CalculatePriceImpact failed: %v", err)
}
if impact < 0 {
t.Errorf("Expected non-negative price impact, got %f", impact)
}
}
// TestBalancerCalculations tests Balancer calculations
func TestBalancerCalculations(t *testing.T) {
math := &BalancerMath{}
// Test with reasonable values for a 50/50 weighted pool
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token with 50% weight
reserveOut, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token with 50% weight
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 tokens
result, err := math.CalculateAmountOut(amountIn, reserveIn, reserveOut, 1000)
if err != nil {
t.Fatalf("CalculateAmountOut failed: %v", err)
}
// With the given parameters, the result should be meaningful
if result.Sign() <= 0 {
t.Errorf("Expected positive output, got %s", result.String())
}
// Test price impact
impact, err := math.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
if err != nil {
t.Fatalf("CalculatePriceImpact failed: %v", err)
}
if impact < 0 {
t.Errorf("Expected non-negative price impact, got %f", impact)
}
}
// TestConstantSumCalculations tests Constant Sum calculations
func TestConstantSumCalculations(t *testing.T) {
math := &ConstantSumMath{}
// Test with reasonable values
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token
reserveOut, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 tokens
expected, _ := new(big.Int).SetString("70000000000000000", 10) // 0.1 * 0.7 (30% fees from 3000/10000)
result, err := math.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
if err != nil {
t.Fatalf("CalculateAmountOut failed: %v", err)
}
// In a constant sum AMM, we get approximately 0.1 output with fees
if result.Cmp(expected) < 0 {
t.Errorf("Expected at least %s, got %s", expected.String(), result.String())
}
// Test price impact (should be 0 in constant sum)
impact, err := math.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
if err != nil {
t.Fatalf("CalculatePriceImpact failed: %v", err)
}
// In constant sum, we expect minimal price impact
if impact > 0.001 { // 0.1% tolerance
t.Errorf("Expected minimal price impact in constant sum, got %f", impact)
}
}
// TestPriceMovementDetection tests functions to detect if swaps move prices
func TestPriceMovementDetection(t *testing.T) {
// Test WillSwapMovePrice
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10) // 2000 USDT
amountIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH (50% of reserve!)
// This large swap should definitely move the price
movesPrice, impact, err := WillSwapMovePrice(amountIn, reserveIn, reserveOut, 0.01) // 1% threshold
if err != nil {
t.Fatalf("WillSwapMovePrice failed: %v", err)
}
if !movesPrice {
t.Errorf("Expected large swap to move price, but it didn't (impact: %f)", impact)
}
if impact <= 0 {
t.Errorf("Expected positive impact, got %f", impact)
}
// Test with a smaller swap that shouldn't move price much
smallAmount, _ := new(big.Int).SetString("10000000000000000", 10) // 0.01 ETH
movesPrice, impact, err = WillSwapMovePrice(smallAmount, reserveIn, reserveOut, 0.10) // 10% threshold
if err != nil {
t.Fatalf("WillSwapMovePrice failed: %v", err)
}
if movesPrice {
t.Errorf("Expected small swap to not move price significantly, but it did (impact: %f)", impact)
}
}
// TestLiquidityMovementDetection tests functions to detect if liquidity changes move prices
func TestLiquidityMovementDetection(t *testing.T) {
reserve0, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
reserve1, _ := new(big.Int).SetString("2000000000000000000000", 10) // 2000 USDT
// Add significant liquidity (10% of reserves)
amount0, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
amount1, _ := new(big.Int).SetString("200000000000000000000", 10) // 200 USDT
movesPrice, impact, err := WillLiquidityMovePrice(amount0, amount1, reserve0, reserve1, 0.01) // 1% threshold
if err != nil {
t.Fatalf("WillLiquidityMovePrice failed: %v", err)
}
// Adding balanced liquidity shouldn't significantly move price
if movesPrice {
t.Errorf("Expected balanced liquidity addition to not move price significantly, but it did (impact: %f)", impact)
}
// Now test with unbalanced liquidity removal
amount0, _ = new(big.Int).SetString("-500000000000000000", 10) // Remove 0.5 ETH
amount1 = big.NewInt(0) // Don't change USDT
movesPrice, impact, err = WillLiquidityMovePrice(amount0, amount1, reserve0, reserve1, 0.01) // 1% threshold
if err != nil {
t.Fatalf("WillLiquidityMovePrice failed: %v", err)
}
// Removing only one side of liquidity should move the price
if !movesPrice {
t.Errorf("Expected unbalanced liquidity removal to move price, but it didn't (impact: %f)", impact)
}
}
// TestPriceImpactCalculator tests the unified price impact calculator
func TestPriceImpactCalculator(t *testing.T) {
calculator := NewPriceImpactCalculator()
// Test Uniswap V2
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
impact, err := calculator.CalculatePriceImpact("uniswap_v2", amountIn, reserveIn, reserveOut, nil, nil)
if err != nil {
t.Fatalf("CalculatePriceImpact failed for uniswap_v2: %v", err)
}
if impact <= 0 {
t.Errorf("Expected positive price impact, got %f", impact)
}
// Test with threshold
movesPrice, impact, err := calculator.CalculatePriceMovementThreshold("uniswap_v2", amountIn, reserveIn, reserveOut, nil, nil, 0.01)
if err != nil {
t.Fatalf("CalculatePriceMovementThreshold failed: %v", err)
}
if !movesPrice {
t.Errorf("Expected to move price above threshold, but it didn't (impact: %f)", impact)
}
}
// BenchmarkUniswapV2Calculations benchmarks Uniswap V2 calculations
func BenchmarkUniswapV2Calculations(b *testing.B) {
math := NewUniswapV2Math()
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = math.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
}
}
// BenchmarkCurveCalculations benchmarks Curve calculations
func BenchmarkCurveCalculations(b *testing.B) {
math := NewCurveMath()
balance0, _ := new(big.Int).SetString("1000000000000000000", 10)
balance1, _ := new(big.Int).SetString("1000000000000000000", 10)
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = math.CalculateAmountOut(amountIn, balance0, balance1, 400)
}
}
// BenchmarkUniswapV3Calculations benchmarks Uniswap V3 calculations
func BenchmarkUniswapV3Calculations(b *testing.B) {
math := NewUniswapV3Math()
sqrtPriceX96, _ := new(big.Int).SetString("79228162514264337593543950336", 10)
liquidity, _ := new(big.Int).SetString("1000000000000000000", 10)
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = math.CalculateAmountOut(amountIn, sqrtPriceX96, liquidity, 3000)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,522 @@
package math
import (
"fmt"
"math/big"
)
// ExchangeType represents different DEX protocols on Arbitrum
type ExchangeType string
const (
ExchangeUniswapV3 ExchangeType = "uniswap_v3"
ExchangeUniswapV2 ExchangeType = "uniswap_v2"
ExchangeSushiSwap ExchangeType = "sushiswap"
ExchangeCamelot ExchangeType = "camelot"
ExchangeBalancer ExchangeType = "balancer"
ExchangeTraderJoe ExchangeType = "traderjoe"
ExchangeRamses ExchangeType = "ramses"
ExchangeCurve ExchangeType = "curve"
ExchangeKyber ExchangeType = "kyber"
ExchangeUniswapV4 ExchangeType = "uniswap_v4"
)
// ExchangePricer interface for exchange-specific price calculations
type ExchangePricer interface {
GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error)
CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error)
ValidatePoolData(poolData *PoolData) error
}
// PoolData represents universal pool data structure
type PoolData struct {
Address string
ExchangeType ExchangeType
Token0 TokenInfo
Token1 TokenInfo
Reserve0 *UniversalDecimal
Reserve1 *UniversalDecimal
Fee *UniversalDecimal // Fee as percentage (e.g., 0.003 for 0.3%)
// Uniswap V3 specific
SqrtPriceX96 *big.Int
Tick *big.Int
Liquidity *big.Int
// Curve specific
A *big.Int // Amplification coefficient
// Balancer specific
Weights []*UniversalDecimal // Token weights
SwapFeeRate *UniversalDecimal // Swap fee rate
}
// TokenInfo represents token metadata
type TokenInfo struct {
Address string
Symbol string
Decimals uint8
}
// ExchangePricingEngine manages all exchange-specific pricing logic
type ExchangePricingEngine struct {
decimalConverter *DecimalConverter
pricers map[ExchangeType]ExchangePricer
}
// NewExchangePricingEngine creates a new pricing engine with all exchange support
func NewExchangePricingEngine() *ExchangePricingEngine {
dc := NewDecimalConverter()
engine := &ExchangePricingEngine{
decimalConverter: dc,
pricers: make(map[ExchangeType]ExchangePricer),
}
// Register all exchange pricers
engine.pricers[ExchangeUniswapV3] = NewUniswapV3Pricer(dc)
engine.pricers[ExchangeUniswapV2] = NewUniswapV2Pricer(dc)
engine.pricers[ExchangeSushiSwap] = NewSushiSwapPricer(dc)
engine.pricers[ExchangeCamelot] = NewCamelotPricer(dc)
engine.pricers[ExchangeBalancer] = NewBalancerPricer(dc)
engine.pricers[ExchangeTraderJoe] = NewTraderJoePricer(dc)
engine.pricers[ExchangeRamses] = NewRamsesPricer(dc)
engine.pricers[ExchangeCurve] = NewCurvePricer(dc)
return engine
}
// GetExchangePricer returns the appropriate pricer for an exchange
func (engine *ExchangePricingEngine) GetExchangePricer(exchangeType ExchangeType) (ExchangePricer, error) {
pricer, exists := engine.pricers[exchangeType]
if !exists {
return nil, fmt.Errorf("unsupported exchange type: %s", exchangeType)
}
return pricer, nil
}
// CalculateSpotPrice gets spot price from any exchange
func (engine *ExchangePricingEngine) CalculateSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
pricer, err := engine.GetExchangePricer(poolData.ExchangeType)
if err != nil {
return nil, err
}
return pricer.GetSpotPrice(poolData)
}
// UniswapV3Pricer implements Uniswap V3 concentrated liquidity pricing
type UniswapV3Pricer struct {
dc *DecimalConverter
}
func NewUniswapV3Pricer(dc *DecimalConverter) *UniswapV3Pricer {
return &UniswapV3Pricer{dc: dc}
}
func (p *UniswapV3Pricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
if poolData.SqrtPriceX96 == nil {
return nil, fmt.Errorf("missing sqrtPriceX96 for Uniswap V3 pool")
}
// Use cached function for optimized calculation
// Convert sqrtPriceX96 to actual price using cached constants
// price = sqrtPriceX96^2 / 2^192
price := SqrtPriceX96ToPriceCached(poolData.SqrtPriceX96)
// Adjust for decimal differences between tokens
if poolData.Token0.Decimals != poolData.Token1.Decimals {
decimalDiff := int(poolData.Token1.Decimals) - int(poolData.Token0.Decimals)
adjustment := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimalDiff)), nil))
price.Mul(price, adjustment)
}
// Convert back to big.Int with appropriate precision
priceInt := new(big.Int)
priceScaled := new(big.Float).Mul(price, new(big.Float).SetInt(p.dc.getScalingFactor(18)))
priceScaled.Int(priceInt)
return NewUniversalDecimal(priceInt, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
}
func (p *UniswapV3Pricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Uniswap V3 concentrated liquidity calculation
// This is a simplified version - production would need full tick math
if poolData.Liquidity == nil || poolData.Liquidity.Sign() == 0 {
return nil, fmt.Errorf("insufficient liquidity in Uniswap V3 pool")
}
// Apply fee
feeAmount, err := p.dc.Multiply(amountIn, poolData.Fee, amountIn.Decimals, "FEE")
if err != nil {
return nil, fmt.Errorf("error calculating fee: %w", err)
}
amountInAfterFee, err := p.dc.Subtract(amountIn, feeAmount)
if err != nil {
return nil, fmt.Errorf("error subtracting fee: %w", err)
}
// Simplified constant product formula for demonstration
// Real implementation would use tick mathematics
numerator, err := p.dc.Multiply(amountInAfterFee, poolData.Reserve1, poolData.Reserve1.Decimals, "TEMP")
if err != nil {
return nil, err
}
denominator, err := p.dc.Add(poolData.Reserve0, amountInAfterFee)
if err != nil {
return nil, err
}
return p.dc.Divide(numerator, denominator, poolData.Token1.Decimals, poolData.Token1.Symbol)
}
func (p *UniswapV3Pricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Reverse calculation for Uniswap V3
if poolData.Reserve1.IsZero() || amountOut.Value.Cmp(poolData.Reserve1.Value) >= 0 {
return nil, fmt.Errorf("insufficient liquidity for requested output amount")
}
// Simplified reverse calculation
numerator, err := p.dc.Multiply(poolData.Reserve0, amountOut, poolData.Reserve0.Decimals, "TEMP")
if err != nil {
return nil, err
}
denominator, err := p.dc.Subtract(poolData.Reserve1, amountOut)
if err != nil {
return nil, err
}
amountInBeforeFee, err := p.dc.Divide(numerator, denominator, poolData.Token0.Decimals, "TEMP")
if err != nil {
return nil, err
}
// Add fee
feeMultiplier, err := p.dc.FromString("1", 18, "FEE_MULT")
if err != nil {
return nil, err
}
oneMinusFee, err := p.dc.Subtract(feeMultiplier, poolData.Fee)
if err != nil {
return nil, err
}
return p.dc.Divide(amountInBeforeFee, oneMinusFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
}
func (p *UniswapV3Pricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Calculate price before trade
priceBefore, err := p.GetSpotPrice(poolData)
if err != nil {
return nil, fmt.Errorf("error getting spot price: %w", err)
}
// Calculate amount out
amountOut, err := p.CalculateAmountOut(amountIn, poolData)
if err != nil {
return nil, fmt.Errorf("error calculating amount out: %w", err)
}
// Calculate effective price
effectivePrice, err := p.dc.Divide(amountOut, amountIn, 18, "EFFECTIVE_PRICE")
if err != nil {
return nil, fmt.Errorf("error calculating effective price: %w", err)
}
// Calculate price impact as percentage
priceDiff, err := p.dc.Subtract(priceBefore, effectivePrice)
if err != nil {
return nil, fmt.Errorf("error calculating price difference: %w", err)
}
return p.dc.CalculatePercentage(priceDiff, priceBefore)
}
func (p *UniswapV3Pricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
if poolData.Liquidity == nil {
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
}
return NewUniversalDecimal(poolData.Liquidity, 18, "LIQUIDITY")
}
func (p *UniswapV3Pricer) ValidatePoolData(poolData *PoolData) error {
if poolData.SqrtPriceX96 == nil {
return fmt.Errorf("Uniswap V3 pool missing sqrtPriceX96")
}
if poolData.Liquidity == nil {
return fmt.Errorf("Uniswap V3 pool missing liquidity")
}
if poolData.Fee == nil {
return fmt.Errorf("Uniswap V3 pool missing fee")
}
return nil
}
// UniswapV2Pricer implements Uniswap V2 / SushiSwap constant product pricing
type UniswapV2Pricer struct {
dc *DecimalConverter
}
func NewUniswapV2Pricer(dc *DecimalConverter) *UniswapV2Pricer {
return &UniswapV2Pricer{dc: dc}
}
func (p *UniswapV2Pricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
if poolData.Reserve0.IsZero() {
return nil, fmt.Errorf("zero reserve0 in constant product pool")
}
return p.dc.Divide(poolData.Reserve1, poolData.Reserve0, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
}
func (p *UniswapV2Pricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Uniswap V2 constant product formula: x * y = k
// amountOut = (amountIn * 997 * reserveOut) / (reserveIn * 1000 + amountIn * 997)
// Apply fee (0.3% = 997/1000 remaining)
feeNumerator, _ := p.dc.FromString("997", 0, "FEE_NUM")
feeDenominator, _ := p.dc.FromString("1000", 0, "FEE_DEN")
amountInWithFee, err := p.dc.Multiply(amountIn, feeNumerator, amountIn.Decimals, "TEMP")
if err != nil {
return nil, err
}
numerator, err := p.dc.Multiply(amountInWithFee, poolData.Reserve1, poolData.Reserve1.Decimals, "TEMP")
if err != nil {
return nil, err
}
reserveInScaled, err := p.dc.Multiply(poolData.Reserve0, feeDenominator, poolData.Reserve0.Decimals, "TEMP")
if err != nil {
return nil, err
}
denominator, err := p.dc.Add(reserveInScaled, amountInWithFee)
if err != nil {
return nil, err
}
return p.dc.Divide(numerator, denominator, poolData.Token1.Decimals, poolData.Token1.Symbol)
}
func (p *UniswapV2Pricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Reverse calculation for constant product
feeNumerator, _ := p.dc.FromString("1000", 0, "FEE_NUM")
feeDenominator, _ := p.dc.FromString("997", 0, "FEE_DEN")
numerator, err := p.dc.Multiply(poolData.Reserve0, amountOut, poolData.Reserve0.Decimals, "TEMP")
if err != nil {
return nil, err
}
numeratorWithFee, err := p.dc.Multiply(numerator, feeNumerator, numerator.Decimals, "TEMP")
if err != nil {
return nil, err
}
denominator, err := p.dc.Subtract(poolData.Reserve1, amountOut)
if err != nil {
return nil, err
}
denominatorWithFee, err := p.dc.Multiply(denominator, feeDenominator, denominator.Decimals, "TEMP")
if err != nil {
return nil, err
}
return p.dc.Divide(numeratorWithFee, denominatorWithFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
}
func (p *UniswapV2Pricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Similar to Uniswap V3 implementation
priceBefore, err := p.GetSpotPrice(poolData)
if err != nil {
return nil, err
}
amountOut, err := p.CalculateAmountOut(amountIn, poolData)
if err != nil {
return nil, err
}
effectivePrice, err := p.dc.Divide(amountOut, amountIn, 18, "EFFECTIVE_PRICE")
if err != nil {
return nil, err
}
priceDiff, err := p.dc.Subtract(priceBefore, effectivePrice)
if err != nil {
return nil, err
}
return p.dc.CalculatePercentage(priceDiff, priceBefore)
}
func (p *UniswapV2Pricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
// Geometric mean of reserves
product, err := p.dc.Multiply(poolData.Reserve0, poolData.Reserve1, 18, "LIQUIDITY")
if err != nil {
return nil, err
}
// Simplified square root - in production use precise sqrt algorithm
sqrt := new(big.Int).Sqrt(product.Value)
return NewUniversalDecimal(sqrt, 18, "LIQUIDITY")
}
func (p *UniswapV2Pricer) ValidatePoolData(poolData *PoolData) error {
if poolData.Reserve0 == nil || poolData.Reserve1 == nil {
return fmt.Errorf("missing reserves for constant product pool")
}
if poolData.Reserve0.IsZero() || poolData.Reserve1.IsZero() {
return fmt.Errorf("zero reserves in constant product pool")
}
return nil
}
// SushiSwapPricer uses same logic as Uniswap V2
type SushiSwapPricer struct {
*UniswapV2Pricer
}
func NewSushiSwapPricer(dc *DecimalConverter) *SushiSwapPricer {
return &SushiSwapPricer{NewUniswapV2Pricer(dc)}
}
// CamelotPricer - Algebra-based DEX on Arbitrum
type CamelotPricer struct {
*UniswapV3Pricer
}
func NewCamelotPricer(dc *DecimalConverter) *CamelotPricer {
return &CamelotPricer{NewUniswapV3Pricer(dc)}
}
// BalancerPricer - Weighted pool implementation
type BalancerPricer struct {
dc *DecimalConverter
}
func NewBalancerPricer(dc *DecimalConverter) *BalancerPricer {
return &BalancerPricer{dc: dc}
}
func (p *BalancerPricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
if len(poolData.Weights) < 2 {
return nil, fmt.Errorf("insufficient weights for Balancer pool")
}
// Balancer spot price = (reserveOut/weightOut) / (reserveIn/weightIn)
reserveOutWeighted, err := p.dc.Divide(poolData.Reserve1, poolData.Weights[1], 18, "TEMP")
if err != nil {
return nil, err
}
reserveInWeighted, err := p.dc.Divide(poolData.Reserve0, poolData.Weights[0], 18, "TEMP")
if err != nil {
return nil, err
}
return p.dc.Divide(reserveOutWeighted, reserveInWeighted, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
}
func (p *BalancerPricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Simplified Balancer calculation - production needs full weighted math
return p.GetSpotPrice(poolData)
}
func (p *BalancerPricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
spotPrice, err := p.GetSpotPrice(poolData)
if err != nil {
return nil, err
}
return p.dc.Divide(amountOut, spotPrice, poolData.Token0.Decimals, poolData.Token0.Symbol)
}
func (p *BalancerPricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Placeholder - would implement Balancer-specific price impact
return NewUniversalDecimal(big.NewInt(0), 4, "PERCENT")
}
func (p *BalancerPricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
}
func (p *BalancerPricer) ValidatePoolData(poolData *PoolData) error {
if len(poolData.Weights) < 2 {
return fmt.Errorf("Balancer pool missing weights")
}
return nil
}
// Placeholder implementations for other exchanges
func NewTraderJoePricer(dc *DecimalConverter) *UniswapV2Pricer { return NewUniswapV2Pricer(dc) }
func NewRamsesPricer(dc *DecimalConverter) *UniswapV3Pricer { return NewUniswapV3Pricer(dc) }
// CurvePricer - Stable swap implementation
type CurvePricer struct {
dc *DecimalConverter
}
func NewCurvePricer(dc *DecimalConverter) *CurvePricer {
return &CurvePricer{dc: dc}
}
func (p *CurvePricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
// Curve stable swap pricing - simplified version
if poolData.A == nil {
return nil, fmt.Errorf("missing amplification coefficient for Curve pool")
}
// For stable swaps, price should be close to 1:1
return NewUniversalDecimal(p.dc.getScalingFactor(18), 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
}
func (p *CurvePricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Simplified stable swap calculation
// Real implementation would use Newton's method for stable swap invariant
feeAmount, err := p.dc.Multiply(amountIn, poolData.Fee, amountIn.Decimals, "FEE")
if err != nil {
return nil, err
}
return p.dc.Subtract(amountIn, feeAmount)
}
func (p *CurvePricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Reverse stable swap calculation
feeMultiplier, _ := p.dc.FromString("1", 18, "FEE_MULT")
oneMinusFee, err := p.dc.Subtract(feeMultiplier, poolData.Fee)
if err != nil {
return nil, err
}
return p.dc.Divide(amountOut, oneMinusFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
}
func (p *CurvePricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
// Curve pools have minimal price impact for stable pairs
return NewUniversalDecimal(big.NewInt(1000), 4, "PERCENT") // 0.1%
}
func (p *CurvePricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
}
func (p *CurvePricer) ValidatePoolData(poolData *PoolData) error {
if poolData.A == nil {
return fmt.Errorf("Curve pool missing amplification coefficient")
}
return nil
}

BIN
orig/pkg/math/math.test Executable file

Binary file not shown.

View File

@@ -0,0 +1,62 @@
package math
import (
"math/big"
)
// MockGasEstimator implements GasEstimator for testing purposes
type MockGasEstimator struct {
currentGasPrice *UniversalDecimal
}
// NewMockGasEstimator creates a new mock gas estimator
func NewMockGasEstimator() *MockGasEstimator {
dc := NewDecimalConverter()
gasPrice, _ := dc.FromString("20", 9, "GWEI") // 20 gwei default
return &MockGasEstimator{
currentGasPrice: gasPrice,
}
}
// EstimateSwapGas estimates gas for a swap
func (mge *MockGasEstimator) EstimateSwapGas(exchange string, poolData *PoolData) (uint64, error) {
// Different exchanges have different gas costs
switch exchange {
case "uniswap_v3":
return 150000, nil // 150k gas for Uniswap V3
case "uniswap_v2", "sushiswap":
return 120000, nil // 120k gas for Uniswap V2/SushiSwap
case "camelot":
return 130000, nil // 130k gas for Camelot
case "balancer":
return 200000, nil // 200k gas for Balancer
case "curve":
return 180000, nil // 180k gas for Curve
default:
return 150000, nil // Default to 150k gas
}
}
// EstimateFlashSwapGas estimates gas for a flash swap
func (mge *MockGasEstimator) EstimateFlashSwapGas(route []*PoolData) (uint64, error) {
// Flash swap overhead varies by complexity
baseGas := uint64(200000) // Base flash swap overhead
// Add gas for each hop
hopGas := uint64(len(route)) * 50000
return baseGas + hopGas, nil
}
// GetCurrentGasPrice gets the current gas price
func (mge *MockGasEstimator) GetCurrentGasPrice() (*UniversalDecimal, error) {
return mge.currentGasPrice, nil
}
// SetCurrentGasPrice sets the current gas price for testing
func (mge *MockGasEstimator) SetCurrentGasPrice(gasPriceGwei int64) {
dc := NewDecimalConverter()
gasPrice, _ := dc.FromString(big.NewInt(gasPriceGwei).String(), 9, "GWEI")
mge.currentGasPrice = gasPrice
}

View File

@@ -0,0 +1,293 @@
package math
import (
"math/big"
"math/rand"
"testing"
"time"
)
// TestDecimalPrecisionPreservation tests that decimal operations preserve precision
func TestDecimalPrecisionPreservation(t *testing.T) {
dc := NewDecimalConverter()
testCases := []struct {
name string
value string
decimals uint8
symbol string
}{
{"ETH precision", "1000000000000000000", 18, "ETH"}, // 1 ETH
{"USDC precision", "1000000", 6, "USDC"}, // 1 USDC
{"WBTC precision", "100000000", 8, "WBTC"}, // 1 WBTC
{"Small amount", "1", 18, "ETH"}, // 1 wei
{"Large amount", "1000000000000000000000", 18, "ETH"}, // 1000 ETH
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create decimal from string
decimal, err := dc.FromString(tc.value, tc.decimals, tc.symbol)
if err != nil {
t.Fatalf("Failed to create decimal: %v", err)
}
// Convert to string and back
humanReadable := dc.ToHumanReadable(decimal)
backToDecimal, err := dc.FromString(humanReadable, tc.decimals, tc.symbol)
if err != nil {
t.Fatalf("Failed to convert back from string: %v", err)
}
// Compare values
if decimal.Value.Cmp(backToDecimal.Value) != 0 {
t.Errorf("Precision lost in round-trip conversion")
t.Errorf("Original: %s", decimal.Value.String())
t.Errorf("Round-trip: %s", backToDecimal.Value.String())
t.Errorf("Human readable: %s", humanReadable)
}
})
}
}
// TestArithmeticOperations tests basic arithmetic with different decimal precisions
func TestArithmeticOperations(t *testing.T) {
dc := NewDecimalConverter()
// Create test values with different precisions
eth1, _ := dc.FromString("1000000000000000000", 18, "ETH") // 1 ETH
eth2, _ := dc.FromString("2000000000000000000", 18, "ETH") // 2 ETH
tests := []struct {
name string
op string
a, b *UniversalDecimal
expected string
}{
{
name: "ETH addition",
op: "add",
a: eth1,
b: eth2,
expected: "3000000000000000000", // 3 ETH
},
{
name: "ETH subtraction",
op: "sub",
a: eth2,
b: eth1,
expected: "1000000000000000000", // 1 ETH
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var result *UniversalDecimal
var err error
switch test.op {
case "add":
result, err = dc.Add(test.a, test.b)
case "sub":
result, err = dc.Subtract(test.a, test.b)
default:
t.Fatalf("Unknown operation: %s", test.op)
}
if err != nil {
t.Fatalf("Operation failed: %v", err)
}
if result.Value.String() != test.expected {
t.Errorf("Expected %s, got %s", test.expected, result.Value.String())
}
})
}
}
// TestPercentageCalculations tests percentage calculations for precision
func TestPercentageCalculations(t *testing.T) {
dc := NewDecimalConverter()
testCases := []struct {
name string
numerator string
denominator string
decimals uint8
expectedRange [2]float64 // [min, max] acceptable range
}{
{
name: "1% calculation",
numerator: "10000000000000000", // 0.01 ETH
denominator: "1000000000000000000", // 1 ETH
decimals: 18,
expectedRange: [2]float64{0.99, 1.01}, // 1% ± 0.01%
},
{
name: "50% calculation",
numerator: "500000000000000000", // 0.5 ETH
denominator: "1000000000000000000", // 1 ETH
decimals: 18,
expectedRange: [2]float64{49.9, 50.1}, // 50% ± 0.1%
},
{
name: "Small percentage",
numerator: "1000000000000000", // 0.001 ETH
denominator: "1000000000000000000", // 1 ETH
decimals: 18,
expectedRange: [2]float64{0.099, 0.101}, // 0.1% ± 0.001%
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
num, err := dc.FromString(tc.numerator, tc.decimals, "ETH")
if err != nil {
t.Fatalf("Failed to create numerator: %v", err)
}
denom, err := dc.FromString(tc.denominator, tc.decimals, "ETH")
if err != nil {
t.Fatalf("Failed to create denominator: %v", err)
}
percentage, err := dc.CalculatePercentage(num, denom)
if err != nil {
t.Fatalf("Failed to calculate percentage: %v", err)
}
percentageFloat, _ := percentage.Value.Float64()
// Convert from raw value to actual percentage (divide by 10^decimals)
// Since percentage has 4 decimals, divide by 10000 to get actual percentage value
actualPercentage := percentageFloat / 10000.0
t.Logf("Calculated percentage: %.6f%%", actualPercentage)
if actualPercentage < tc.expectedRange[0] || actualPercentage > tc.expectedRange[1] {
t.Errorf("Percentage %.6f%% outside expected range [%.3f%%, %.3f%%]",
actualPercentage, tc.expectedRange[0], tc.expectedRange[1])
}
})
}
}
// PropertyTest tests mathematical properties like commutativity, associativity
func TestMathematicalProperties(t *testing.T) {
dc := NewDecimalConverter()
rand.Seed(time.Now().UnixNano())
// Generate random test values
for i := 0; i < 100; i++ {
// Generate random big integers
val1 := big.NewInt(rand.Int63n(1000000000000000000)) // Up to 1 ETH
val2 := big.NewInt(rand.Int63n(1000000000000000000))
val3 := big.NewInt(rand.Int63n(1000000000000000000))
a, _ := NewUniversalDecimal(val1, 18, "ETH")
b, _ := NewUniversalDecimal(val2, 18, "ETH")
c, _ := NewUniversalDecimal(val3, 18, "ETH")
// Test commutativity: a + b = b + a
ab, err1 := dc.Add(a, b)
ba, err2 := dc.Add(b, a)
if err1 != nil || err2 != nil {
t.Fatalf("Addition failed: %v, %v", err1, err2)
}
if ab.Value.Cmp(ba.Value) != 0 {
t.Errorf("Addition not commutative: %s + %s = %s, %s + %s = %s",
a.Value.String(), b.Value.String(), ab.Value.String(),
b.Value.String(), a.Value.String(), ba.Value.String())
}
// Test associativity: (a + b) + c = a + (b + c)
ab_c, err1 := dc.Add(ab, c)
bc, err2 := dc.Add(b, c)
a_bc, err3 := dc.Add(a, bc)
if err1 != nil || err2 != nil || err3 != nil {
continue // Skip this iteration if any operation fails
}
if ab_c.Value.Cmp(a_bc.Value) != 0 {
t.Errorf("Addition not associative")
}
}
}
// BenchmarkDecimalOperations benchmarks decimal operations
func BenchmarkDecimalOperations(b *testing.B) {
dc := NewDecimalConverter()
val1, _ := dc.FromString("1000000000000000000", 18, "ETH")
val2, _ := dc.FromString("2000000000000000000", 18, "ETH")
b.Run("Addition", func(b *testing.B) {
for i := 0; i < b.N; i++ {
dc.Add(val1, val2)
}
})
b.Run("Subtraction", func(b *testing.B) {
for i := 0; i < b.N; i++ {
dc.Subtract(val2, val1)
}
})
b.Run("Percentage", func(b *testing.B) {
for i := 0; i < b.N; i++ {
dc.CalculatePercentage(val1, val2)
}
})
}
// FuzzDecimalOperations fuzzes decimal operations for edge cases
func FuzzDecimalOperations(f *testing.F) {
// Seed with known values
f.Add(int64(1000000000000000000), int64(2000000000000000000)) // 1 ETH, 2 ETH
f.Add(int64(1), int64(1000000000000000000)) // 1 wei, 1 ETH
f.Add(int64(0), int64(1000000000000000000)) // 0, 1 ETH
f.Fuzz(func(t *testing.T, val1, val2 int64) {
// Ensure positive values
if val1 < 0 {
val1 = -val1
}
if val2 <= 0 {
return // Skip zero/negative denominators
}
dc := NewDecimalConverter()
a, err := NewUniversalDecimal(big.NewInt(val1), 18, "ETH")
if err != nil {
return
}
b, err := NewUniversalDecimal(big.NewInt(val2), 18, "ETH")
if err != nil {
return
}
// Test addition doesn't panic
_, err = dc.Add(a, b)
if err != nil {
t.Errorf("Addition failed: %v", err)
}
// Test subtraction doesn't panic (if a >= b)
if val1 >= val2 {
_, err = dc.Subtract(a, b)
if err != nil {
t.Errorf("Subtraction failed: %v", err)
}
}
// Test percentage calculation doesn't panic
_, err = dc.CalculatePercentage(a, b)
if err != nil {
t.Errorf("Percentage calculation failed: %v", err)
}
})
}

View File

@@ -0,0 +1,192 @@
package math
import (
"fmt"
"math"
"math/big"
)
// PriceImpactCalculator provides a unified interface for calculating price impact across all protocols
type PriceImpactCalculator struct {
mathCalculator *MathCalculator
}
// NewPriceImpactCalculator creates a new price impact calculator
func NewPriceImpactCalculator() *PriceImpactCalculator {
return &PriceImpactCalculator{
mathCalculator: NewMathCalculator(),
}
}
// CalculatePriceImpact calculates price impact for any supported protocol
func (pic *PriceImpactCalculator) CalculatePriceImpact(
protocol string,
amountIn, reserveIn, reserveOut *big.Int,
sqrtPriceX96, liquidity *big.Int, // For Uniswap V3 and Kyber
) (float64, error) {
switch protocol {
case "uniswap_v2", "sushiswap":
return pic.mathCalculator.uniswapV2.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
case "uniswap_v3", "camelot_v3":
return pic.mathCalculator.uniswapV3.CalculatePriceImpact(amountIn, sqrtPriceX96, liquidity)
case "curve":
return pic.mathCalculator.curve.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
case "kyber_elastic", "kyber_classic":
return pic.mathCalculator.kyber.CalculatePriceImpact(amountIn, sqrtPriceX96, liquidity)
case "balancer":
return pic.mathCalculator.balancer.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
case "constant_sum":
return pic.mathCalculator.constantSum.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
case "algebra_v1":
return pic.calculateAlgebraPriceImpact(amountIn, reserveIn, reserveOut)
case "integral":
return pic.calculateIntegralPriceImpact(amountIn, reserveIn, reserveOut)
case "oneinch":
return pic.calculateOneInchPriceImpact(amountIn, reserveIn, reserveOut)
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}
// calculateAlgebraPriceImpact calculates price impact for Algebra V1.9
func (pic *PriceImpactCalculator) calculateAlgebraPriceImpact(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return 0, fmt.Errorf("invalid amounts")
}
// Calculate new reserves after swap
amountOut, err := NewAlgebraV1Math().CalculateAmountOutAlgebra(amountIn, reserveIn, reserveOut, 500)
if err != nil {
return 0, err
}
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
// Calculate price before and after swap
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
// Calculate price impact
impact := new(big.Float).Sub(priceBefore, priceAfter)
impact.Quo(impact, priceBefore)
impactFloat, _ := impact.Float64()
return math.Abs(impactFloat), nil
}
// calculateIntegralPriceImpact calculates price impact for Integral
func (pic *PriceImpactCalculator) calculateIntegralPriceImpact(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return 0, fmt.Errorf("invalid amounts")
}
// Calculate new reserves after swap
amountOut, err := NewIntegralMath().CalculateAmountOutIntegral(amountIn, reserveIn, reserveOut, 100)
if err != nil {
return 0, err
}
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
// Calculate price before and after swap
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
// Calculate price impact
impact := new(big.Float).Sub(priceBefore, priceAfter)
impact.Quo(impact, priceBefore)
impactFloat, _ := impact.Float64()
return math.Abs(impactFloat), nil
}
// calculateOneInchPriceImpact calculates price impact for 1Inch aggregation
func (pic *PriceImpactCalculator) calculateOneInchPriceImpact(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
return 0, fmt.Errorf("invalid amounts")
}
// 1Inch aggregates multiple DEXs, so we'll calculate an effective price impact
// based on the overall route
// For this implementation, we'll calculate using a simple weighted average
// of the price impact across different paths
// Calculate new reserves after swap (simplified)
amountOut, err := NewOneInchMath().CalculateAmountOutOneInch(amountIn, []PathElement{
{
Protocol: "uniswap_v2",
ReserveIn: reserveIn,
ReserveOut: reserveOut,
Fee: 3000,
},
})
if err != nil {
return 0, err
}
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
// Calculate price before and after swap
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
// Calculate price impact
impact := new(big.Float).Sub(priceBefore, priceAfter)
impact.Quo(impact, priceBefore)
impactFloat, _ := impact.Float64()
return math.Abs(impactFloat), nil
}
// CalculatePriceMovementThreshold determines if a swap moves price beyond a certain threshold
func (pic *PriceImpactCalculator) CalculatePriceMovementThreshold(
protocol string,
amountIn, reserveIn, reserveOut *big.Int,
sqrtPriceX96, liquidity *big.Int, // For Uniswap V3 and Kyber
threshold float64,
) (bool, float64, error) {
impact, err := pic.CalculatePriceImpact(protocol, amountIn, reserveIn, reserveOut, sqrtPriceX96, liquidity)
if err != nil {
return false, 0, err
}
movesPrice := impact >= threshold
return movesPrice, impact, nil
}
// CalculatePriceImpactWithSlippage combines price impact and slippage calculations
func (pic *PriceImpactCalculator) CalculatePriceImpactWithSlippage(
protocol string,
amountIn, reserveIn, reserveOut *big.Int,
sqrtPriceX96, liquidity *big.Int, // For Uniswap V3 and Kyber
) (float64, float64, error) {
// Calculate price impact
priceImpact, err := pic.CalculatePriceImpact(protocol, amountIn, reserveIn, reserveOut, sqrtPriceX96, liquidity)
if err != nil {
return 0, 0, err
}
// Calculate expected output
mathCalculator := pic.mathCalculator.GetMathForExchange(protocol)
expectedOut, err := mathCalculator.CalculateAmountOut(amountIn, reserveIn, reserveOut, 0)
if err != nil {
return 0, 0, err
}
// Calculate actual output after slippage (simplified)
actualOut := new(big.Int).Set(expectedOut)
// Calculate slippage
slippage, err := mathCalculator.CalculateSlippage(expectedOut, actualOut)
if err != nil {
return 0, 0, err
}
return priceImpact, slippage, nil
}