feat: create v2-prep branch with comprehensive planning
Restructured project for V2 refactor: **Structure Changes:** - Moved all V1 code to orig/ folder (preserved with git mv) - Created docs/planning/ directory - Added orig/README_V1.md explaining V1 preservation **Planning Documents:** - 00_V2_MASTER_PLAN.md: Complete architecture overview - Executive summary of critical V1 issues - High-level component architecture diagrams - 5-phase implementation roadmap - Success metrics and risk mitigation - 07_TASK_BREAKDOWN.md: Atomic task breakdown - 99+ hours of detailed tasks - Every task < 2 hours (atomic) - Clear dependencies and success criteria - Organized by implementation phase **V2 Key Improvements:** - Per-exchange parsers (factory pattern) - Multi-layer strict validation - Multi-index pool cache - Background validation pipeline - Comprehensive observability **Critical Issues Addressed:** - Zero address tokens (strict validation + cache enrichment) - Parsing accuracy (protocol-specific parsers) - No audit trail (background validation channel) - Inefficient lookups (multi-index cache) - Stats disconnection (event-driven metrics) Next Steps: 1. Review planning documents 2. Begin Phase 1: Foundation (P1-001 through P1-010) 3. Implement parsers in Phase 2 4. Build cache system in Phase 3 5. Add validation pipeline in Phase 4 6. Migrate and test in Phase 5 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
737
orig/pkg/math/arbitrage_calculator.go
Normal file
737
orig/pkg/math/arbitrage_calculator.go
Normal file
@@ -0,0 +1,737 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
|
||||
"github.com/fraktal/mev-beta/pkg/security"
|
||||
"github.com/fraktal/mev-beta/pkg/types"
|
||||
)
|
||||
|
||||
// Use the canonical ArbitrageOpportunity from types package
|
||||
// Extended fields for advanced calculations can be added as needed
|
||||
|
||||
// ExchangeStep represents one step in the arbitrage execution
|
||||
type ExchangeStep struct {
|
||||
Exchange ExchangeType
|
||||
Pool *PoolData
|
||||
TokenIn TokenInfo
|
||||
TokenOut TokenInfo
|
||||
AmountIn *UniversalDecimal
|
||||
AmountOut *UniversalDecimal
|
||||
PriceImpact *UniversalDecimal
|
||||
EstimatedGas uint64
|
||||
}
|
||||
|
||||
// RiskAssessment evaluates the risk level of an arbitrage opportunity
|
||||
type RiskAssessment struct {
|
||||
Overall RiskLevel
|
||||
Liquidity RiskLevel
|
||||
PriceImpact RiskLevel
|
||||
Competition RiskLevel
|
||||
Slippage RiskLevel
|
||||
GasPrice RiskLevel
|
||||
Warnings []string
|
||||
OverallRisk float64 // Numeric representation of overall risk (0.0 to 1.0)
|
||||
}
|
||||
|
||||
// RiskLevel represents different risk categories
|
||||
type RiskLevel string
|
||||
|
||||
const (
|
||||
RiskLow RiskLevel = "low"
|
||||
RiskMedium RiskLevel = "medium"
|
||||
RiskHigh RiskLevel = "high"
|
||||
RiskCritical RiskLevel = "critical"
|
||||
)
|
||||
|
||||
// ArbitrageCalculator performs precise arbitrage calculations
|
||||
type ArbitrageCalculator struct {
|
||||
pricingEngine *ExchangePricingEngine
|
||||
decimalConverter *DecimalConverter
|
||||
gasEstimator GasEstimator
|
||||
|
||||
// Configuration
|
||||
minProfitThreshold *UniversalDecimal
|
||||
maxPriceImpact *UniversalDecimal
|
||||
maxSlippage *UniversalDecimal
|
||||
maxGasPriceGwei *UniversalDecimal
|
||||
}
|
||||
|
||||
// GasEstimator interface for gas cost calculations
|
||||
type GasEstimator interface {
|
||||
EstimateSwapGas(exchange ExchangeType, poolData *PoolData) (uint64, error)
|
||||
EstimateFlashSwapGas(route []*PoolData) (uint64, error)
|
||||
GetCurrentGasPrice() (*UniversalDecimal, error)
|
||||
}
|
||||
|
||||
// NewArbitrageCalculator creates a new arbitrage calculator
|
||||
func NewArbitrageCalculator(gasEstimator GasEstimator) *ArbitrageCalculator {
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
// Default configuration
|
||||
minProfit, _ := dc.FromString("0.01", 18, "ETH") // 0.01 ETH minimum
|
||||
maxImpact, _ := dc.FromString("0.02", 4, "PERCENT") // 2% max price impact
|
||||
maxSlip, _ := dc.FromString("0.01", 4, "PERCENT") // 1% max slippage
|
||||
maxGas, _ := dc.FromString("50", 9, "GWEI") // 50 gwei max gas
|
||||
|
||||
return &ArbitrageCalculator{
|
||||
pricingEngine: NewExchangePricingEngine(),
|
||||
decimalConverter: dc,
|
||||
gasEstimator: gasEstimator,
|
||||
minProfitThreshold: minProfit,
|
||||
maxPriceImpact: maxImpact,
|
||||
maxSlippage: maxSlip,
|
||||
maxGasPriceGwei: maxGas,
|
||||
}
|
||||
}
|
||||
|
||||
func toDecimalAmount(ud *UniversalDecimal) types.DecimalAmount {
|
||||
if ud == nil {
|
||||
return types.DecimalAmount{}
|
||||
}
|
||||
return types.DecimalAmount{
|
||||
Value: ud.Value.String(),
|
||||
Decimals: ud.Decimals,
|
||||
Symbol: ud.Symbol,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateArbitrageOpportunity performs comprehensive arbitrage analysis
|
||||
func (calc *ArbitrageCalculator) CalculateArbitrageOpportunity(
|
||||
path []*PoolData,
|
||||
inputAmount *UniversalDecimal,
|
||||
inputToken TokenInfo,
|
||||
outputToken TokenInfo,
|
||||
) (*types.ArbitrageOpportunity, error) {
|
||||
|
||||
if len(path) == 0 {
|
||||
return nil, fmt.Errorf("empty arbitrage path")
|
||||
}
|
||||
|
||||
// Step 1: Calculate execution route with amounts
|
||||
route, err := calc.calculateExecutionRoute(path, inputAmount, inputToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating execution route: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Get final output amount
|
||||
finalOutput := route[len(route)-1].AmountOut
|
||||
|
||||
// Step 3: Calculate gas costs
|
||||
totalGasCost, err := calc.calculateTotalGasCost(route)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating gas cost: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Calculate profits (convert to common denomination - ETH)
|
||||
grossProfit, netProfit, profitPercentage, err := calc.calculateProfits(
|
||||
inputAmount, finalOutput, totalGasCost, inputToken, outputToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating profits: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Calculate total price impact
|
||||
totalPriceImpact, err := calc.calculateTotalPriceImpact(route)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating price impact: %w", err)
|
||||
}
|
||||
|
||||
// Step 6: Calculate minimum output with slippage (we don't use this in the final result)
|
||||
_, err = calc.calculateMinimumOutput(finalOutput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating minimum output: %w", err)
|
||||
}
|
||||
|
||||
// Step 7: Assess risks
|
||||
riskAssessment := calc.assessRisks(route, totalPriceImpact, netProfit)
|
||||
|
||||
// Step 8: Calculate confidence and execution time
|
||||
confidence := calc.calculateConfidence(riskAssessment, netProfit, totalPriceImpact)
|
||||
executionTime := calc.estimateExecutionTime(route)
|
||||
|
||||
// Convert path to string array
|
||||
pathStrings := make([]string, len(path))
|
||||
for i, pool := range path {
|
||||
pathStrings[i] = pool.Address // Address is already a string
|
||||
}
|
||||
|
||||
// Convert pools to string array
|
||||
poolStrings := make([]string, len(path))
|
||||
for i, pool := range path {
|
||||
poolStrings[i] = pool.Address // Address is already a string
|
||||
}
|
||||
|
||||
opportunity := &types.ArbitrageOpportunity{
|
||||
Path: pathStrings,
|
||||
Pools: poolStrings,
|
||||
AmountIn: inputAmount.Value,
|
||||
RequiredAmount: inputAmount.Value,
|
||||
Profit: grossProfit.Value,
|
||||
NetProfit: netProfit.Value,
|
||||
EstimatedProfit: grossProfit.Value,
|
||||
GasEstimate: totalGasCost.Value,
|
||||
ROI: func() float64 {
|
||||
// Convert percentage from 4-decimal format to actual percentage
|
||||
f, _ := profitPercentage.Value.Float64()
|
||||
return f / 10000.0 // Convert from 4-decimal format to actual percentage
|
||||
}(),
|
||||
Protocol: "multi", // Default protocol for multi-step arbitrage
|
||||
ExecutionTime: executionTime,
|
||||
Confidence: confidence,
|
||||
PriceImpact: func() float64 {
|
||||
// Convert percentage from 4-decimal format to actual percentage
|
||||
f, _ := totalPriceImpact.Value.Float64()
|
||||
return f / 10000.0 // Convert from 4-decimal format to actual percentage
|
||||
}(),
|
||||
MaxSlippage: 0.01, // Default 1% max slippage
|
||||
TokenIn: common.HexToAddress(inputToken.Address),
|
||||
TokenOut: common.HexToAddress(outputToken.Address),
|
||||
Timestamp: time.Now().Unix(),
|
||||
DetectedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
Risk: riskAssessment.OverallRisk,
|
||||
}
|
||||
opportunity.Quantities = &types.OpportunityQuantities{
|
||||
AmountIn: toDecimalAmount(inputAmount),
|
||||
AmountOut: toDecimalAmount(finalOutput),
|
||||
GrossProfit: toDecimalAmount(grossProfit),
|
||||
NetProfit: toDecimalAmount(netProfit),
|
||||
GasCost: toDecimalAmount(totalGasCost),
|
||||
ProfitPercent: toDecimalAmount(profitPercentage),
|
||||
PriceImpact: toDecimalAmount(totalPriceImpact),
|
||||
}
|
||||
|
||||
return opportunity, nil
|
||||
}
|
||||
|
||||
// calculateExecutionRoute calculates amounts through each step of the arbitrage
|
||||
func (calc *ArbitrageCalculator) calculateExecutionRoute(
|
||||
path []*PoolData,
|
||||
inputAmount *UniversalDecimal,
|
||||
inputToken TokenInfo,
|
||||
) ([]ExchangeStep, error) {
|
||||
|
||||
route := make([]ExchangeStep, len(path))
|
||||
currentAmount := inputAmount
|
||||
currentToken := inputToken
|
||||
|
||||
for i, pool := range path {
|
||||
// Determine output token for this step
|
||||
var outputToken TokenInfo
|
||||
if currentToken.Address == pool.Token0.Address {
|
||||
outputToken = TokenInfo{
|
||||
Address: pool.Token1.Address,
|
||||
Symbol: "TOKEN1", // In a real implementation, you'd fetch the actual symbol
|
||||
Decimals: 18,
|
||||
}
|
||||
} else if currentToken.Address == pool.Token1.Address {
|
||||
outputToken = TokenInfo{
|
||||
Address: pool.Token0.Address,
|
||||
Symbol: "TOKEN0", // In a real implementation, you'd fetch the actual symbol
|
||||
Decimals: 18,
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("token %s not found in pool %s", currentToken.Symbol, pool.Address)
|
||||
}
|
||||
|
||||
// For this simplified implementation, we'll calculate a mock amount out
|
||||
// In a real implementation, you'd use the pricer's CalculateAmountOut method
|
||||
amountOut := currentAmount // Simple 1:1 for this example
|
||||
priceImpact := &UniversalDecimal{Value: big.NewInt(0), Decimals: 4, Symbol: "PERCENT"} // No impact in mock
|
||||
|
||||
// Estimate gas for this step
|
||||
estimatedGas, err := calc.gasEstimator.EstimateSwapGas(ExchangeUniswapV3, pool) // Using a mock exchange type
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error estimating gas for pool %s: %w", pool.Address, err)
|
||||
}
|
||||
|
||||
// Create execution step
|
||||
route[i] = ExchangeStep{
|
||||
Exchange: ExchangeUniswapV3, // Using a mock exchange type
|
||||
Pool: pool,
|
||||
TokenIn: currentToken,
|
||||
TokenOut: outputToken,
|
||||
AmountIn: currentAmount,
|
||||
AmountOut: amountOut,
|
||||
PriceImpact: priceImpact,
|
||||
EstimatedGas: estimatedGas,
|
||||
}
|
||||
|
||||
// Update for next iteration
|
||||
currentAmount = amountOut
|
||||
currentToken = outputToken
|
||||
}
|
||||
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// calculateTotalGasCost calculates the total gas cost for the entire route
|
||||
func (calc *ArbitrageCalculator) calculateTotalGasCost(route []ExchangeStep) (*UniversalDecimal, error) {
|
||||
// Get current gas price
|
||||
gasPrice, err := calc.gasEstimator.GetCurrentGasPrice()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting gas price: %w", err)
|
||||
}
|
||||
|
||||
// Sum up all gas estimates
|
||||
totalGas := uint64(0)
|
||||
for _, step := range route {
|
||||
totalGas += step.EstimatedGas
|
||||
}
|
||||
|
||||
// Add flash swap overhead if multi-step
|
||||
if len(route) > 1 {
|
||||
flashSwapGas, err := calc.gasEstimator.EstimateFlashSwapGas([]*PoolData{})
|
||||
if err == nil {
|
||||
totalGas += flashSwapGas
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to gas cost in ETH
|
||||
totalGasInt64, err := security.SafeUint64ToInt64(totalGas)
|
||||
if err != nil {
|
||||
// This is very unlikely for gas calculations, but handle safely
|
||||
// Use maximum safe value as fallback
|
||||
totalGasInt64 = math.MaxInt64
|
||||
}
|
||||
totalGasBig := big.NewInt(totalGasInt64)
|
||||
totalGasDecimal, err := NewUniversalDecimal(totalGasBig, 0, "GAS")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return calc.decimalConverter.Multiply(totalGasDecimal, gasPrice, 18, "ETH")
|
||||
}
|
||||
|
||||
// calculateProfits calculates gross profit, net profit, and profit percentage
|
||||
func (calc *ArbitrageCalculator) calculateProfits(
|
||||
inputAmount, outputAmount, gasCost *UniversalDecimal,
|
||||
inputToken, outputToken TokenInfo,
|
||||
) (*UniversalDecimal, *UniversalDecimal, *UniversalDecimal, error) {
|
||||
|
||||
// Convert amounts to common denomination (ETH) for comparison
|
||||
inputETH := calc.convertToETH(inputAmount, inputToken)
|
||||
outputETH := calc.convertToETH(outputAmount, outputToken)
|
||||
|
||||
// Gross profit = output - input (in ETH terms)
|
||||
grossProfit, err := calc.decimalConverter.Subtract(outputETH, inputETH)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error calculating gross profit: %w", err)
|
||||
}
|
||||
|
||||
// Net profit = gross profit - gas cost
|
||||
netProfit, err := calc.decimalConverter.Subtract(grossProfit, gasCost)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error calculating net profit: %w", err)
|
||||
}
|
||||
|
||||
// Profit percentage = (net profit / input) * 100
|
||||
profitPercentage, err := calc.decimalConverter.CalculatePercentage(netProfit, inputETH)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error calculating profit percentage: %w", err)
|
||||
}
|
||||
|
||||
return grossProfit, netProfit, profitPercentage, nil
|
||||
}
|
||||
|
||||
// calculateTotalPriceImpact calculates cumulative price impact across all steps
|
||||
func (calc *ArbitrageCalculator) calculateTotalPriceImpact(route []ExchangeStep) (*UniversalDecimal, error) {
|
||||
if len(route) == 0 {
|
||||
return NewUniversalDecimal(big.NewInt(0), 4, "PERCENT")
|
||||
}
|
||||
|
||||
// Compound price impacts: (1 + impact1) * (1 + impact2) - 1
|
||||
compoundedImpact, err := calc.decimalConverter.FromString("1", 4, "COMPOUND")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, step := range route {
|
||||
// Convert price impact to factor (1 + impact)
|
||||
one, _ := calc.decimalConverter.FromString("1", 4, "ONE")
|
||||
impactFactor, err := calc.decimalConverter.Add(one, step.PriceImpact)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating impact factor: %w", err)
|
||||
}
|
||||
|
||||
// Multiply with cumulative impact
|
||||
compoundedImpact, err = calc.decimalConverter.Multiply(compoundedImpact, impactFactor, 4, "COMPOUND")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error compounding impact: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Subtract 1 to get final impact percentage
|
||||
one, _ := calc.decimalConverter.FromString("1", 4, "ONE")
|
||||
totalImpact, err := calc.decimalConverter.Subtract(compoundedImpact, one)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating total impact: %w", err)
|
||||
}
|
||||
|
||||
return totalImpact, nil
|
||||
}
|
||||
|
||||
// calculateMinimumOutput calculates minimum output accounting for slippage
|
||||
func (calc *ArbitrageCalculator) calculateMinimumOutput(expectedOutput *UniversalDecimal) (*UniversalDecimal, error) {
|
||||
// Apply slippage tolerance
|
||||
slippageFactor, err := calc.decimalConverter.Subtract(
|
||||
&UniversalDecimal{Value: big.NewInt(10000), Decimals: 4, Symbol: "ONE"},
|
||||
calc.maxSlippage,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return calc.decimalConverter.Multiply(expectedOutput, slippageFactor, 18, "TOKEN")
|
||||
}
|
||||
|
||||
// assessRisks performs comprehensive risk assessment
|
||||
func (calc *ArbitrageCalculator) assessRisks(route []ExchangeStep, priceImpact, netProfit *UniversalDecimal) RiskAssessment {
|
||||
assessment := RiskAssessment{
|
||||
Warnings: make([]string, 0),
|
||||
}
|
||||
|
||||
// Assess liquidity risk
|
||||
assessment.Liquidity = calc.assessLiquidityRisk(route)
|
||||
|
||||
// Assess price impact risk
|
||||
assessment.PriceImpact = calc.assessPriceImpactRisk(priceImpact)
|
||||
|
||||
// Assess profitability risk
|
||||
profitRisk := calc.assessProfitabilityRisk(netProfit)
|
||||
|
||||
// Assess gas price risk
|
||||
assessment.GasPrice = calc.assessGasPriceRisk()
|
||||
|
||||
// Calculate overall risk (worst of all categories)
|
||||
risks := []RiskLevel{assessment.Liquidity, assessment.PriceImpact, profitRisk, assessment.GasPrice}
|
||||
assessment.Overall = calc.calculateOverallRisk(risks)
|
||||
|
||||
// Calculate OverallRisk as a numeric value (0.0 to 1.0) based on the overall risk level
|
||||
switch assessment.Overall {
|
||||
case RiskLow:
|
||||
assessment.OverallRisk = 0.1
|
||||
case RiskMedium:
|
||||
assessment.OverallRisk = 0.4
|
||||
case RiskHigh:
|
||||
assessment.OverallRisk = 0.7
|
||||
case RiskCritical:
|
||||
assessment.OverallRisk = 0.95
|
||||
default:
|
||||
assessment.OverallRisk = 0.5 // Default to medium risk
|
||||
}
|
||||
|
||||
return assessment
|
||||
}
|
||||
|
||||
// Helper risk assessment methods
|
||||
func (calc *ArbitrageCalculator) assessLiquidityRisk(route []ExchangeStep) RiskLevel {
|
||||
for _, step := range route {
|
||||
// For this simplified implementation, assume a mock liquidity value
|
||||
// In a real implementation, you'd get this from the pricing engine
|
||||
mockLiquidity, _ := calc.decimalConverter.FromString("1000", 18, "TOKEN") // 1000 tokens
|
||||
if mockLiquidity.IsZero() {
|
||||
return RiskHigh
|
||||
}
|
||||
|
||||
// Check if trade size is significant portion of liquidity (>10%)
|
||||
tenPercent, _ := calc.decimalConverter.FromString("10", 4, "PERCENT")
|
||||
tradeSizePercent, _ := calc.decimalConverter.CalculatePercentage(step.AmountIn, mockLiquidity)
|
||||
|
||||
if comp, _ := calc.decimalConverter.Compare(tradeSizePercent, tenPercent); comp > 0 {
|
||||
return RiskMedium
|
||||
}
|
||||
}
|
||||
return RiskLow
|
||||
}
|
||||
|
||||
func (calc *ArbitrageCalculator) assessPriceImpactRisk(priceImpact *UniversalDecimal) RiskLevel {
|
||||
fivePercent, _ := calc.decimalConverter.FromString("5", 4, "PERCENT")
|
||||
twoPercent, _ := calc.decimalConverter.FromString("2", 4, "PERCENT")
|
||||
|
||||
if comp, _ := calc.decimalConverter.Compare(priceImpact, fivePercent); comp > 0 {
|
||||
return RiskHigh
|
||||
}
|
||||
if comp, _ := calc.decimalConverter.Compare(priceImpact, twoPercent); comp > 0 {
|
||||
return RiskMedium
|
||||
}
|
||||
return RiskLow
|
||||
}
|
||||
|
||||
func (calc *ArbitrageCalculator) assessProfitabilityRisk(netProfit *UniversalDecimal) RiskLevel {
|
||||
if netProfit.IsNegative() {
|
||||
return RiskCritical
|
||||
}
|
||||
|
||||
smallProfit, _ := calc.decimalConverter.FromString("0.001", 18, "ETH") // $1 at $1000/ETH
|
||||
mediumProfit, _ := calc.decimalConverter.FromString("0.01", 18, "ETH") // $10 at $1000/ETH
|
||||
|
||||
if comp, _ := calc.decimalConverter.Compare(netProfit, smallProfit); comp < 0 {
|
||||
return RiskHigh
|
||||
}
|
||||
if comp, _ := calc.decimalConverter.Compare(netProfit, mediumProfit); comp < 0 {
|
||||
return RiskMedium
|
||||
}
|
||||
return RiskLow
|
||||
}
|
||||
|
||||
func (calc *ArbitrageCalculator) assessGasPriceRisk() RiskLevel {
|
||||
currentGas, _ := calc.gasEstimator.GetCurrentGasPrice()
|
||||
|
||||
if comp, _ := calc.decimalConverter.Compare(currentGas, calc.maxGasPriceGwei); comp > 0 {
|
||||
return RiskHigh
|
||||
}
|
||||
|
||||
twentyGwei, _ := calc.decimalConverter.FromString("20", 9, "GWEI")
|
||||
if comp, _ := calc.decimalConverter.Compare(currentGas, twentyGwei); comp > 0 {
|
||||
return RiskMedium
|
||||
}
|
||||
|
||||
return RiskLow
|
||||
}
|
||||
|
||||
func (calc *ArbitrageCalculator) calculateOverallRisk(risks []RiskLevel) RiskLevel {
|
||||
riskScores := map[RiskLevel]int{
|
||||
RiskLow: 1,
|
||||
RiskMedium: 2,
|
||||
RiskHigh: 3,
|
||||
RiskCritical: 4,
|
||||
}
|
||||
|
||||
maxScore := 0
|
||||
for _, risk := range risks {
|
||||
if score := riskScores[risk]; score > maxScore {
|
||||
maxScore = score
|
||||
}
|
||||
}
|
||||
|
||||
for risk, score := range riskScores {
|
||||
if score == maxScore {
|
||||
return risk
|
||||
}
|
||||
}
|
||||
return RiskLow
|
||||
}
|
||||
|
||||
// calculateConfidence calculates confidence score based on risk and profit
|
||||
func (calc *ArbitrageCalculator) calculateConfidence(risk RiskAssessment, netProfit, priceImpact *UniversalDecimal) float64 {
|
||||
baseConfidence := 0.5
|
||||
|
||||
// Adjust for risk level
|
||||
switch risk.Overall {
|
||||
case RiskLow:
|
||||
baseConfidence += 0.3
|
||||
case RiskMedium:
|
||||
baseConfidence += 0.1
|
||||
case RiskHigh:
|
||||
baseConfidence -= 0.2
|
||||
case RiskCritical:
|
||||
baseConfidence -= 0.4
|
||||
}
|
||||
|
||||
// Adjust for profit magnitude
|
||||
if netProfit.IsPositive() {
|
||||
largeProfit, _ := calc.decimalConverter.FromString("0.1", 18, "ETH")
|
||||
if comp, _ := calc.decimalConverter.Compare(netProfit, largeProfit); comp > 0 {
|
||||
baseConfidence += 0.2
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust for price impact
|
||||
lowImpact, _ := calc.decimalConverter.FromString("1", 4, "PERCENT")
|
||||
if comp, _ := calc.decimalConverter.Compare(priceImpact, lowImpact); comp < 0 {
|
||||
baseConfidence += 0.1
|
||||
}
|
||||
|
||||
if baseConfidence < 0 {
|
||||
baseConfidence = 0
|
||||
}
|
||||
if baseConfidence > 1 {
|
||||
baseConfidence = 1
|
||||
}
|
||||
|
||||
return baseConfidence
|
||||
}
|
||||
|
||||
// estimateExecutionTime estimates execution time in milliseconds
|
||||
func (calc *ArbitrageCalculator) estimateExecutionTime(route []ExchangeStep) int64 {
|
||||
baseTime := int64(500) // 500ms base
|
||||
|
||||
// Add time per hop
|
||||
hopTime := int64(len(route)) * 200
|
||||
|
||||
// Add time for complex exchanges
|
||||
complexTime := int64(0)
|
||||
for _, step := range route {
|
||||
switch ExchangeType(step.Exchange) {
|
||||
case ExchangeUniswapV3, ExchangeCamelot:
|
||||
complexTime += 300 // Concentrated liquidity is more complex
|
||||
case ExchangeBalancer, ExchangeCurve:
|
||||
complexTime += 400 // Weighted/stable pools are complex
|
||||
default:
|
||||
complexTime += 100 // Simple AMM
|
||||
}
|
||||
}
|
||||
|
||||
return baseTime + hopTime + complexTime
|
||||
}
|
||||
|
||||
// convertToETH converts any token amount to ETH for comparison (placeholder)
|
||||
func (calc *ArbitrageCalculator) convertToETH(amount *UniversalDecimal, token TokenInfo) *UniversalDecimal {
|
||||
// This is a placeholder - in production, this would query price oracles
|
||||
// For now, assume 1:1 conversion for demonstration
|
||||
ethAmount, _ := calc.decimalConverter.ConvertTo(amount, 18, "ETH")
|
||||
return ethAmount
|
||||
}
|
||||
|
||||
// IsOpportunityProfitable checks if opportunity meets minimum criteria
|
||||
// IsOpportunityProfitable checks if an opportunity meets profitability criteria
|
||||
func (calc *ArbitrageCalculator) IsOpportunityProfitable(opportunity *types.ArbitrageOpportunity) bool {
|
||||
if opportunity == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check minimum profit threshold
|
||||
if !calc.checkProfitThreshold(opportunity) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check maximum price impact
|
||||
if !calc.checkPriceImpactThreshold(opportunity) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check risk level
|
||||
if !calc.checkRiskLevel(opportunity) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check confidence threshold
|
||||
if !calc.checkConfidenceThreshold(opportunity) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// checkProfitThreshold checks if the opportunity meets minimum profit requirements
|
||||
func (calc *ArbitrageCalculator) checkProfitThreshold(opportunity *types.ArbitrageOpportunity) bool {
|
||||
if opportunity.Quantities != nil {
|
||||
if netProfitUD, err := calc.decimalAmountToUniversal(opportunity.Quantities.NetProfit); err == nil {
|
||||
if cmp, err := calc.decimalConverter.Compare(netProfitUD, calc.minProfitThreshold); err == nil && cmp < 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
} else if opportunity.NetProfit != nil {
|
||||
if opportunity.NetProfit.Cmp(calc.minProfitThreshold.Value) < 0 {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkPriceImpactThreshold checks if the opportunity is below maximum price impact
|
||||
func (calc *ArbitrageCalculator) checkPriceImpactThreshold(opportunity *types.ArbitrageOpportunity) bool {
|
||||
if opportunity.Quantities != nil {
|
||||
if impactUD, err := calc.decimalAmountToUniversal(opportunity.Quantities.PriceImpact); err == nil {
|
||||
if cmp, err := calc.decimalConverter.Compare(impactUD, calc.maxPriceImpact); err == nil && cmp > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
maxImpactFloat := float64(calc.maxPriceImpact.Value.Int64()) / math.Pow10(int(calc.maxPriceImpact.Decimals))
|
||||
if opportunity.PriceImpact > maxImpactFloat {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkRiskLevel checks if the opportunity's risk is acceptable
|
||||
func (calc *ArbitrageCalculator) checkRiskLevel(opportunity *types.ArbitrageOpportunity) bool {
|
||||
return opportunity.Risk < 0.8 // High risk threshold
|
||||
}
|
||||
|
||||
// checkConfidenceThreshold checks if the opportunity has sufficient confidence
|
||||
func (calc *ArbitrageCalculator) checkConfidenceThreshold(opportunity *types.ArbitrageOpportunity) bool {
|
||||
return opportunity.Confidence >= 0.3
|
||||
}
|
||||
|
||||
// SortOpportunitiesByProfitability sorts opportunities by net profit descending
|
||||
func (calc *ArbitrageCalculator) SortOpportunitiesByProfitability(opportunities []*types.ArbitrageOpportunity) {
|
||||
sort.Slice(opportunities, func(i, j int) bool {
|
||||
left, errL := calc.decimalAmountToUniversal(opportunities[i].Quantities.NetProfit)
|
||||
right, errR := calc.decimalAmountToUniversal(opportunities[j].Quantities.NetProfit)
|
||||
if errL == nil && errR == nil {
|
||||
cmp, err := calc.decimalConverter.Compare(left, right)
|
||||
if err == nil {
|
||||
return cmp > 0
|
||||
}
|
||||
}
|
||||
// Fallback to canonical big.Int comparison
|
||||
return opportunities[i].NetProfit.Cmp(opportunities[j].NetProfit) > 0 // Descending order
|
||||
})
|
||||
}
|
||||
|
||||
func (calc *ArbitrageCalculator) decimalAmountToUniversal(dec types.DecimalAmount) (*UniversalDecimal, error) {
|
||||
if dec.Value == "" {
|
||||
return nil, fmt.Errorf("decimal amount empty")
|
||||
}
|
||||
val, ok := new(big.Int).SetString(dec.Value, 10)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid decimal amount %s", dec.Value)
|
||||
}
|
||||
return NewUniversalDecimal(val, dec.Decimals, dec.Symbol)
|
||||
}
|
||||
|
||||
// CalculateArbitrage calculates arbitrage opportunity for a given path and input amount
|
||||
func (calc *ArbitrageCalculator) CalculateArbitrage(ctx context.Context, inputAmount *UniversalDecimal, path []*PoolData) (*types.ArbitrageOpportunity, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, fmt.Errorf("empty path provided")
|
||||
}
|
||||
|
||||
// Get the input and output tokens for the path
|
||||
inputToken := path[0].Token0
|
||||
outputToken := path[len(path)-1].Token1
|
||||
if path[len(path)-1].Token0.Address == inputToken.Address {
|
||||
outputToken = path[len(path)-1].Token0
|
||||
}
|
||||
|
||||
// Calculate the arbitrage opportunity for this path
|
||||
opportunity, err := calc.CalculateArbitrageOpportunity(path, inputAmount, inputToken, outputToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to calculate arbitrage opportunity: %w", err)
|
||||
}
|
||||
|
||||
return opportunity, nil
|
||||
}
|
||||
|
||||
// FindOptimalPath finds the most profitable arbitrage path between two tokens
|
||||
func (calc *ArbitrageCalculator) FindOptimalPath(ctx context.Context, tokenA, tokenB common.Address, amount *UniversalDecimal) (*types.ArbitrageOpportunity, error) {
|
||||
// In a real implementation, this would query for available paths between tokens
|
||||
// and calculate the most profitable path. For this implementation, we'll return an error
|
||||
// indicating no path is available since we don't have direct path-finding ability in the calculator
|
||||
return nil, fmt.Errorf("FindOptimalPath not implemented in calculator - use executor.CalculateOptimalPath instead")
|
||||
}
|
||||
|
||||
// FilterProfitableOpportunities returns only profitable opportunities
|
||||
func (calc *ArbitrageCalculator) FilterProfitableOpportunities(opportunities []*types.ArbitrageOpportunity) []*types.ArbitrageOpportunity {
|
||||
profitable := make([]*types.ArbitrageOpportunity, 0)
|
||||
|
||||
for _, opp := range opportunities {
|
||||
if calc.IsOpportunityProfitable(opp) {
|
||||
profitable = append(profitable, opp)
|
||||
}
|
||||
}
|
||||
|
||||
return profitable
|
||||
}
|
||||
175
orig/pkg/math/arbitrage_calculator_test.go
Normal file
175
orig/pkg/math/arbitrage_calculator_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/fraktal/mev-beta/pkg/types"
|
||||
)
|
||||
|
||||
type stubGasEstimator struct {
|
||||
price *UniversalDecimal
|
||||
}
|
||||
|
||||
func (s stubGasEstimator) EstimateSwapGas(exchange ExchangeType, poolData *PoolData) (uint64, error) {
|
||||
return 100_000, nil
|
||||
}
|
||||
|
||||
func (s stubGasEstimator) EstimateFlashSwapGas(route []*PoolData) (uint64, error) {
|
||||
return 50_000, nil
|
||||
}
|
||||
|
||||
func (s stubGasEstimator) GetCurrentGasPrice() (*UniversalDecimal, error) {
|
||||
return s.price, nil
|
||||
}
|
||||
|
||||
func TestIsOpportunityProfitableRespectsThreshold(t *testing.T) {
|
||||
estimator := stubGasEstimator{price: func() *UniversalDecimal {
|
||||
ud, _ := NewUniversalDecimal(big.NewInt(1_000_000_000), 9, "GWEI")
|
||||
return ud
|
||||
}()}
|
||||
calc := NewArbitrageCalculator(estimator)
|
||||
|
||||
belowThreshold, _ := NewUniversalDecimal(big.NewInt(9_000_000_000_000_000), 18, "ETH")
|
||||
priceImpact, _ := NewUniversalDecimal(big.NewInt(100), 4, "PERCENT")
|
||||
|
||||
opportunity := &types.ArbitrageOpportunity{
|
||||
NetProfit: belowThreshold.Value,
|
||||
PriceImpact: 0.01,
|
||||
Confidence: 0.5,
|
||||
Quantities: &types.OpportunityQuantities{
|
||||
NetProfit: toDecimalAmount(belowThreshold),
|
||||
PriceImpact: toDecimalAmount(priceImpact),
|
||||
AmountIn: toDecimalAmount(belowThreshold),
|
||||
AmountOut: toDecimalAmount(belowThreshold),
|
||||
GrossProfit: toDecimalAmount(belowThreshold),
|
||||
GasCost: toDecimalAmount(belowThreshold),
|
||||
ProfitPercent: toDecimalAmount(priceImpact),
|
||||
},
|
||||
}
|
||||
|
||||
if calc.IsOpportunityProfitable(opportunity) {
|
||||
t.Fatalf("expected below-threshold opportunity to be rejected")
|
||||
}
|
||||
|
||||
aboveThreshold, _ := NewUniversalDecimal(big.NewInt(2_000_000_000_000_0000), 18, "ETH")
|
||||
opportunity.NetProfit = aboveThreshold.Value
|
||||
opportunity.Quantities.NetProfit = toDecimalAmount(aboveThreshold)
|
||||
if !calc.IsOpportunityProfitable(opportunity) {
|
||||
t.Fatalf("expected opportunity above threshold to be accepted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortOpportunitiesByProfitabilityUsesDecimals(t *testing.T) {
|
||||
estimator := stubGasEstimator{price: func() *UniversalDecimal {
|
||||
ud, _ := NewUniversalDecimal(big.NewInt(1_000_000_000), 9, "GWEI")
|
||||
return ud
|
||||
}()}
|
||||
calc := NewArbitrageCalculator(estimator)
|
||||
|
||||
a, _ := NewUniversalDecimal(big.NewInt(1_500_000_000_000_0000), 18, "ETH")
|
||||
b, _ := NewUniversalDecimal(big.NewInt(5_000_000_000_000_000), 18, "ETH")
|
||||
|
||||
oppA := &types.ArbitrageOpportunity{
|
||||
NetProfit: a.Value,
|
||||
Quantities: &types.OpportunityQuantities{
|
||||
NetProfit: toDecimalAmount(a),
|
||||
},
|
||||
}
|
||||
oppB := &types.ArbitrageOpportunity{
|
||||
NetProfit: b.Value,
|
||||
Quantities: &types.OpportunityQuantities{
|
||||
NetProfit: toDecimalAmount(b),
|
||||
},
|
||||
}
|
||||
|
||||
opps := []*types.ArbitrageOpportunity{oppB, oppA}
|
||||
calc.SortOpportunitiesByProfitability(opps)
|
||||
|
||||
if opps[0] != oppA {
|
||||
t.Fatalf("expected higher decimal profit opportunity first")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateArbitrageOpportunitySetsQuantities(t *testing.T) {
|
||||
estimator := stubGasEstimator{price: func() *UniversalDecimal {
|
||||
ud, _ := NewUniversalDecimal(big.NewInt(1_000_000_000), 9, "GWEI")
|
||||
return ud
|
||||
}()}
|
||||
calc := NewArbitrageCalculator(estimator)
|
||||
|
||||
pool := &PoolData{
|
||||
Address: "0xpool",
|
||||
ExchangeType: ExchangeUniswapV2,
|
||||
Token0: TokenInfo{Address: "0x0", Symbol: "TOKEN0", Decimals: 18},
|
||||
Token1: TokenInfo{Address: "0x1", Symbol: "TOKEN1", Decimals: 18},
|
||||
}
|
||||
|
||||
amountIn, _ := NewUniversalDecimal(big.NewInt(1_000_000_000_000_000), 18, "TOKEN0")
|
||||
|
||||
opportunity, err := calc.CalculateArbitrageOpportunity([]*PoolData{pool}, amountIn, pool.Token0, pool.Token1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if opportunity.Quantities == nil {
|
||||
t.Fatalf("expected quantities to be populated")
|
||||
}
|
||||
if opportunity.Quantities.NetProfit.Value == "" {
|
||||
t.Fatalf("expected net profit decimal to have value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateMinimumOutputAppliesSlippage(t *testing.T) {
|
||||
estimator := stubGasEstimator{price: func() *UniversalDecimal {
|
||||
ud, _ := NewUniversalDecimal(big.NewInt(1_000_000_000), 9, "GWEI")
|
||||
return ud
|
||||
}()}
|
||||
calc := NewArbitrageCalculator(estimator)
|
||||
|
||||
expected, _ := NewUniversalDecimal(big.NewInt(1_000_000_000_000_000_000), 18, "ETH")
|
||||
minOutput, err := calc.calculateMinimumOutput(expected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Default max slippage is 1% -> expect 0.99 ETH
|
||||
expectedMin, _ := NewUniversalDecimal(big.NewInt(990000000000000000), 18, "ETH")
|
||||
cmp, err := calc.decimalConverter.Compare(minOutput, expectedMin)
|
||||
if err != nil || cmp != 0 {
|
||||
t.Fatalf("expected min output 0.99 ETH, got %s", calc.decimalConverter.ToHumanReadable(minOutput))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateProfitsCapturesSpread(t *testing.T) {
|
||||
estimator := stubGasEstimator{price: func() *UniversalDecimal {
|
||||
ud, _ := NewUniversalDecimal(big.NewInt(1_000_000_000), 9, "GWEI")
|
||||
return ud
|
||||
}()}
|
||||
calc := NewArbitrageCalculator(estimator)
|
||||
|
||||
amountIn, _ := NewUniversalDecimal(big.NewInt(10_000_000_000_000_000), 18, "ETH") // 0.01
|
||||
amountOut, _ := NewUniversalDecimal(big.NewInt(12_000_000_000_000_000), 18, "ETH")
|
||||
gasCost, _ := NewUniversalDecimal(big.NewInt(500_000_000_000_000), 18, "ETH")
|
||||
|
||||
gross, net, pct, err := calc.calculateProfits(amountIn, amountOut, gasCost, TokenInfo{Symbol: "ETH", Decimals: 18}, TokenInfo{Symbol: "ETH", Decimals: 18})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedGross, _ := NewUniversalDecimal(big.NewInt(2_000_000_000_000_000), 18, "ETH")
|
||||
cmp, err := calc.decimalConverter.Compare(gross, expectedGross)
|
||||
if err != nil || cmp != 0 {
|
||||
t.Fatalf("expected gross profit 0.002 ETH, got %s", calc.decimalConverter.ToHumanReadable(gross))
|
||||
}
|
||||
|
||||
expectedNet, _ := NewUniversalDecimal(big.NewInt(1_500_000_000_000_000), 18, "ETH")
|
||||
cmp, err = calc.decimalConverter.Compare(net, expectedNet)
|
||||
if err != nil || cmp != 0 {
|
||||
t.Fatalf("expected net profit 0.0015 ETH, got %s", calc.decimalConverter.ToHumanReadable(net))
|
||||
}
|
||||
|
||||
if pct == nil || pct.Value.Sign() <= 0 {
|
||||
t.Fatalf("expected positive profit percentage")
|
||||
}
|
||||
}
|
||||
125
orig/pkg/math/benchmark_test.go
Normal file
125
orig/pkg/math/benchmark_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// BenchmarkAllProtocols runs performance tests for all supported protocols
|
||||
func BenchmarkAllProtocols(b *testing.B) {
|
||||
// Create test values for all protocols
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token
|
||||
reserveOut, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 token
|
||||
sqrtPriceX96, _ := new(big.Int).SetString("79228162514264337593543950336", 10) // 2^96
|
||||
liquidity, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH worth of liquidity
|
||||
|
||||
calculator := NewMathCalculator()
|
||||
|
||||
b.Run("UniswapV2", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = calculator.uniswapV2.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("UniswapV3", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = calculator.uniswapV3.CalculateAmountOut(amountIn, sqrtPriceX96, liquidity, 3000)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Curve", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = calculator.curve.CalculateAmountOut(amountIn, reserveIn, reserveOut, 400)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Kyber", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = calculator.kyber.CalculateAmountOut(amountIn, sqrtPriceX96, liquidity, 1000)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Balancer", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = calculator.balancer.CalculateAmountOut(amountIn, reserveIn, reserveOut, 1000)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConstantSum", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = calculator.constantSum.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPriceMovementDetection runs performance tests for price movement detection
|
||||
func BenchmarkPriceMovementDetection(b *testing.B) {
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
|
||||
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = WillSwapMovePrice(amountIn, reserveIn, reserveOut, 0.01)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPriceImpactCalculations runs performance tests for price impact calculations
|
||||
func BenchmarkPriceImpactCalculations(b *testing.B) {
|
||||
calculator := NewPriceImpactCalculator()
|
||||
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
|
||||
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = calculator.CalculatePriceImpact("uniswap_v2", amountIn, reserveIn, reserveOut, nil, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkOptimizedUniswapV2 calculates amount out using optimized approach
|
||||
func BenchmarkOptimizedUniswapV2(b *testing.B) {
|
||||
// Pre-allocated values to reduce allocations
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
|
||||
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
|
||||
math := NewUniswapV2Math()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = math.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkOptimizedPriceMovementDetection runs performance tests for optimized price movement detection
|
||||
func BenchmarkOptimizedPriceMovementDetection(b *testing.B) {
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
|
||||
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Simplified check without full calculation for performance comparison
|
||||
// This just does the basic arithmetic to compare with the full function
|
||||
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
|
||||
amountOut, err := NewUniswapV2Math().CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
|
||||
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
|
||||
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
|
||||
impact := new(big.Float).Sub(priceBefore, priceAfter)
|
||||
impact.Quo(impact, priceBefore)
|
||||
impactFloat, _ := impact.Float64()
|
||||
_ = impactFloat >= 0.01
|
||||
}
|
||||
}
|
||||
96
orig/pkg/math/cached_bench_test.go
Normal file
96
orig/pkg/math/cached_bench_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/holiman/uint256"
|
||||
)
|
||||
|
||||
// Benchmark original vs cached SqrtPriceX96ToPrice conversion
|
||||
func BenchmarkSqrtPriceX96ToPriceOriginal(b *testing.B) {
|
||||
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
|
||||
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Original calculation: price = sqrtPriceX96^2 / 2^192
|
||||
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
|
||||
price := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
|
||||
|
||||
// Calculate 2^192
|
||||
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
q192Float := new(big.Float).SetInt(q192)
|
||||
|
||||
// Divide by 2^192
|
||||
price.Quo(price, q192Float)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSqrtPriceX96ToPriceCached(b *testing.B) {
|
||||
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
|
||||
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Cached calculation using precomputed constants
|
||||
SqrtPriceX96ToPriceCached(sqrtPriceX96)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark original vs cached PriceToSqrtPriceX96 conversion
|
||||
func BenchmarkPriceToSqrtPriceX96Original(b *testing.B) {
|
||||
// Use a typical price value (represents price of ~2000 USDC/ETH)
|
||||
price := new(big.Float).SetFloat64(2000.0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Original calculation: sqrtPriceX96 = sqrt(price * 2^192)
|
||||
|
||||
// Calculate 2^192
|
||||
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
q192Float := new(big.Float).SetInt(q192)
|
||||
|
||||
// Multiply price by 2^192
|
||||
result := new(big.Float).Mul(price, q192Float)
|
||||
|
||||
// Calculate square root
|
||||
result.Sqrt(result)
|
||||
|
||||
// Convert to big.Int
|
||||
sqrtPriceX96 := new(big.Int)
|
||||
result.Int(sqrtPriceX96)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPriceToSqrtPriceX96Cached(b *testing.B) {
|
||||
// Use a typical price value (represents price of ~2000 USDC/ETH)
|
||||
price := new(big.Float).SetFloat64(2000.0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Cached calculation using precomputed constants
|
||||
PriceToSqrtPriceX96Cached(price)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark optimized versions with uint256
|
||||
func BenchmarkSqrtPriceX96ToPriceOptimized(b *testing.B) {
|
||||
// Use a typical sqrtPriceX96 value
|
||||
sqrtPriceX96 := uint256.NewInt(0).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
SqrtPriceX96ToPriceOptimized(sqrtPriceX96)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPriceToSqrtPriceX96Optimized(b *testing.B) {
|
||||
// Use a typical price value
|
||||
price := new(big.Float).SetFloat64(2000.0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
PriceToSqrtPriceX96Optimized(price)
|
||||
}
|
||||
}
|
||||
126
orig/pkg/math/cached_functions.go
Normal file
126
orig/pkg/math/cached_functions.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"sync"
|
||||
|
||||
"github.com/holiman/uint256"
|
||||
|
||||
"github.com/fraktal/mev-beta/pkg/uniswap"
|
||||
)
|
||||
|
||||
// Cached mathematical constants to avoid recomputation
|
||||
var (
|
||||
cachedConstantsOnce sync.Once
|
||||
cachedQ192 *big.Int
|
||||
cachedQ96 *big.Int
|
||||
cachedQ384 *big.Int
|
||||
cachedTwoPower96 *big.Float
|
||||
cachedTwoPower192 *big.Float
|
||||
cachedTwoPower384 *big.Float
|
||||
)
|
||||
|
||||
// initCachedConstants initializes all cached constants once
|
||||
func initCachedConstants() {
|
||||
cachedConstantsOnce.Do(func() {
|
||||
// Calculate 2^96
|
||||
cachedQ96 = new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)
|
||||
|
||||
// Calculate 2^192
|
||||
cachedQ192 = new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
|
||||
// Calculate 2^384
|
||||
cachedQ384 = new(big.Int).Exp(big.NewInt(2), big.NewInt(384), nil)
|
||||
|
||||
// Convert to big.Float for division operations
|
||||
cachedTwoPower96 = new(big.Float).SetInt(cachedQ96)
|
||||
cachedTwoPower192 = new(big.Float).SetInt(cachedQ192)
|
||||
cachedTwoPower384 = new(big.Float).SetInt(cachedQ384)
|
||||
})
|
||||
}
|
||||
|
||||
// GetCachedQ192 returns the cached value of 2^192
|
||||
func GetCachedQ192() *big.Int {
|
||||
initCachedConstants()
|
||||
return cachedQ192
|
||||
}
|
||||
|
||||
// GetCachedQ96 returns the cached value of 2^96
|
||||
func GetCachedQ96() *big.Int {
|
||||
initCachedConstants()
|
||||
return cachedQ96
|
||||
}
|
||||
|
||||
// GetCachedQ384 returns the cached value of 2^384
|
||||
func GetCachedQ384() *big.Int {
|
||||
initCachedConstants()
|
||||
return cachedQ384
|
||||
}
|
||||
|
||||
// SqrtPriceX96ToPriceCached converts sqrtPriceX96 to a price using cached constants
|
||||
// Formula: price = sqrtPriceX96^2 / 2^192
|
||||
func SqrtPriceX96ToPriceCached(sqrtPriceX96 *big.Int) *big.Float {
|
||||
initCachedConstants()
|
||||
|
||||
// Convert to big.Float for precision
|
||||
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
|
||||
|
||||
// Calculate sqrtPrice^2
|
||||
price := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
|
||||
|
||||
// Divide by 2^192 using cached constant
|
||||
price.Quo(price, cachedTwoPower192)
|
||||
|
||||
return price
|
||||
}
|
||||
|
||||
// PriceToSqrtPriceX96Cached converts a price to sqrtPriceX96 using cached constants
|
||||
// Formula: sqrtPriceX96 = sqrt(price * 2^192)
|
||||
func PriceToSqrtPriceX96Cached(price *big.Float) *big.Int {
|
||||
initCachedConstants()
|
||||
|
||||
// Multiply price by 2^192
|
||||
result := new(big.Float).Mul(price, cachedTwoPower192)
|
||||
|
||||
// Calculate square root
|
||||
result.Sqrt(result)
|
||||
|
||||
// Convert to big.Int
|
||||
sqrtPriceX96 := new(big.Int)
|
||||
result.Int(sqrtPriceX96)
|
||||
|
||||
return sqrtPriceX96
|
||||
}
|
||||
|
||||
// SqrtPriceX96ToPriceOptimized converts sqrtPriceX96 to a price using optimized uint256 operations
|
||||
// Formula: price = sqrtPriceX96^2 / 2^192
|
||||
func SqrtPriceX96ToPriceOptimized(sqrtPriceX96 *uint256.Int) *big.Float {
|
||||
initCachedConstants()
|
||||
|
||||
// Convert to big.Int for calculation
|
||||
sqrtPriceBig := sqrtPriceX96.ToBig()
|
||||
|
||||
// Use cached function for consistency
|
||||
return SqrtPriceX96ToPriceCached(sqrtPriceBig)
|
||||
}
|
||||
|
||||
// PriceToSqrtPriceX96Optimized converts a price to sqrtPriceX96 using optimized operations
|
||||
// Formula: sqrtPriceX96 = sqrt(price * 2^192)
|
||||
func PriceToSqrtPriceX96Optimized(price *big.Float) *uint256.Int {
|
||||
initCachedConstants()
|
||||
|
||||
// Use cached function for consistency
|
||||
sqrtPriceBig := PriceToSqrtPriceX96Cached(price)
|
||||
|
||||
// Convert to uint256
|
||||
return uint256.MustFromBig(sqrtPriceBig)
|
||||
}
|
||||
|
||||
// TickToSqrtPriceX96Optimized calculates sqrtPriceX96 from a tick using optimized operations
|
||||
// Formula: sqrtPriceX96 = 1.0001^(tick/2)
|
||||
func TickToSqrtPriceX96Optimized(tick int) *uint256.Int {
|
||||
// For simplicity, we'll convert to big.Int and use existing implementation
|
||||
tickBig := big.NewInt(int64(tick))
|
||||
sqrtPriceBig := uniswap.TickToSqrtPriceX96(int(tickBig.Int64()))
|
||||
return uint256.MustFromBig(sqrtPriceBig)
|
||||
}
|
||||
127
orig/pkg/math/cached_test.go
Normal file
127
orig/pkg/math/cached_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/holiman/uint256"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test that cached functions produce the same results as original implementations
|
||||
func TestCachedFunctionAccuracy(t *testing.T) {
|
||||
// Test SqrtPriceX96ToPrice functions
|
||||
t.Run("SqrtPriceX96ToPrice", func(t *testing.T) {
|
||||
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
|
||||
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
|
||||
|
||||
// Original calculation
|
||||
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
|
||||
originalPrice := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
|
||||
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
q192Float := new(big.Float).SetInt(q192)
|
||||
originalPrice.Quo(originalPrice, q192Float)
|
||||
|
||||
// Cached calculation
|
||||
cachedPrice := SqrtPriceX96ToPriceCached(sqrtPriceX96)
|
||||
|
||||
// Compare results (should be identical)
|
||||
assert.Equal(t, originalPrice.String(), cachedPrice.String(), "Cached and original SqrtPriceX96ToPrice should produce identical results")
|
||||
})
|
||||
|
||||
// Test PriceToSqrtPriceX96 functions
|
||||
t.Run("PriceToSqrtPriceX96", func(t *testing.T) {
|
||||
// Use a typical price value (represents price of ~2000 USDC/ETH)
|
||||
price := new(big.Float).SetFloat64(2000.0)
|
||||
|
||||
// Original calculation
|
||||
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
q192Float := new(big.Float).SetInt(q192)
|
||||
result := new(big.Float).Mul(price, q192Float)
|
||||
result.Sqrt(result)
|
||||
expectedSqrtPriceX96 := new(big.Int)
|
||||
result.Int(expectedSqrtPriceX96)
|
||||
|
||||
// Cached calculation
|
||||
actualSqrtPriceX96 := PriceToSqrtPriceX96Cached(price)
|
||||
|
||||
// Compare results (should be identical)
|
||||
assert.Equal(t, expectedSqrtPriceX96.String(), actualSqrtPriceX96.String(), "Cached and original PriceToSqrtPriceX96 should produce identical results")
|
||||
})
|
||||
|
||||
// Test optimized functions with uint256
|
||||
t.Run("SqrtPriceX96ToPriceOptimized", func(t *testing.T) {
|
||||
// Use a typical sqrtPriceX96 value
|
||||
sqrtPriceX96Big := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
|
||||
sqrtPriceX96 := uint256.MustFromBig(sqrtPriceX96Big)
|
||||
|
||||
// Cached calculation
|
||||
cachedResult := SqrtPriceX96ToPriceCached(sqrtPriceX96Big)
|
||||
|
||||
// Optimized calculation
|
||||
optimizedResult := SqrtPriceX96ToPriceOptimized(sqrtPriceX96)
|
||||
|
||||
// Compare results (should be identical)
|
||||
assert.Equal(t, cachedResult.String(), optimizedResult.String(), "Optimized and cached SqrtPriceX96ToPrice should produce identical results")
|
||||
})
|
||||
|
||||
// Test optimized functions with uint256
|
||||
t.Run("PriceToSqrtPriceX96Optimized", func(t *testing.T) {
|
||||
// Use a typical price value
|
||||
price := new(big.Float).SetFloat64(2000.0)
|
||||
|
||||
// Cached calculation
|
||||
cachedResult := PriceToSqrtPriceX96Cached(price)
|
||||
|
||||
// Optimized calculation
|
||||
optimizedResult := PriceToSqrtPriceX96Optimized(price)
|
||||
|
||||
// Compare results (should be identical)
|
||||
assert.Equal(t, cachedResult.String(), optimizedResult.ToBig().String(), "Optimized and cached PriceToSqrtPriceX96 should produce identical results")
|
||||
})
|
||||
}
|
||||
|
||||
// Test that cached constants are working correctly
|
||||
func TestCachedConstants(t *testing.T) {
|
||||
// Test that Q192 is correctly calculated
|
||||
expectedQ192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
actualQ192 := GetCachedQ192()
|
||||
assert.Equal(t, expectedQ192.String(), actualQ192.String(), "Cached Q192 should equal 2^192")
|
||||
|
||||
// Test that Q96 is correctly calculated
|
||||
expectedQ96 := new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)
|
||||
actualQ96 := GetCachedQ96()
|
||||
assert.Equal(t, expectedQ96.String(), actualQ96.String(), "Cached Q96 should equal 2^96")
|
||||
|
||||
// Test that Q384 is correctly calculated
|
||||
expectedQ384 := new(big.Int).Exp(big.NewInt(2), big.NewInt(384), nil)
|
||||
actualQ384 := GetCachedQ384()
|
||||
assert.Equal(t, expectedQ384.String(), actualQ384.String(), "Cached Q384 should equal 2^384")
|
||||
}
|
||||
|
||||
// Test edge cases
|
||||
func TestEdgeCases(t *testing.T) {
|
||||
// Test with zero values
|
||||
zero := big.NewInt(0)
|
||||
zeroFloat := new(big.Float).SetInt64(0)
|
||||
|
||||
// SqrtPriceX96ToPrice with zero
|
||||
result := SqrtPriceX96ToPriceCached(zero)
|
||||
assert.Equal(t, "0", result.String(), "SqrtPriceX96ToPriceCached with zero should return zero")
|
||||
|
||||
// PriceToSqrtPriceX96 with zero
|
||||
result2 := PriceToSqrtPriceX96Cached(zeroFloat)
|
||||
assert.Equal(t, "0", result2.String(), "PriceToSqrtPriceX96Cached with zero should return zero")
|
||||
|
||||
// Test with small values
|
||||
one := big.NewInt(1)
|
||||
oneFloat := new(big.Float).SetInt64(1)
|
||||
|
||||
// SqrtPriceX96ToPrice with one
|
||||
result3 := SqrtPriceX96ToPriceCached(one)
|
||||
assert.NotEmpty(t, result3.String(), "SqrtPriceX96ToPriceCached with one should return a value")
|
||||
|
||||
// PriceToSqrtPriceX96 with one
|
||||
result4 := PriceToSqrtPriceX96Cached(oneFloat)
|
||||
assert.NotEmpty(t, result4.String(), "PriceToSqrtPriceX96Cached with one should return a value")
|
||||
}
|
||||
459
orig/pkg/math/decimal_handler.go
Normal file
459
orig/pkg/math/decimal_handler.go
Normal file
@@ -0,0 +1,459 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// UniversalDecimal represents a token amount with precise decimal handling
|
||||
type UniversalDecimal struct {
|
||||
Value *big.Int // Raw value in smallest unit
|
||||
Decimals uint8 // Number of decimal places (0-18)
|
||||
Symbol string // Token symbol for debugging
|
||||
}
|
||||
|
||||
// DecimalConverter handles conversions between different decimal precisions
|
||||
type DecimalConverter struct {
|
||||
// Cache for common scaling factors to avoid repeated calculations
|
||||
scalingFactors map[uint8]*big.Int
|
||||
}
|
||||
|
||||
// NewDecimalConverter creates a new decimal converter with caching
|
||||
func NewDecimalConverter() *DecimalConverter {
|
||||
dc := &DecimalConverter{
|
||||
scalingFactors: make(map[uint8]*big.Int),
|
||||
}
|
||||
|
||||
// Pre-calculate common scaling factors (0-18 decimals)
|
||||
for i := uint8(0); i <= 18; i++ {
|
||||
dc.scalingFactors[i] = new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(i)), nil)
|
||||
}
|
||||
|
||||
return dc
|
||||
}
|
||||
|
||||
// NewUniversalDecimal creates a new universal decimal with comprehensive validation
|
||||
func NewUniversalDecimal(value *big.Int, decimals uint8, symbol string) (*UniversalDecimal, error) {
|
||||
// Validate decimal places
|
||||
if decimals > 18 {
|
||||
return nil, fmt.Errorf("decimal places cannot exceed 18, got %d for token %s", decimals, symbol)
|
||||
}
|
||||
|
||||
// Validate symbol
|
||||
if symbol == "" {
|
||||
return nil, fmt.Errorf("symbol cannot be empty")
|
||||
}
|
||||
|
||||
// Validate value bounds - prevent extremely large values that could cause overflow
|
||||
if value != nil {
|
||||
// Check for reasonable bounds - max value should not exceed what can be represented
|
||||
// in financial calculations (roughly 2^256 / 10^18 for safety)
|
||||
maxValue := new(big.Int)
|
||||
maxValue.Exp(big.NewInt(10), big.NewInt(60), nil) // 10^60 max value for safety
|
||||
|
||||
absValue := new(big.Int).Abs(value)
|
||||
if absValue.Cmp(maxValue) > 0 {
|
||||
return nil, fmt.Errorf("value %s exceeds maximum safe value for token %s", value.String(), symbol)
|
||||
}
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
value = big.NewInt(0)
|
||||
}
|
||||
|
||||
// Copy the value to prevent external modifications
|
||||
valueCopy := new(big.Int).Set(value)
|
||||
|
||||
return &UniversalDecimal{
|
||||
Value: valueCopy,
|
||||
Decimals: decimals,
|
||||
Symbol: symbol,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FromString creates UniversalDecimal from string representation
|
||||
// Intelligently determines format:
|
||||
// 1. Very large numbers (length >= decimals): treated as raw wei/smallest unit
|
||||
// 2. Small numbers (length < decimals): treated as human-readable units
|
||||
// 3. Numbers with decimal point: always treated as human-readable
|
||||
func (dc *DecimalConverter) FromString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) {
|
||||
// Handle empty or zero values
|
||||
if valueStr == "" || valueStr == "0" {
|
||||
return NewUniversalDecimal(big.NewInt(0), decimals, symbol)
|
||||
}
|
||||
|
||||
// Remove any whitespace
|
||||
valueStr = strings.TrimSpace(valueStr)
|
||||
|
||||
// Check for decimal point - if present, treat as human-readable decimal
|
||||
if strings.Contains(valueStr, ".") {
|
||||
return dc.fromDecimalString(valueStr, decimals, symbol)
|
||||
}
|
||||
|
||||
// For integers without decimal point, we need to determine if this is:
|
||||
// - A raw value (like "1000000000000000000" = 1000000000000000000 wei)
|
||||
// - A human-readable value (like "1" = 1.0 ETH = 1000000000000000000 wei)
|
||||
|
||||
// Parse the number first
|
||||
value := new(big.Int)
|
||||
_, success := value.SetString(valueStr, 10)
|
||||
if !success {
|
||||
return nil, fmt.Errorf("invalid number format: %s for token %s", valueStr, symbol)
|
||||
}
|
||||
|
||||
// Improved heuristic for distinguishing raw vs human-readable values:
|
||||
// 1. If value is very large relative to what a human would typically enter, treat as raw
|
||||
// 2. If value is small (< 1000), treat as human-readable
|
||||
// 3. Use length as secondary indicator
|
||||
|
||||
valueInt := value.Int64() // Safe since we parsed it successfully
|
||||
|
||||
// If the value is very small (less than 1000), it's likely human-readable
|
||||
if valueInt < 1000 {
|
||||
// Treat as human-readable value - convert to smallest unit
|
||||
scalingFactor := dc.getScalingFactor(decimals)
|
||||
scaledValue := new(big.Int).Mul(value, scalingFactor)
|
||||
return NewUniversalDecimal(scaledValue, decimals, symbol)
|
||||
}
|
||||
|
||||
// If the value looks like it could be raw wei (very large), treat as raw
|
||||
if len(valueStr) >= int(decimals) && decimals > 0 {
|
||||
// Treat as raw value in smallest unit
|
||||
return NewUniversalDecimal(value, decimals, symbol)
|
||||
}
|
||||
|
||||
// For intermediate values, use a more sophisticated check
|
||||
// If the number would represent more than 1000 tokens when treated as human-readable,
|
||||
// it's probably meant to be raw
|
||||
if valueInt > 1000 {
|
||||
return NewUniversalDecimal(value, decimals, symbol)
|
||||
}
|
||||
|
||||
// Default: treat as human-readable
|
||||
scalingFactor := dc.getScalingFactor(decimals)
|
||||
scaledValue := new(big.Int).Mul(value, scalingFactor)
|
||||
return NewUniversalDecimal(scaledValue, decimals, symbol)
|
||||
}
|
||||
|
||||
// fromDecimalString parses decimal string (e.g., "1.23") to smallest unit
|
||||
func (dc *DecimalConverter) fromDecimalString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) {
|
||||
parts := strings.Split(valueStr, ".")
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid decimal format: %s for token %s", valueStr, symbol)
|
||||
}
|
||||
|
||||
integerPart := parts[0]
|
||||
decimalPart := parts[1]
|
||||
|
||||
// Validate decimal part doesn't exceed token decimals
|
||||
if len(decimalPart) > int(decimals) {
|
||||
return nil, fmt.Errorf("decimal part %s has %d digits, but token %s only supports %d decimals",
|
||||
decimalPart, len(decimalPart), symbol, decimals)
|
||||
}
|
||||
|
||||
// Parse integer part
|
||||
intValue := new(big.Int)
|
||||
if integerPart != "" && integerPart != "0" {
|
||||
_, success := intValue.SetString(integerPart, 10)
|
||||
if !success {
|
||||
return nil, fmt.Errorf("invalid integer part: %s for token %s", integerPart, symbol)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse decimal part
|
||||
decValue := new(big.Int)
|
||||
if decimalPart != "" && decimalPart != "0" {
|
||||
// Pad decimal part to full precision
|
||||
paddedDecimal := decimalPart
|
||||
for len(paddedDecimal) < int(decimals) {
|
||||
paddedDecimal += "0"
|
||||
}
|
||||
|
||||
_, success := decValue.SetString(paddedDecimal, 10)
|
||||
if !success {
|
||||
return nil, fmt.Errorf("invalid decimal part: %s for token %s", decimalPart, symbol)
|
||||
}
|
||||
}
|
||||
|
||||
// Combine integer and decimal parts
|
||||
scalingFactor := dc.getScalingFactor(decimals)
|
||||
totalValue := new(big.Int).Mul(intValue, scalingFactor)
|
||||
totalValue.Add(totalValue, decValue)
|
||||
|
||||
return NewUniversalDecimal(totalValue, decimals, symbol)
|
||||
}
|
||||
|
||||
// ToHumanReadable converts to human-readable decimal string
|
||||
// For round-trip precision preservation with FromString, returns raw value when appropriate
|
||||
func (dc *DecimalConverter) ToHumanReadable(ud *UniversalDecimal) string {
|
||||
if ud.Value.Sign() == 0 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
// For round-trip precision preservation, if the value represents exact units
|
||||
// (like 1000000000000000000 wei = exactly 1 ETH), output the human readable form
|
||||
// Otherwise, output the raw value to preserve precision
|
||||
|
||||
if ud.Decimals == 0 {
|
||||
return ud.Value.String()
|
||||
}
|
||||
|
||||
scalingFactor := dc.getScalingFactor(ud.Decimals)
|
||||
|
||||
// Get integer and remainder parts
|
||||
integerPart := new(big.Int).Div(ud.Value, scalingFactor)
|
||||
remainder := new(big.Int).Mod(ud.Value, scalingFactor)
|
||||
|
||||
// If this is an exact unit (no fractional part), return human readable
|
||||
if remainder.Sign() == 0 {
|
||||
return integerPart.String()
|
||||
}
|
||||
|
||||
// For values with fractional parts, we need to decide:
|
||||
// If the value looks like it came from raw input (very large numbers),
|
||||
// preserve it as raw to maintain round-trip precision
|
||||
|
||||
// Check if this looks like a raw value by comparing magnitude
|
||||
valueStr := ud.Value.String()
|
||||
if len(valueStr) >= int(ud.Decimals) {
|
||||
// This is likely a raw value, preserve as raw for round-trip
|
||||
return ud.Value.String()
|
||||
}
|
||||
|
||||
// Format as human readable decimal
|
||||
decimalStr := remainder.String()
|
||||
for len(decimalStr) < int(ud.Decimals) {
|
||||
decimalStr = "0" + decimalStr
|
||||
}
|
||||
|
||||
// Remove trailing zeros for readability
|
||||
decimalStr = strings.TrimRight(decimalStr, "0")
|
||||
if decimalStr == "" {
|
||||
return integerPart.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s.%s", integerPart.String(), decimalStr)
|
||||
}
|
||||
|
||||
// ConvertTo converts between different decimal precisions
|
||||
func (dc *DecimalConverter) ConvertTo(from *UniversalDecimal, toDecimals uint8, toSymbol string) (*UniversalDecimal, error) {
|
||||
if from.Decimals == toDecimals {
|
||||
// Same precision, just copy with new symbol
|
||||
return NewUniversalDecimal(from.Value, toDecimals, toSymbol)
|
||||
}
|
||||
|
||||
var convertedValue *big.Int
|
||||
|
||||
if from.Decimals < toDecimals {
|
||||
// Increase precision (multiply)
|
||||
decimalDiff := toDecimals - from.Decimals
|
||||
scalingFactor := dc.getScalingFactor(decimalDiff)
|
||||
convertedValue = new(big.Int).Mul(from.Value, scalingFactor)
|
||||
} else {
|
||||
// Decrease precision (divide with rounding)
|
||||
decimalDiff := from.Decimals - toDecimals
|
||||
scalingFactor := dc.getScalingFactor(decimalDiff)
|
||||
|
||||
// Round to nearest (banker's rounding)
|
||||
halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2))
|
||||
roundedValue := new(big.Int).Add(from.Value, halfScaling)
|
||||
convertedValue = new(big.Int).Div(roundedValue, scalingFactor)
|
||||
}
|
||||
|
||||
return NewUniversalDecimal(convertedValue, toDecimals, toSymbol)
|
||||
}
|
||||
|
||||
// Multiply performs precise multiplication between different decimal tokens with overflow protection
|
||||
func (dc *DecimalConverter) Multiply(a, b *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) {
|
||||
// Check for overflow potential before multiplication
|
||||
maxSafeValue := new(big.Int)
|
||||
maxSafeValue.Exp(big.NewInt(10), big.NewInt(30), nil) // Conservative limit for multiplication
|
||||
|
||||
if a.Value.Cmp(maxSafeValue) > 0 || b.Value.Cmp(maxSafeValue) > 0 {
|
||||
return nil, fmt.Errorf("values too large for safe multiplication: %s * %s", a.Symbol, b.Symbol)
|
||||
}
|
||||
|
||||
// Multiply raw values
|
||||
product := new(big.Int).Mul(a.Value, b.Value)
|
||||
|
||||
// Adjust for decimal places (division by 10^(a.decimals + b.decimals - result.decimals))
|
||||
totalInputDecimals := a.Decimals + b.Decimals
|
||||
|
||||
var adjustedProduct *big.Int
|
||||
if totalInputDecimals >= resultDecimals {
|
||||
decimalDiff := totalInputDecimals - resultDecimals
|
||||
scalingFactor := dc.getScalingFactor(decimalDiff)
|
||||
|
||||
// Round to nearest
|
||||
halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2))
|
||||
roundedProduct := new(big.Int).Add(product, halfScaling)
|
||||
adjustedProduct = new(big.Int).Div(roundedProduct, scalingFactor)
|
||||
} else {
|
||||
decimalDiff := resultDecimals - totalInputDecimals
|
||||
scalingFactor := dc.getScalingFactor(decimalDiff)
|
||||
adjustedProduct = new(big.Int).Mul(product, scalingFactor)
|
||||
}
|
||||
|
||||
return NewUniversalDecimal(adjustedProduct, resultDecimals, resultSymbol)
|
||||
}
|
||||
|
||||
// Divide performs precise division between different decimal tokens
|
||||
func (dc *DecimalConverter) Divide(numerator, denominator *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) {
|
||||
if denominator.Value.Sign() == 0 {
|
||||
return nil, fmt.Errorf("division by zero: %s / %s", numerator.Symbol, denominator.Symbol)
|
||||
}
|
||||
|
||||
// Scale numerator to maintain precision
|
||||
totalDecimals := numerator.Decimals + resultDecimals
|
||||
scalingFactor := dc.getScalingFactor(totalDecimals - denominator.Decimals)
|
||||
|
||||
scaledNumerator := new(big.Int).Mul(numerator.Value, scalingFactor)
|
||||
quotient := new(big.Int).Div(scaledNumerator, denominator.Value)
|
||||
|
||||
return NewUniversalDecimal(quotient, resultDecimals, resultSymbol)
|
||||
}
|
||||
|
||||
// Add adds two UniversalDecimals with same precision and overflow protection
|
||||
func (dc *DecimalConverter) Add(a, b *UniversalDecimal) (*UniversalDecimal, error) {
|
||||
if a.Decimals != b.Decimals {
|
||||
return nil, fmt.Errorf("cannot add tokens with different decimals: %s(%d) + %s(%d)",
|
||||
a.Symbol, a.Decimals, b.Symbol, b.Decimals)
|
||||
}
|
||||
|
||||
// Check for potential overflow before performing addition
|
||||
maxSafeValue := new(big.Int)
|
||||
maxSafeValue.Exp(big.NewInt(10), big.NewInt(59), nil) // 10^59 for safety margin
|
||||
|
||||
if a.Value.Cmp(maxSafeValue) > 0 || b.Value.Cmp(maxSafeValue) > 0 {
|
||||
return nil, fmt.Errorf("values too large for safe addition: %s + %s", a.Symbol, b.Symbol)
|
||||
}
|
||||
|
||||
sum := new(big.Int).Add(a.Value, b.Value)
|
||||
resultSymbol := a.Symbol
|
||||
if a.Symbol != b.Symbol {
|
||||
resultSymbol = fmt.Sprintf("%s+%s", a.Symbol, b.Symbol)
|
||||
}
|
||||
|
||||
return NewUniversalDecimal(sum, a.Decimals, resultSymbol)
|
||||
}
|
||||
|
||||
// Subtract subtracts two UniversalDecimals with same precision
|
||||
func (dc *DecimalConverter) Subtract(a, b *UniversalDecimal) (*UniversalDecimal, error) {
|
||||
if a.Decimals != b.Decimals {
|
||||
return nil, fmt.Errorf("cannot subtract tokens with different decimals: %s(%d) - %s(%d)",
|
||||
a.Symbol, a.Decimals, b.Symbol, b.Decimals)
|
||||
}
|
||||
|
||||
diff := new(big.Int).Sub(a.Value, b.Value)
|
||||
resultSymbol := a.Symbol
|
||||
if a.Symbol != b.Symbol {
|
||||
resultSymbol = fmt.Sprintf("%s-%s", a.Symbol, b.Symbol)
|
||||
}
|
||||
|
||||
return NewUniversalDecimal(diff, a.Decimals, resultSymbol)
|
||||
}
|
||||
|
||||
// Compare returns -1, 0, or 1 for a < b, a == b, a > b respectively
|
||||
func (dc *DecimalConverter) Compare(a, b *UniversalDecimal) (int, error) {
|
||||
if a.Decimals != b.Decimals {
|
||||
// Convert to same precision for comparison
|
||||
converted, err := dc.ConvertTo(b, a.Decimals, b.Symbol)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot compare tokens with different decimals: %w", err)
|
||||
}
|
||||
b = converted
|
||||
}
|
||||
|
||||
return a.Value.Cmp(b.Value), nil
|
||||
}
|
||||
|
||||
// IsZero checks if the value is zero
|
||||
func (ud *UniversalDecimal) IsZero() bool {
|
||||
return ud.Value.Sign() == 0
|
||||
}
|
||||
|
||||
// IsPositive checks if the value is positive
|
||||
func (ud *UniversalDecimal) IsPositive() bool {
|
||||
return ud.Value.Sign() > 0
|
||||
}
|
||||
|
||||
// IsNegative checks if the value is negative
|
||||
func (ud *UniversalDecimal) IsNegative() bool {
|
||||
return ud.Value.Sign() < 0
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the UniversalDecimal
|
||||
func (ud *UniversalDecimal) Copy() *UniversalDecimal {
|
||||
return &UniversalDecimal{
|
||||
Value: new(big.Int).Set(ud.Value),
|
||||
Decimals: ud.Decimals,
|
||||
Symbol: ud.Symbol,
|
||||
}
|
||||
}
|
||||
|
||||
// getScalingFactor returns the scaling factor for given decimals (cached)
|
||||
func (dc *DecimalConverter) getScalingFactor(decimals uint8) *big.Int {
|
||||
if factor, exists := dc.scalingFactors[decimals]; exists {
|
||||
return factor
|
||||
}
|
||||
|
||||
// Calculate and cache if not exists (shouldn't happen for 0-18)
|
||||
factor := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimals)), nil)
|
||||
dc.scalingFactors[decimals] = factor
|
||||
return factor
|
||||
}
|
||||
|
||||
// ToWei converts any decimal precision to 18-decimal wei representation
|
||||
func (dc *DecimalConverter) ToWei(ud *UniversalDecimal) *UniversalDecimal {
|
||||
weiValue, _ := dc.ConvertTo(ud, 18, "WEI")
|
||||
return weiValue
|
||||
}
|
||||
|
||||
// FromWei converts 18-decimal wei to specified decimal precision
|
||||
func (dc *DecimalConverter) FromWei(weiValue *big.Int, targetDecimals uint8, targetSymbol string) *UniversalDecimal {
|
||||
weiDecimal := &UniversalDecimal{
|
||||
Value: new(big.Int).Set(weiValue),
|
||||
Decimals: 18,
|
||||
Symbol: "WEI",
|
||||
}
|
||||
|
||||
result, _ := dc.ConvertTo(weiDecimal, targetDecimals, targetSymbol)
|
||||
return result
|
||||
}
|
||||
|
||||
// CalculatePercentage calculates percentage with precise decimal handling
|
||||
// Returns percentage as UniversalDecimal with 4 decimal places (e.g., 1.5000% = 15000 with 4 decimals)
|
||||
func (dc *DecimalConverter) CalculatePercentage(value, total *UniversalDecimal) (*UniversalDecimal, error) {
|
||||
if total.IsZero() {
|
||||
return nil, fmt.Errorf("cannot calculate percentage with zero total")
|
||||
}
|
||||
|
||||
// Convert to same precision if needed
|
||||
if value.Decimals != total.Decimals {
|
||||
convertedValue, err := dc.ConvertTo(value, total.Decimals, value.Symbol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting decimals for percentage: %w", err)
|
||||
}
|
||||
value = convertedValue
|
||||
}
|
||||
|
||||
// Calculate (value / total) * 100 using integer arithmetic to avoid floating point errors
|
||||
// Formula: (value * 100 * 10^4) / total where 10^4 gives us 4 decimal places
|
||||
|
||||
// Multiply value by 100 * 10^4 = 1,000,000 for percentage with 4 decimal places
|
||||
hundredWithDecimals := big.NewInt(1000000) // 100.0000 in 4-decimal format
|
||||
numerator := new(big.Int).Mul(value.Value, hundredWithDecimals)
|
||||
|
||||
// Divide by total to get percentage
|
||||
percentage := new(big.Int).Div(numerator, total.Value)
|
||||
|
||||
return NewUniversalDecimal(percentage, 4, "PERCENT")
|
||||
}
|
||||
|
||||
// String returns string representation for debugging
|
||||
func (ud *UniversalDecimal) String() string {
|
||||
dc := NewDecimalConverter()
|
||||
humanReadable := dc.ToHumanReadable(ud)
|
||||
return fmt.Sprintf("%s %s", humanReadable, ud.Symbol)
|
||||
}
|
||||
378
orig/pkg/math/dex_math.go
Normal file
378
orig/pkg/math/dex_math.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
|
||||
"github.com/holiman/uint256"
|
||||
)
|
||||
|
||||
// UniswapV4Math implements Uniswap V4 mathematical calculations
|
||||
type UniswapV4Math struct{}
|
||||
|
||||
// AlgebraV1Math implements Algebra V1.9 mathematical calculations
|
||||
type AlgebraV1Math struct{}
|
||||
|
||||
// IntegralMath implements Integral mathematical calculations
|
||||
type IntegralMath struct{}
|
||||
|
||||
// KyberMath implements Kyber mathematical calculations
|
||||
type KyberMath struct{}
|
||||
|
||||
// OneInchMath implements 1Inch mathematical calculations
|
||||
type OneInchMath struct{}
|
||||
|
||||
// ========== Uniswap V4 Math =========
|
||||
|
||||
// NewUniswapV4Math creates a new Uniswap V4 math calculator
|
||||
func NewUniswapV4Math() *UniswapV4Math {
|
||||
return &UniswapV4Math{}
|
||||
}
|
||||
|
||||
// CalculateAmountOutV4 calculates output amount for Uniswap V4
|
||||
// Uniswap V4 uses hooks and pre/post-swap hooks for additional functionality
|
||||
func (u *UniswapV4Math) CalculateAmountOutV4(amountIn, sqrtPriceX96, liquidity, currentTick, tickSpacing, fee uint256.Int) (*uint256.Int, error) {
|
||||
if amountIn.IsZero() || sqrtPriceX96.IsZero() || liquidity.IsZero() {
|
||||
return nil, fmt.Errorf("invalid parameters")
|
||||
}
|
||||
|
||||
// For Uniswap V4, we reuse V3 calculations with hook considerations
|
||||
// In practice, V4 introduces hooks which can modify the calculation
|
||||
// This is a simplified implementation based on V3
|
||||
|
||||
// Apply fee: amountInWithFee = amountIn * (1000000 - fee) / 1000000
|
||||
feeFactor := uint256.NewInt(1000000).Sub(uint256.NewInt(1000000), &fee)
|
||||
amountInWithFee := new(uint256.Int).Mul(&amountIn, feeFactor)
|
||||
amountInWithFee.Div(amountInWithFee, uint256.NewInt(1000000))
|
||||
|
||||
// Calculate price change using liquidity and amountIn
|
||||
Q96 := uint256.NewInt(1).Lsh(uint256.NewInt(1), 96)
|
||||
priceChange := new(uint256.Int).Mul(amountInWithFee, Q96)
|
||||
priceChange.Div(priceChange, &liquidity)
|
||||
|
||||
// Calculate new sqrt price after swap
|
||||
newSqrtPriceX96 := new(uint256.Int).Add(&sqrtPriceX96, priceChange)
|
||||
|
||||
// Calculate amount out based on price difference and liquidity
|
||||
priceDiff := new(uint256.Int).Sub(newSqrtPriceX96, &sqrtPriceX96)
|
||||
amountOut := new(uint256.Int).Mul(&liquidity, priceDiff)
|
||||
amountOut.Div(amountOut, &sqrtPriceX96)
|
||||
|
||||
return amountOut, nil
|
||||
}
|
||||
|
||||
// ========== Algebra V1.9 Math ==========
|
||||
|
||||
// NewAlgebraV1Math creates a new Algebra V1.9 math calculator
|
||||
func NewAlgebraV1Math() *AlgebraV1Math {
|
||||
return &AlgebraV1Math{}
|
||||
}
|
||||
|
||||
// CalculateAmountOutAlgebra calculates output amount for Algebra V1.9
|
||||
func (a *AlgebraV1Math) CalculateAmountOutAlgebra(amountIn, reserveIn, reserveOut *big.Int, fee uint32) (*big.Int, error) {
|
||||
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return nil, fmt.Errorf("invalid amounts")
|
||||
}
|
||||
|
||||
// Algebra uses a dynamic fee model based on volatility
|
||||
if fee == 0 {
|
||||
fee = 500 // Default 0.05% for Algebra
|
||||
}
|
||||
|
||||
// Calculate fee amount (10000 = 100%)
|
||||
feeFactor := big.NewInt(int64(10000 - fee))
|
||||
amountInWithFee := new(big.Int).Mul(amountIn, feeFactor)
|
||||
|
||||
// For Algebra, we also consider dynamic fees and volatility
|
||||
// This is a simplified implementation based on Uniswap V2 with dynamic fee consideration
|
||||
numerator := new(big.Int).Mul(amountInWithFee, reserveOut)
|
||||
denominator := new(big.Int).Mul(reserveIn, big.NewInt(10000))
|
||||
denominator.Add(denominator, amountInWithFee)
|
||||
|
||||
if denominator.Sign() == 0 {
|
||||
return nil, fmt.Errorf("division by zero in amountOut calculation")
|
||||
}
|
||||
|
||||
amountOut := new(big.Int).Div(numerator, denominator)
|
||||
return amountOut, nil
|
||||
}
|
||||
|
||||
// CalculatePriceImpactAlgebra calculates price impact for Algebra V1.9
|
||||
func (a *AlgebraV1Math) CalculatePriceImpactAlgebra(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
|
||||
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return 0, fmt.Errorf("invalid amounts")
|
||||
}
|
||||
|
||||
// Calculate new reserves after swap
|
||||
amountOut, err := a.CalculateAmountOutAlgebra(amountIn, reserveIn, reserveOut, 500)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
|
||||
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
|
||||
|
||||
// Calculate price before and after swap
|
||||
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
|
||||
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
|
||||
|
||||
// Calculate price impact
|
||||
impact := new(big.Float).Sub(priceBefore, priceAfter)
|
||||
impact.Quo(impact, priceBefore)
|
||||
|
||||
impactFloat, _ := impact.Float64()
|
||||
return math.Abs(impactFloat), nil
|
||||
}
|
||||
|
||||
// ========== Integral Math ==========
|
||||
|
||||
// NewIntegralMath creates a new Integral math calculator
|
||||
func NewIntegralMath() *IntegralMath {
|
||||
return &IntegralMath{}
|
||||
}
|
||||
|
||||
// CalculateAmountOutIntegral calculates output for Integral with base fee model
|
||||
func (i *IntegralMath) CalculateAmountOutIntegral(amountIn, reserveIn, reserveOut *big.Int, baseFee uint32) (*big.Int, error) {
|
||||
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return nil, fmt.Errorf("invalid amounts")
|
||||
}
|
||||
|
||||
// Integral uses a base fee model for more efficient gas usage
|
||||
// Calculate effective fee based on base fee and market conditions
|
||||
if baseFee == 0 {
|
||||
baseFee = 100 // Default base fee of 0.01%
|
||||
}
|
||||
|
||||
// For Integral, we implement the base fee model
|
||||
feeFactor := big.NewInt(int64(10000 - baseFee))
|
||||
amountInWithFee := new(big.Int).Mul(amountIn, feeFactor)
|
||||
|
||||
// Calculate amount out with base fee
|
||||
numerator := new(big.Int).Mul(amountInWithFee, reserveOut)
|
||||
denominator := new(big.Int).Mul(reserveIn, big.NewInt(10000))
|
||||
denominator.Add(denominator, amountInWithFee)
|
||||
|
||||
if denominator.Sign() == 0 {
|
||||
return nil, fmt.Errorf("division by zero in amountOut calculation")
|
||||
}
|
||||
|
||||
amountOut := new(big.Int).Div(numerator, denominator)
|
||||
return amountOut, nil
|
||||
}
|
||||
|
||||
// ========== Kyber Math ==========
|
||||
|
||||
// NewKyberMath creates a new Kyber math calculator
|
||||
func NewKyberMath() *KyberMath {
|
||||
return &KyberMath{}
|
||||
}
|
||||
|
||||
// CalculateAmountOutKyber calculates output for Kyber Elastic and Classic
|
||||
func (k *KyberMath) CalculateAmountOutKyber(amountIn, sqrtPriceX96, liquidity *big.Int, fee uint32) (*big.Int, error) {
|
||||
if amountIn.Sign() <= 0 || sqrtPriceX96.Sign() <= 0 || liquidity.Sign() <= 0 {
|
||||
return nil, fmt.Errorf("invalid parameters")
|
||||
}
|
||||
|
||||
// Kyber Elastic uses concentrated liquidity similar to Uniswap V3
|
||||
// but with different fee structures and mechanisms
|
||||
|
||||
if fee == 0 {
|
||||
fee = 1000 // Default 0.1% for Kyber
|
||||
}
|
||||
|
||||
// Apply fee: amountInWithFee = amountIn * (1000000 - fee) / 1000000
|
||||
feeFactor := big.NewInt(int64(1000000 - fee))
|
||||
amountInWithFee := new(big.Int).Mul(amountIn, feeFactor)
|
||||
amountInWithFee.Div(amountInWithFee, big.NewInt(1000000))
|
||||
|
||||
// Calculate price change using liquidity and amountIn
|
||||
Q96 := new(big.Int).Lsh(big.NewInt(1), 96)
|
||||
priceChange := new(big.Int).Mul(amountInWithFee, Q96)
|
||||
priceChange.Div(priceChange, liquidity)
|
||||
|
||||
// Calculate new sqrt price after swap
|
||||
newSqrtPriceX96 := new(big.Int).Add(sqrtPriceX96, priceChange)
|
||||
|
||||
// Calculate amount out based on price difference and liquidity
|
||||
priceDiff := new(big.Int).Sub(newSqrtPriceX96, sqrtPriceX96)
|
||||
amountOut := new(big.Int).Mul(liquidity, priceDiff)
|
||||
amountOut.Div(amountOut, sqrtPriceX96)
|
||||
|
||||
return amountOut, nil
|
||||
}
|
||||
|
||||
// CalculateAmountOutKyberClassic calculates output for Kyber Classic reserves
|
||||
func (k *KyberMath) CalculateAmountOutKyberClassic(amountIn, reserveIn, reserveOut *big.Int, fee uint32) (*big.Int, error) {
|
||||
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return nil, fmt.Errorf("invalid amounts")
|
||||
}
|
||||
|
||||
// Kyber Classic has a different mechanism than Elastic
|
||||
// This is a simplified implementation based on Kyber Classic formula
|
||||
if fee == 0 {
|
||||
fee = 2500 // Default 0.25% for Kyber Classic
|
||||
}
|
||||
|
||||
// Calculate fee amount
|
||||
feeFactor := big.NewInt(int64(10000 - fee))
|
||||
amountInWithFee := new(big.Int).Mul(amountIn, feeFactor)
|
||||
|
||||
// Calculate amount out with consideration for Kyber's amplification factor
|
||||
numerator := new(big.Int).Mul(amountInWithFee, reserveOut)
|
||||
denominator := new(big.Int).Mul(reserveIn, big.NewInt(10000))
|
||||
denominator.Add(denominator, amountInWithFee)
|
||||
|
||||
if denominator.Sign() == 0 {
|
||||
return nil, fmt.Errorf("division by zero in amountOut calculation")
|
||||
}
|
||||
|
||||
amountOut := new(big.Int).Div(numerator, denominator)
|
||||
return amountOut, nil
|
||||
}
|
||||
|
||||
// ========== 1Inch Math ==========
|
||||
|
||||
// NewOneInchMath creates a new 1Inch math calculator
|
||||
func NewOneInchMath() *OneInchMath {
|
||||
return &OneInchMath{}
|
||||
}
|
||||
|
||||
// CalculateAmountOutOneInch calculates output for 1Inch aggregation
|
||||
func (o *OneInchMath) CalculateAmountOutOneInch(amountIn *big.Int, multiHopPath []PathElement) (*big.Int, error) {
|
||||
if amountIn.Sign() <= 0 {
|
||||
return nil, fmt.Errorf("invalid amountIn")
|
||||
}
|
||||
|
||||
result := new(big.Int).Set(amountIn)
|
||||
|
||||
// 1Inch aggregates multiple DEXs with different routing algorithms
|
||||
// This is a simplified multi-hop calculation
|
||||
for _, pathElement := range multiHopPath {
|
||||
var amountOut *big.Int
|
||||
var err error
|
||||
|
||||
switch pathElement.Protocol {
|
||||
case "uniswap_v2":
|
||||
amountOut, err = NewUniswapV2Math().CalculateAmountOut(result, pathElement.ReserveIn, pathElement.ReserveOut, pathElement.Fee)
|
||||
case "uniswap_v3":
|
||||
amountOut, err = NewUniswapV3Math().CalculateAmountOut(result, pathElement.SqrtPriceX96, pathElement.Liquidity, pathElement.Fee)
|
||||
case "kyber_elastic", "kyber_classic":
|
||||
amountOut, err = NewKyberMath().CalculateAmountOut(result, pathElement.SqrtPriceX96, pathElement.Liquidity, pathElement.Fee)
|
||||
case "curve":
|
||||
amountOut, err = NewCurveMath().CalculateAmountOut(result, pathElement.ReserveIn, pathElement.ReserveOut, pathElement.Fee)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", pathElement.Protocol)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = amountOut
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// PathElement represents a single step in a multi-hop path
|
||||
type PathElement struct {
|
||||
Protocol string
|
||||
ReserveIn *big.Int
|
||||
ReserveOut *big.Int
|
||||
SqrtPriceX96 *big.Int
|
||||
Liquidity *big.Int
|
||||
Fee uint32
|
||||
}
|
||||
|
||||
// ========== Price Movement Detection Functions ==========
|
||||
|
||||
// WillSwapMovePrice determines if a swap will significantly move the price of a pool
|
||||
func WillSwapMovePrice(amountIn, reserveIn, reserveOut *big.Int, threshold float64) (bool, float64, error) {
|
||||
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return false, 0, fmt.Errorf("invalid parameters")
|
||||
}
|
||||
|
||||
// Calculate price impact
|
||||
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
|
||||
|
||||
// Calculate output for the proposed swap
|
||||
amountOut, err := NewUniswapV2Math().CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
|
||||
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
|
||||
|
||||
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
|
||||
|
||||
// Calculate price impact as percentage
|
||||
impact := new(big.Float).Sub(priceBefore, priceAfter)
|
||||
impact.Quo(impact, priceBefore)
|
||||
impact.Abs(impact)
|
||||
|
||||
impactFloat, _ := impact.Float64()
|
||||
|
||||
// Check if price impact exceeds threshold (e.g., 1%)
|
||||
movesPrice := impactFloat >= threshold
|
||||
|
||||
return movesPrice, impactFloat, nil
|
||||
}
|
||||
|
||||
// WillLiquidityMovePrice determines if a liquidity addition/removal will significantly move the price
|
||||
func WillLiquidityMovePrice(amount0, amount1, reserve0, reserve1 *big.Int, threshold float64) (bool, float64, error) {
|
||||
if reserve0.Sign() <= 0 || reserve1.Sign() <= 0 {
|
||||
return false, 0, fmt.Errorf("invalid reserves")
|
||||
}
|
||||
|
||||
// Check if amounts are valid for the provided reserves
|
||||
if (amount0.Sign() < 0 && new(big.Int).Abs(amount0).Cmp(reserve0) > 0) ||
|
||||
(amount1.Sign() < 0 && new(big.Int).Abs(amount1).Cmp(reserve1) > 0) {
|
||||
return false, 0, fmt.Errorf("removing more liquidity than available")
|
||||
}
|
||||
|
||||
// Calculate price before liquidity change
|
||||
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserve1), new(big.Float).SetInt(reserve0))
|
||||
|
||||
// Calculate new reserves after liquidity change
|
||||
newReserve0 := new(big.Int).Add(reserve0, amount0)
|
||||
newReserve1 := new(big.Int).Add(reserve1, amount1)
|
||||
|
||||
// Ensure reserves don't go negative
|
||||
if newReserve0.Sign() <= 0 || newReserve1.Sign() <= 0 {
|
||||
return false, 0, fmt.Errorf("liquidity change would result in negative reserves")
|
||||
}
|
||||
|
||||
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserve1), new(big.Float).SetInt(newReserve0))
|
||||
|
||||
// Calculate price impact as percentage
|
||||
impact := new(big.Float).Sub(priceBefore, priceAfter)
|
||||
impact.Quo(impact, priceBefore)
|
||||
impact.Abs(impact)
|
||||
|
||||
impactFloat, _ := impact.Float64()
|
||||
|
||||
// Check if price impact exceeds threshold
|
||||
movesPrice := impactFloat >= threshold
|
||||
|
||||
return movesPrice, impactFloat, nil
|
||||
}
|
||||
|
||||
// CalculateRequiredAmountForPriceMove calculates how much would need to be swapped to move price by a certain percentage
|
||||
func CalculateRequiredAmountForPriceMove(targetPriceMove float64, reserveIn, reserveOut *big.Int) (*big.Int, error) {
|
||||
if targetPriceMove <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return nil, fmt.Errorf("invalid parameters")
|
||||
}
|
||||
|
||||
// This is a simplified calculation - in practice this would require more complex math
|
||||
// using binary search or other numerical methods
|
||||
|
||||
// This is an estimation, for exact calculation, we'd need to use more sophisticated methods
|
||||
// such as binary search to find the exact amount required
|
||||
|
||||
estimatedAmount := new(big.Int).Div(reserveIn, big.NewInt(100)) // 1% of reserve as estimation
|
||||
estimatedAmount.Mul(estimatedAmount, big.NewInt(int64(targetPriceMove*100)))
|
||||
|
||||
return estimatedAmount, nil
|
||||
}
|
||||
398
orig/pkg/math/dex_math_test.go
Normal file
398
orig/pkg/math/dex_math_test.go
Normal file
@@ -0,0 +1,398 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestUniswapV2Calculations tests Uniswap V2 calculations against known values
|
||||
func TestUniswapV2Calculations(t *testing.T) {
|
||||
math := NewUniswapV2Math()
|
||||
|
||||
// Test case from Uniswap V2 documentation
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
|
||||
reserveOut, _ := new(big.Int).SetString("100000000000000000000", 10) // 100 DAI
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
|
||||
|
||||
// Correct calculation using Uniswap V2 formula:
|
||||
// amountOut = (amountIn * reserveOut * (10000 - fee)) / (reserveIn * 10000 + amountIn * (10000 - fee))
|
||||
// With fee = 3000 (0.3%), amountIn = 0.1 ETH, reserveIn = 1 ETH, reserveOut = 100 DAI
|
||||
// amountOut = (0.1 * 100 * 9970) / (1 * 10000 + 0.1 * 9970) = 9970 / 10997 ≈ 9.0661 DAI
|
||||
// In wei: 9066100000000000000
|
||||
expectedOut, _ := new(big.Int).SetString("6542056074766355140", 10) // Correct expected value
|
||||
|
||||
result, err := math.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculateAmountOut failed: %v", err)
|
||||
}
|
||||
|
||||
// We expect the result to be close to the expected value
|
||||
// Note: The actual result may vary slightly due to rounding
|
||||
if result.Cmp(expectedOut) < 0 {
|
||||
t.Errorf("Expected %s, got %s", expectedOut.String(), result.String())
|
||||
}
|
||||
|
||||
// Test price impact
|
||||
impact, err := math.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculatePriceImpact failed: %v", err)
|
||||
}
|
||||
|
||||
if impact <= 0 {
|
||||
t.Errorf("Expected positive price impact, got %f", impact)
|
||||
}
|
||||
|
||||
// Test slippage
|
||||
actualOut := result
|
||||
slippage, err := math.CalculateSlippage(expectedOut, actualOut)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculateSlippage failed: %v", err)
|
||||
}
|
||||
|
||||
if slippage < 0 {
|
||||
t.Errorf("Expected non-negative slippage, got %f", slippage)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCurveCalculations tests Curve calculations against known values
|
||||
func TestCurveCalculations(t *testing.T) {
|
||||
math := NewCurveMath()
|
||||
|
||||
// Test case with reasonable values for stablecoins
|
||||
balance0, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 DAI
|
||||
balance1, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 USDC
|
||||
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 DAI
|
||||
|
||||
result, err := math.CalculateAmountOut(amountIn, balance0, balance1, 400)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculateAmountOut failed: %v", err)
|
||||
}
|
||||
|
||||
// For a stable swap, we expect close to 1:1 exchange (with fees)
|
||||
expected, _ := new(big.Int).SetString("95000000000000000", 10) // 0.095 USDC after fees
|
||||
if result.Cmp(expected) < 0 {
|
||||
t.Errorf("Expected approximately 0.095 USDC, got %s", result.String())
|
||||
}
|
||||
|
||||
// Test price impact
|
||||
impact, err := math.CalculatePriceImpact(amountIn, balance0, balance1)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculatePriceImpact failed: %v", err)
|
||||
}
|
||||
|
||||
// Price impact in stable swaps should be relatively small
|
||||
// The actual value was 0.999636 which indicates a very small impact
|
||||
// but the test was checking for < 0.1, so let's adjust the expectation
|
||||
if impact > 1.0 { // More than 100% impact would be unusual for stable swap
|
||||
t.Errorf("Expected small price impact for stable swap, got %f", impact)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUniswapV3Calculations tests Uniswap V3 calculations
|
||||
func TestUniswapV3Calculations(t *testing.T) {
|
||||
math := NewUniswapV3Math()
|
||||
|
||||
// Test with reasonable values
|
||||
sqrtPriceX96, _ := new(big.Int).SetString("79228162514264337593543950336", 10) // 2^96, representing price of 1
|
||||
liquidity, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH worth of liquidity
|
||||
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
|
||||
|
||||
result, err := math.CalculateAmountOut(amountIn, sqrtPriceX96, liquidity, 3000)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculateAmountOut failed: %v", err)
|
||||
}
|
||||
|
||||
// With the given parameters, the result should be meaningful
|
||||
if result.Sign() <= 0 {
|
||||
t.Errorf("Expected positive output, got %s", result.String())
|
||||
}
|
||||
|
||||
// Test price impact
|
||||
impact, err := math.CalculatePriceImpact(amountIn, sqrtPriceX96, liquidity)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculatePriceImpact failed: %v", err)
|
||||
}
|
||||
|
||||
if impact < 0 {
|
||||
t.Errorf("Expected non-negative price impact, got %f", impact)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAlgebraV1Calculations tests Algebra V1.9 calculations
|
||||
func TestAlgebraV1Calculations(t *testing.T) {
|
||||
math := NewAlgebraV1Math()
|
||||
|
||||
// Test with reasonable values
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
|
||||
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10) // 2000 USDT
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
|
||||
|
||||
result, err := math.CalculateAmountOutAlgebra(amountIn, reserveIn, reserveOut, 500)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculateAmountOutAlgebra failed: %v", err)
|
||||
}
|
||||
|
||||
// With the given parameters, the result should be meaningful
|
||||
if result.Sign() <= 0 {
|
||||
t.Errorf("Expected positive output, got %s", result.String())
|
||||
}
|
||||
|
||||
// Test price impact
|
||||
impact, err := math.CalculatePriceImpactAlgebra(amountIn, reserveIn, reserveOut)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculatePriceImpactAlgebra failed: %v", err)
|
||||
}
|
||||
|
||||
if impact < 0 {
|
||||
t.Errorf("Expected non-negative price impact, got %f", impact)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegralCalculations tests Integral calculations
|
||||
func TestIntegralCalculations(t *testing.T) {
|
||||
math := NewIntegralMath()
|
||||
|
||||
// Test with reasonable values
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
|
||||
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10) // 2000 USDT
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
|
||||
|
||||
result, err := math.CalculateAmountOutIntegral(amountIn, reserveIn, reserveOut, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculateAmountOutIntegral failed: %v", err)
|
||||
}
|
||||
|
||||
// With the given parameters, the result should be meaningful
|
||||
if result.Sign() <= 0 {
|
||||
t.Errorf("Expected positive output, got %s", result.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestKyberCalculations tests Kyber calculations
|
||||
func TestKyberCalculations(t *testing.T) {
|
||||
math := &KyberMath{}
|
||||
|
||||
// Test with reasonable values
|
||||
sqrtPriceX96, _ := new(big.Int).SetString("79228162514264337593543950336", 10) // 2^96, representing price of 1
|
||||
liquidity, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH worth of liquidity
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
|
||||
|
||||
result, err := math.CalculateAmountOut(amountIn, sqrtPriceX96, liquidity, 1000)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculateAmountOut failed: %v", err)
|
||||
}
|
||||
|
||||
// With the given parameters, the result should be meaningful
|
||||
if result.Sign() <= 0 {
|
||||
t.Errorf("Expected positive output, got %s", result.String())
|
||||
}
|
||||
|
||||
// Test price impact
|
||||
impact, err := math.CalculatePriceImpact(amountIn, sqrtPriceX96, liquidity)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculatePriceImpact failed: %v", err)
|
||||
}
|
||||
|
||||
if impact < 0 {
|
||||
t.Errorf("Expected non-negative price impact, got %f", impact)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBalancerCalculations tests Balancer calculations
|
||||
func TestBalancerCalculations(t *testing.T) {
|
||||
math := &BalancerMath{}
|
||||
|
||||
// Test with reasonable values for a 50/50 weighted pool
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token with 50% weight
|
||||
reserveOut, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token with 50% weight
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 tokens
|
||||
|
||||
result, err := math.CalculateAmountOut(amountIn, reserveIn, reserveOut, 1000)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculateAmountOut failed: %v", err)
|
||||
}
|
||||
|
||||
// With the given parameters, the result should be meaningful
|
||||
if result.Sign() <= 0 {
|
||||
t.Errorf("Expected positive output, got %s", result.String())
|
||||
}
|
||||
|
||||
// Test price impact
|
||||
impact, err := math.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculatePriceImpact failed: %v", err)
|
||||
}
|
||||
|
||||
if impact < 0 {
|
||||
t.Errorf("Expected non-negative price impact, got %f", impact)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConstantSumCalculations tests Constant Sum calculations
|
||||
func TestConstantSumCalculations(t *testing.T) {
|
||||
math := &ConstantSumMath{}
|
||||
|
||||
// Test with reasonable values
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token
|
||||
reserveOut, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 token
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 tokens
|
||||
|
||||
expected, _ := new(big.Int).SetString("70000000000000000", 10) // 0.1 * 0.7 (30% fees from 3000/10000)
|
||||
result, err := math.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculateAmountOut failed: %v", err)
|
||||
}
|
||||
|
||||
// In a constant sum AMM, we get approximately 0.1 output with fees
|
||||
if result.Cmp(expected) < 0 {
|
||||
t.Errorf("Expected at least %s, got %s", expected.String(), result.String())
|
||||
}
|
||||
|
||||
// Test price impact (should be 0 in constant sum)
|
||||
impact, err := math.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculatePriceImpact failed: %v", err)
|
||||
}
|
||||
|
||||
// In constant sum, we expect minimal price impact
|
||||
if impact > 0.001 { // 0.1% tolerance
|
||||
t.Errorf("Expected minimal price impact in constant sum, got %f", impact)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPriceMovementDetection tests functions to detect if swaps move prices
|
||||
func TestPriceMovementDetection(t *testing.T) {
|
||||
// Test WillSwapMovePrice
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
|
||||
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10) // 2000 USDT
|
||||
amountIn, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH (50% of reserve!)
|
||||
|
||||
// This large swap should definitely move the price
|
||||
movesPrice, impact, err := WillSwapMovePrice(amountIn, reserveIn, reserveOut, 0.01) // 1% threshold
|
||||
if err != nil {
|
||||
t.Fatalf("WillSwapMovePrice failed: %v", err)
|
||||
}
|
||||
|
||||
if !movesPrice {
|
||||
t.Errorf("Expected large swap to move price, but it didn't (impact: %f)", impact)
|
||||
}
|
||||
|
||||
if impact <= 0 {
|
||||
t.Errorf("Expected positive impact, got %f", impact)
|
||||
}
|
||||
|
||||
// Test with a smaller swap that shouldn't move price much
|
||||
smallAmount, _ := new(big.Int).SetString("10000000000000000", 10) // 0.01 ETH
|
||||
movesPrice, impact, err = WillSwapMovePrice(smallAmount, reserveIn, reserveOut, 0.10) // 10% threshold
|
||||
if err != nil {
|
||||
t.Fatalf("WillSwapMovePrice failed: %v", err)
|
||||
}
|
||||
|
||||
if movesPrice {
|
||||
t.Errorf("Expected small swap to not move price significantly, but it did (impact: %f)", impact)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLiquidityMovementDetection tests functions to detect if liquidity changes move prices
|
||||
func TestLiquidityMovementDetection(t *testing.T) {
|
||||
reserve0, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
|
||||
reserve1, _ := new(big.Int).SetString("2000000000000000000000", 10) // 2000 USDT
|
||||
|
||||
// Add significant liquidity (10% of reserves)
|
||||
amount0, _ := new(big.Int).SetString("100000000000000000", 10) // 0.1 ETH
|
||||
amount1, _ := new(big.Int).SetString("200000000000000000000", 10) // 200 USDT
|
||||
|
||||
movesPrice, impact, err := WillLiquidityMovePrice(amount0, amount1, reserve0, reserve1, 0.01) // 1% threshold
|
||||
if err != nil {
|
||||
t.Fatalf("WillLiquidityMovePrice failed: %v", err)
|
||||
}
|
||||
|
||||
// Adding balanced liquidity shouldn't significantly move price
|
||||
if movesPrice {
|
||||
t.Errorf("Expected balanced liquidity addition to not move price significantly, but it did (impact: %f)", impact)
|
||||
}
|
||||
|
||||
// Now test with unbalanced liquidity removal
|
||||
amount0, _ = new(big.Int).SetString("-500000000000000000", 10) // Remove 0.5 ETH
|
||||
amount1 = big.NewInt(0) // Don't change USDT
|
||||
|
||||
movesPrice, impact, err = WillLiquidityMovePrice(amount0, amount1, reserve0, reserve1, 0.01) // 1% threshold
|
||||
if err != nil {
|
||||
t.Fatalf("WillLiquidityMovePrice failed: %v", err)
|
||||
}
|
||||
|
||||
// Removing only one side of liquidity should move the price
|
||||
if !movesPrice {
|
||||
t.Errorf("Expected unbalanced liquidity removal to move price, but it didn't (impact: %f)", impact)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPriceImpactCalculator tests the unified price impact calculator
|
||||
func TestPriceImpactCalculator(t *testing.T) {
|
||||
calculator := NewPriceImpactCalculator()
|
||||
|
||||
// Test Uniswap V2
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
|
||||
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
|
||||
|
||||
impact, err := calculator.CalculatePriceImpact("uniswap_v2", amountIn, reserveIn, reserveOut, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculatePriceImpact failed for uniswap_v2: %v", err)
|
||||
}
|
||||
|
||||
if impact <= 0 {
|
||||
t.Errorf("Expected positive price impact, got %f", impact)
|
||||
}
|
||||
|
||||
// Test with threshold
|
||||
movesPrice, impact, err := calculator.CalculatePriceMovementThreshold("uniswap_v2", amountIn, reserveIn, reserveOut, nil, nil, 0.01)
|
||||
if err != nil {
|
||||
t.Fatalf("CalculatePriceMovementThreshold failed: %v", err)
|
||||
}
|
||||
|
||||
if !movesPrice {
|
||||
t.Errorf("Expected to move price above threshold, but it didn't (impact: %f)", impact)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkUniswapV2Calculations benchmarks Uniswap V2 calculations
|
||||
func BenchmarkUniswapV2Calculations(b *testing.B) {
|
||||
math := NewUniswapV2Math()
|
||||
reserveIn, _ := new(big.Int).SetString("1000000000000000000", 10)
|
||||
reserveOut, _ := new(big.Int).SetString("2000000000000000000000", 10)
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = math.CalculateAmountOut(amountIn, reserveIn, reserveOut, 3000)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCurveCalculations benchmarks Curve calculations
|
||||
func BenchmarkCurveCalculations(b *testing.B) {
|
||||
math := NewCurveMath()
|
||||
balance0, _ := new(big.Int).SetString("1000000000000000000", 10)
|
||||
balance1, _ := new(big.Int).SetString("1000000000000000000", 10)
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = math.CalculateAmountOut(amountIn, balance0, balance1, 400)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkUniswapV3Calculations benchmarks Uniswap V3 calculations
|
||||
func BenchmarkUniswapV3Calculations(b *testing.B) {
|
||||
math := NewUniswapV3Math()
|
||||
sqrtPriceX96, _ := new(big.Int).SetString("79228162514264337593543950336", 10)
|
||||
liquidity, _ := new(big.Int).SetString("1000000000000000000", 10)
|
||||
amountIn, _ := new(big.Int).SetString("100000000000000000", 10)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = math.CalculateAmountOut(amountIn, sqrtPriceX96, liquidity, 3000)
|
||||
}
|
||||
}
|
||||
1010
orig/pkg/math/exchange_math.go
Normal file
1010
orig/pkg/math/exchange_math.go
Normal file
File diff suppressed because it is too large
Load Diff
522
orig/pkg/math/exchange_pricing.go
Normal file
522
orig/pkg/math/exchange_pricing.go
Normal file
@@ -0,0 +1,522 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// ExchangeType represents different DEX protocols on Arbitrum
|
||||
type ExchangeType string
|
||||
|
||||
const (
|
||||
ExchangeUniswapV3 ExchangeType = "uniswap_v3"
|
||||
ExchangeUniswapV2 ExchangeType = "uniswap_v2"
|
||||
ExchangeSushiSwap ExchangeType = "sushiswap"
|
||||
ExchangeCamelot ExchangeType = "camelot"
|
||||
ExchangeBalancer ExchangeType = "balancer"
|
||||
ExchangeTraderJoe ExchangeType = "traderjoe"
|
||||
ExchangeRamses ExchangeType = "ramses"
|
||||
ExchangeCurve ExchangeType = "curve"
|
||||
ExchangeKyber ExchangeType = "kyber"
|
||||
ExchangeUniswapV4 ExchangeType = "uniswap_v4"
|
||||
)
|
||||
|
||||
// ExchangePricer interface for exchange-specific price calculations
|
||||
type ExchangePricer interface {
|
||||
GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error)
|
||||
CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
|
||||
CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
|
||||
CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
|
||||
GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error)
|
||||
ValidatePoolData(poolData *PoolData) error
|
||||
}
|
||||
|
||||
// PoolData represents universal pool data structure
|
||||
type PoolData struct {
|
||||
Address string
|
||||
ExchangeType ExchangeType
|
||||
Token0 TokenInfo
|
||||
Token1 TokenInfo
|
||||
Reserve0 *UniversalDecimal
|
||||
Reserve1 *UniversalDecimal
|
||||
Fee *UniversalDecimal // Fee as percentage (e.g., 0.003 for 0.3%)
|
||||
|
||||
// Uniswap V3 specific
|
||||
SqrtPriceX96 *big.Int
|
||||
Tick *big.Int
|
||||
Liquidity *big.Int
|
||||
|
||||
// Curve specific
|
||||
A *big.Int // Amplification coefficient
|
||||
|
||||
// Balancer specific
|
||||
Weights []*UniversalDecimal // Token weights
|
||||
SwapFeeRate *UniversalDecimal // Swap fee rate
|
||||
}
|
||||
|
||||
// TokenInfo represents token metadata
|
||||
type TokenInfo struct {
|
||||
Address string
|
||||
Symbol string
|
||||
Decimals uint8
|
||||
}
|
||||
|
||||
// ExchangePricingEngine manages all exchange-specific pricing logic
|
||||
type ExchangePricingEngine struct {
|
||||
decimalConverter *DecimalConverter
|
||||
pricers map[ExchangeType]ExchangePricer
|
||||
}
|
||||
|
||||
// NewExchangePricingEngine creates a new pricing engine with all exchange support
|
||||
func NewExchangePricingEngine() *ExchangePricingEngine {
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
engine := &ExchangePricingEngine{
|
||||
decimalConverter: dc,
|
||||
pricers: make(map[ExchangeType]ExchangePricer),
|
||||
}
|
||||
|
||||
// Register all exchange pricers
|
||||
engine.pricers[ExchangeUniswapV3] = NewUniswapV3Pricer(dc)
|
||||
engine.pricers[ExchangeUniswapV2] = NewUniswapV2Pricer(dc)
|
||||
engine.pricers[ExchangeSushiSwap] = NewSushiSwapPricer(dc)
|
||||
engine.pricers[ExchangeCamelot] = NewCamelotPricer(dc)
|
||||
engine.pricers[ExchangeBalancer] = NewBalancerPricer(dc)
|
||||
engine.pricers[ExchangeTraderJoe] = NewTraderJoePricer(dc)
|
||||
engine.pricers[ExchangeRamses] = NewRamsesPricer(dc)
|
||||
engine.pricers[ExchangeCurve] = NewCurvePricer(dc)
|
||||
|
||||
return engine
|
||||
}
|
||||
|
||||
// GetExchangePricer returns the appropriate pricer for an exchange
|
||||
func (engine *ExchangePricingEngine) GetExchangePricer(exchangeType ExchangeType) (ExchangePricer, error) {
|
||||
pricer, exists := engine.pricers[exchangeType]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported exchange type: %s", exchangeType)
|
||||
}
|
||||
return pricer, nil
|
||||
}
|
||||
|
||||
// CalculateSpotPrice gets spot price from any exchange
|
||||
func (engine *ExchangePricingEngine) CalculateSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
pricer, err := engine.GetExchangePricer(poolData.ExchangeType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pricer.GetSpotPrice(poolData)
|
||||
}
|
||||
|
||||
// UniswapV3Pricer implements Uniswap V3 concentrated liquidity pricing
|
||||
type UniswapV3Pricer struct {
|
||||
dc *DecimalConverter
|
||||
}
|
||||
|
||||
func NewUniswapV3Pricer(dc *DecimalConverter) *UniswapV3Pricer {
|
||||
return &UniswapV3Pricer{dc: dc}
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
if poolData.SqrtPriceX96 == nil {
|
||||
return nil, fmt.Errorf("missing sqrtPriceX96 for Uniswap V3 pool")
|
||||
}
|
||||
|
||||
// Use cached function for optimized calculation
|
||||
// Convert sqrtPriceX96 to actual price using cached constants
|
||||
// price = sqrtPriceX96^2 / 2^192
|
||||
price := SqrtPriceX96ToPriceCached(poolData.SqrtPriceX96)
|
||||
|
||||
// Adjust for decimal differences between tokens
|
||||
if poolData.Token0.Decimals != poolData.Token1.Decimals {
|
||||
decimalDiff := int(poolData.Token1.Decimals) - int(poolData.Token0.Decimals)
|
||||
adjustment := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimalDiff)), nil))
|
||||
price.Mul(price, adjustment)
|
||||
}
|
||||
|
||||
// Convert back to big.Int with appropriate precision
|
||||
priceInt := new(big.Int)
|
||||
priceScaled := new(big.Float).Mul(price, new(big.Float).SetInt(p.dc.getScalingFactor(18)))
|
||||
priceScaled.Int(priceInt)
|
||||
|
||||
return NewUniversalDecimal(priceInt, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Uniswap V3 concentrated liquidity calculation
|
||||
// This is a simplified version - production would need full tick math
|
||||
|
||||
if poolData.Liquidity == nil || poolData.Liquidity.Sign() == 0 {
|
||||
return nil, fmt.Errorf("insufficient liquidity in Uniswap V3 pool")
|
||||
}
|
||||
|
||||
// Apply fee
|
||||
feeAmount, err := p.dc.Multiply(amountIn, poolData.Fee, amountIn.Decimals, "FEE")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating fee: %w", err)
|
||||
}
|
||||
|
||||
amountInAfterFee, err := p.dc.Subtract(amountIn, feeAmount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error subtracting fee: %w", err)
|
||||
}
|
||||
|
||||
// Simplified constant product formula for demonstration
|
||||
// Real implementation would use tick mathematics
|
||||
numerator, err := p.dc.Multiply(amountInAfterFee, poolData.Reserve1, poolData.Reserve1.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
denominator, err := p.dc.Add(poolData.Reserve0, amountInAfterFee)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(numerator, denominator, poolData.Token1.Decimals, poolData.Token1.Symbol)
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Reverse calculation for Uniswap V3
|
||||
if poolData.Reserve1.IsZero() || amountOut.Value.Cmp(poolData.Reserve1.Value) >= 0 {
|
||||
return nil, fmt.Errorf("insufficient liquidity for requested output amount")
|
||||
}
|
||||
|
||||
// Simplified reverse calculation
|
||||
numerator, err := p.dc.Multiply(poolData.Reserve0, amountOut, poolData.Reserve0.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
denominator, err := p.dc.Subtract(poolData.Reserve1, amountOut)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
amountInBeforeFee, err := p.dc.Divide(numerator, denominator, poolData.Token0.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add fee
|
||||
feeMultiplier, err := p.dc.FromString("1", 18, "FEE_MULT")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oneMinusFee, err := p.dc.Subtract(feeMultiplier, poolData.Fee)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(amountInBeforeFee, oneMinusFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Calculate price before trade
|
||||
priceBefore, err := p.GetSpotPrice(poolData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting spot price: %w", err)
|
||||
}
|
||||
|
||||
// Calculate amount out
|
||||
amountOut, err := p.CalculateAmountOut(amountIn, poolData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating amount out: %w", err)
|
||||
}
|
||||
|
||||
// Calculate effective price
|
||||
effectivePrice, err := p.dc.Divide(amountOut, amountIn, 18, "EFFECTIVE_PRICE")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating effective price: %w", err)
|
||||
}
|
||||
|
||||
// Calculate price impact as percentage
|
||||
priceDiff, err := p.dc.Subtract(priceBefore, effectivePrice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating price difference: %w", err)
|
||||
}
|
||||
|
||||
return p.dc.CalculatePercentage(priceDiff, priceBefore)
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
if poolData.Liquidity == nil {
|
||||
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
|
||||
}
|
||||
|
||||
return NewUniversalDecimal(poolData.Liquidity, 18, "LIQUIDITY")
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) ValidatePoolData(poolData *PoolData) error {
|
||||
if poolData.SqrtPriceX96 == nil {
|
||||
return fmt.Errorf("Uniswap V3 pool missing sqrtPriceX96")
|
||||
}
|
||||
if poolData.Liquidity == nil {
|
||||
return fmt.Errorf("Uniswap V3 pool missing liquidity")
|
||||
}
|
||||
if poolData.Fee == nil {
|
||||
return fmt.Errorf("Uniswap V3 pool missing fee")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UniswapV2Pricer implements Uniswap V2 / SushiSwap constant product pricing
|
||||
type UniswapV2Pricer struct {
|
||||
dc *DecimalConverter
|
||||
}
|
||||
|
||||
func NewUniswapV2Pricer(dc *DecimalConverter) *UniswapV2Pricer {
|
||||
return &UniswapV2Pricer{dc: dc}
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
if poolData.Reserve0.IsZero() {
|
||||
return nil, fmt.Errorf("zero reserve0 in constant product pool")
|
||||
}
|
||||
|
||||
return p.dc.Divide(poolData.Reserve1, poolData.Reserve0, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Uniswap V2 constant product formula: x * y = k
|
||||
// amountOut = (amountIn * 997 * reserveOut) / (reserveIn * 1000 + amountIn * 997)
|
||||
|
||||
// Apply fee (0.3% = 997/1000 remaining)
|
||||
feeNumerator, _ := p.dc.FromString("997", 0, "FEE_NUM")
|
||||
feeDenominator, _ := p.dc.FromString("1000", 0, "FEE_DEN")
|
||||
|
||||
amountInWithFee, err := p.dc.Multiply(amountIn, feeNumerator, amountIn.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
numerator, err := p.dc.Multiply(amountInWithFee, poolData.Reserve1, poolData.Reserve1.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reserveInScaled, err := p.dc.Multiply(poolData.Reserve0, feeDenominator, poolData.Reserve0.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
denominator, err := p.dc.Add(reserveInScaled, amountInWithFee)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(numerator, denominator, poolData.Token1.Decimals, poolData.Token1.Symbol)
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Reverse calculation for constant product
|
||||
feeNumerator, _ := p.dc.FromString("1000", 0, "FEE_NUM")
|
||||
feeDenominator, _ := p.dc.FromString("997", 0, "FEE_DEN")
|
||||
|
||||
numerator, err := p.dc.Multiply(poolData.Reserve0, amountOut, poolData.Reserve0.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
numeratorWithFee, err := p.dc.Multiply(numerator, feeNumerator, numerator.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
denominator, err := p.dc.Subtract(poolData.Reserve1, amountOut)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
denominatorWithFee, err := p.dc.Multiply(denominator, feeDenominator, denominator.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(numeratorWithFee, denominatorWithFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Similar to Uniswap V3 implementation
|
||||
priceBefore, err := p.GetSpotPrice(poolData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
amountOut, err := p.CalculateAmountOut(amountIn, poolData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
effectivePrice, err := p.dc.Divide(amountOut, amountIn, 18, "EFFECTIVE_PRICE")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
priceDiff, err := p.dc.Subtract(priceBefore, effectivePrice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.CalculatePercentage(priceDiff, priceBefore)
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Geometric mean of reserves
|
||||
product, err := p.dc.Multiply(poolData.Reserve0, poolData.Reserve1, 18, "LIQUIDITY")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Simplified square root - in production use precise sqrt algorithm
|
||||
sqrt := new(big.Int).Sqrt(product.Value)
|
||||
return NewUniversalDecimal(sqrt, 18, "LIQUIDITY")
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) ValidatePoolData(poolData *PoolData) error {
|
||||
if poolData.Reserve0 == nil || poolData.Reserve1 == nil {
|
||||
return fmt.Errorf("missing reserves for constant product pool")
|
||||
}
|
||||
if poolData.Reserve0.IsZero() || poolData.Reserve1.IsZero() {
|
||||
return fmt.Errorf("zero reserves in constant product pool")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SushiSwapPricer uses same logic as Uniswap V2
|
||||
type SushiSwapPricer struct {
|
||||
*UniswapV2Pricer
|
||||
}
|
||||
|
||||
func NewSushiSwapPricer(dc *DecimalConverter) *SushiSwapPricer {
|
||||
return &SushiSwapPricer{NewUniswapV2Pricer(dc)}
|
||||
}
|
||||
|
||||
// CamelotPricer - Algebra-based DEX on Arbitrum
|
||||
type CamelotPricer struct {
|
||||
*UniswapV3Pricer
|
||||
}
|
||||
|
||||
func NewCamelotPricer(dc *DecimalConverter) *CamelotPricer {
|
||||
return &CamelotPricer{NewUniswapV3Pricer(dc)}
|
||||
}
|
||||
|
||||
// BalancerPricer - Weighted pool implementation
|
||||
type BalancerPricer struct {
|
||||
dc *DecimalConverter
|
||||
}
|
||||
|
||||
func NewBalancerPricer(dc *DecimalConverter) *BalancerPricer {
|
||||
return &BalancerPricer{dc: dc}
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
if len(poolData.Weights) < 2 {
|
||||
return nil, fmt.Errorf("insufficient weights for Balancer pool")
|
||||
}
|
||||
|
||||
// Balancer spot price = (reserveOut/weightOut) / (reserveIn/weightIn)
|
||||
reserveOutWeighted, err := p.dc.Divide(poolData.Reserve1, poolData.Weights[1], 18, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reserveInWeighted, err := p.dc.Divide(poolData.Reserve0, poolData.Weights[0], 18, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(reserveOutWeighted, reserveInWeighted, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Simplified Balancer calculation - production needs full weighted math
|
||||
return p.GetSpotPrice(poolData)
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
spotPrice, err := p.GetSpotPrice(poolData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(amountOut, spotPrice, poolData.Token0.Decimals, poolData.Token0.Symbol)
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Placeholder - would implement Balancer-specific price impact
|
||||
return NewUniversalDecimal(big.NewInt(0), 4, "PERCENT")
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) ValidatePoolData(poolData *PoolData) error {
|
||||
if len(poolData.Weights) < 2 {
|
||||
return fmt.Errorf("Balancer pool missing weights")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Placeholder implementations for other exchanges
|
||||
func NewTraderJoePricer(dc *DecimalConverter) *UniswapV2Pricer { return NewUniswapV2Pricer(dc) }
|
||||
func NewRamsesPricer(dc *DecimalConverter) *UniswapV3Pricer { return NewUniswapV3Pricer(dc) }
|
||||
|
||||
// CurvePricer - Stable swap implementation
|
||||
type CurvePricer struct {
|
||||
dc *DecimalConverter
|
||||
}
|
||||
|
||||
func NewCurvePricer(dc *DecimalConverter) *CurvePricer {
|
||||
return &CurvePricer{dc: dc}
|
||||
}
|
||||
|
||||
func (p *CurvePricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Curve stable swap pricing - simplified version
|
||||
if poolData.A == nil {
|
||||
return nil, fmt.Errorf("missing amplification coefficient for Curve pool")
|
||||
}
|
||||
|
||||
// For stable swaps, price should be close to 1:1
|
||||
return NewUniversalDecimal(p.dc.getScalingFactor(18), 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
|
||||
}
|
||||
|
||||
func (p *CurvePricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Simplified stable swap calculation
|
||||
// Real implementation would use Newton's method for stable swap invariant
|
||||
feeAmount, err := p.dc.Multiply(amountIn, poolData.Fee, amountIn.Decimals, "FEE")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Subtract(amountIn, feeAmount)
|
||||
}
|
||||
|
||||
func (p *CurvePricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Reverse stable swap calculation
|
||||
feeMultiplier, _ := p.dc.FromString("1", 18, "FEE_MULT")
|
||||
oneMinusFee, err := p.dc.Subtract(feeMultiplier, poolData.Fee)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(amountOut, oneMinusFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
|
||||
}
|
||||
|
||||
func (p *CurvePricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Curve pools have minimal price impact for stable pairs
|
||||
return NewUniversalDecimal(big.NewInt(1000), 4, "PERCENT") // 0.1%
|
||||
}
|
||||
|
||||
func (p *CurvePricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
|
||||
}
|
||||
|
||||
func (p *CurvePricer) ValidatePoolData(poolData *PoolData) error {
|
||||
if poolData.A == nil {
|
||||
return fmt.Errorf("Curve pool missing amplification coefficient")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
BIN
orig/pkg/math/math.test
Executable file
BIN
orig/pkg/math/math.test
Executable file
Binary file not shown.
62
orig/pkg/math/mock_gas_estimator.go
Normal file
62
orig/pkg/math/mock_gas_estimator.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// MockGasEstimator implements GasEstimator for testing purposes
|
||||
type MockGasEstimator struct {
|
||||
currentGasPrice *UniversalDecimal
|
||||
}
|
||||
|
||||
// NewMockGasEstimator creates a new mock gas estimator
|
||||
func NewMockGasEstimator() *MockGasEstimator {
|
||||
dc := NewDecimalConverter()
|
||||
gasPrice, _ := dc.FromString("20", 9, "GWEI") // 20 gwei default
|
||||
|
||||
return &MockGasEstimator{
|
||||
currentGasPrice: gasPrice,
|
||||
}
|
||||
}
|
||||
|
||||
// EstimateSwapGas estimates gas for a swap
|
||||
func (mge *MockGasEstimator) EstimateSwapGas(exchange string, poolData *PoolData) (uint64, error) {
|
||||
// Different exchanges have different gas costs
|
||||
switch exchange {
|
||||
case "uniswap_v3":
|
||||
return 150000, nil // 150k gas for Uniswap V3
|
||||
case "uniswap_v2", "sushiswap":
|
||||
return 120000, nil // 120k gas for Uniswap V2/SushiSwap
|
||||
case "camelot":
|
||||
return 130000, nil // 130k gas for Camelot
|
||||
case "balancer":
|
||||
return 200000, nil // 200k gas for Balancer
|
||||
case "curve":
|
||||
return 180000, nil // 180k gas for Curve
|
||||
default:
|
||||
return 150000, nil // Default to 150k gas
|
||||
}
|
||||
}
|
||||
|
||||
// EstimateFlashSwapGas estimates gas for a flash swap
|
||||
func (mge *MockGasEstimator) EstimateFlashSwapGas(route []*PoolData) (uint64, error) {
|
||||
// Flash swap overhead varies by complexity
|
||||
baseGas := uint64(200000) // Base flash swap overhead
|
||||
|
||||
// Add gas for each hop
|
||||
hopGas := uint64(len(route)) * 50000
|
||||
|
||||
return baseGas + hopGas, nil
|
||||
}
|
||||
|
||||
// GetCurrentGasPrice gets the current gas price
|
||||
func (mge *MockGasEstimator) GetCurrentGasPrice() (*UniversalDecimal, error) {
|
||||
return mge.currentGasPrice, nil
|
||||
}
|
||||
|
||||
// SetCurrentGasPrice sets the current gas price for testing
|
||||
func (mge *MockGasEstimator) SetCurrentGasPrice(gasPriceGwei int64) {
|
||||
dc := NewDecimalConverter()
|
||||
gasPrice, _ := dc.FromString(big.NewInt(gasPriceGwei).String(), 9, "GWEI")
|
||||
mge.currentGasPrice = gasPrice
|
||||
}
|
||||
293
orig/pkg/math/precision_test.go
Normal file
293
orig/pkg/math/precision_test.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestDecimalPrecisionPreservation tests that decimal operations preserve precision
|
||||
func TestDecimalPrecisionPreservation(t *testing.T) {
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
value string
|
||||
decimals uint8
|
||||
symbol string
|
||||
}{
|
||||
{"ETH precision", "1000000000000000000", 18, "ETH"}, // 1 ETH
|
||||
{"USDC precision", "1000000", 6, "USDC"}, // 1 USDC
|
||||
{"WBTC precision", "100000000", 8, "WBTC"}, // 1 WBTC
|
||||
{"Small amount", "1", 18, "ETH"}, // 1 wei
|
||||
{"Large amount", "1000000000000000000000", 18, "ETH"}, // 1000 ETH
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create decimal from string
|
||||
decimal, err := dc.FromString(tc.value, tc.decimals, tc.symbol)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create decimal: %v", err)
|
||||
}
|
||||
|
||||
// Convert to string and back
|
||||
humanReadable := dc.ToHumanReadable(decimal)
|
||||
backToDecimal, err := dc.FromString(humanReadable, tc.decimals, tc.symbol)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert back from string: %v", err)
|
||||
}
|
||||
|
||||
// Compare values
|
||||
if decimal.Value.Cmp(backToDecimal.Value) != 0 {
|
||||
t.Errorf("Precision lost in round-trip conversion")
|
||||
t.Errorf("Original: %s", decimal.Value.String())
|
||||
t.Errorf("Round-trip: %s", backToDecimal.Value.String())
|
||||
t.Errorf("Human readable: %s", humanReadable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestArithmeticOperations tests basic arithmetic with different decimal precisions
|
||||
func TestArithmeticOperations(t *testing.T) {
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
// Create test values with different precisions
|
||||
eth1, _ := dc.FromString("1000000000000000000", 18, "ETH") // 1 ETH
|
||||
eth2, _ := dc.FromString("2000000000000000000", 18, "ETH") // 2 ETH
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
op string
|
||||
a, b *UniversalDecimal
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "ETH addition",
|
||||
op: "add",
|
||||
a: eth1,
|
||||
b: eth2,
|
||||
expected: "3000000000000000000", // 3 ETH
|
||||
},
|
||||
{
|
||||
name: "ETH subtraction",
|
||||
op: "sub",
|
||||
a: eth2,
|
||||
b: eth1,
|
||||
expected: "1000000000000000000", // 1 ETH
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var result *UniversalDecimal
|
||||
var err error
|
||||
|
||||
switch test.op {
|
||||
case "add":
|
||||
result, err = dc.Add(test.a, test.b)
|
||||
case "sub":
|
||||
result, err = dc.Subtract(test.a, test.b)
|
||||
default:
|
||||
t.Fatalf("Unknown operation: %s", test.op)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Operation failed: %v", err)
|
||||
}
|
||||
|
||||
if result.Value.String() != test.expected {
|
||||
t.Errorf("Expected %s, got %s", test.expected, result.Value.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPercentageCalculations tests percentage calculations for precision
|
||||
func TestPercentageCalculations(t *testing.T) {
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
numerator string
|
||||
denominator string
|
||||
decimals uint8
|
||||
expectedRange [2]float64 // [min, max] acceptable range
|
||||
}{
|
||||
{
|
||||
name: "1% calculation",
|
||||
numerator: "10000000000000000", // 0.01 ETH
|
||||
denominator: "1000000000000000000", // 1 ETH
|
||||
decimals: 18,
|
||||
expectedRange: [2]float64{0.99, 1.01}, // 1% ± 0.01%
|
||||
},
|
||||
{
|
||||
name: "50% calculation",
|
||||
numerator: "500000000000000000", // 0.5 ETH
|
||||
denominator: "1000000000000000000", // 1 ETH
|
||||
decimals: 18,
|
||||
expectedRange: [2]float64{49.9, 50.1}, // 50% ± 0.1%
|
||||
},
|
||||
{
|
||||
name: "Small percentage",
|
||||
numerator: "1000000000000000", // 0.001 ETH
|
||||
denominator: "1000000000000000000", // 1 ETH
|
||||
decimals: 18,
|
||||
expectedRange: [2]float64{0.099, 0.101}, // 0.1% ± 0.001%
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
num, err := dc.FromString(tc.numerator, tc.decimals, "ETH")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create numerator: %v", err)
|
||||
}
|
||||
|
||||
denom, err := dc.FromString(tc.denominator, tc.decimals, "ETH")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create denominator: %v", err)
|
||||
}
|
||||
|
||||
percentage, err := dc.CalculatePercentage(num, denom)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to calculate percentage: %v", err)
|
||||
}
|
||||
|
||||
percentageFloat, _ := percentage.Value.Float64()
|
||||
|
||||
// Convert from raw value to actual percentage (divide by 10^decimals)
|
||||
// Since percentage has 4 decimals, divide by 10000 to get actual percentage value
|
||||
actualPercentage := percentageFloat / 10000.0
|
||||
|
||||
t.Logf("Calculated percentage: %.6f%%", actualPercentage)
|
||||
|
||||
if actualPercentage < tc.expectedRange[0] || actualPercentage > tc.expectedRange[1] {
|
||||
t.Errorf("Percentage %.6f%% outside expected range [%.3f%%, %.3f%%]",
|
||||
actualPercentage, tc.expectedRange[0], tc.expectedRange[1])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// PropertyTest tests mathematical properties like commutativity, associativity
|
||||
func TestMathematicalProperties(t *testing.T) {
|
||||
dc := NewDecimalConverter()
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
// Generate random test values
|
||||
for i := 0; i < 100; i++ {
|
||||
// Generate random big integers
|
||||
val1 := big.NewInt(rand.Int63n(1000000000000000000)) // Up to 1 ETH
|
||||
val2 := big.NewInt(rand.Int63n(1000000000000000000))
|
||||
val3 := big.NewInt(rand.Int63n(1000000000000000000))
|
||||
|
||||
a, _ := NewUniversalDecimal(val1, 18, "ETH")
|
||||
b, _ := NewUniversalDecimal(val2, 18, "ETH")
|
||||
c, _ := NewUniversalDecimal(val3, 18, "ETH")
|
||||
|
||||
// Test commutativity: a + b = b + a
|
||||
ab, err1 := dc.Add(a, b)
|
||||
ba, err2 := dc.Add(b, a)
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
t.Fatalf("Addition failed: %v, %v", err1, err2)
|
||||
}
|
||||
|
||||
if ab.Value.Cmp(ba.Value) != 0 {
|
||||
t.Errorf("Addition not commutative: %s + %s = %s, %s + %s = %s",
|
||||
a.Value.String(), b.Value.String(), ab.Value.String(),
|
||||
b.Value.String(), a.Value.String(), ba.Value.String())
|
||||
}
|
||||
|
||||
// Test associativity: (a + b) + c = a + (b + c)
|
||||
ab_c, err1 := dc.Add(ab, c)
|
||||
bc, err2 := dc.Add(b, c)
|
||||
a_bc, err3 := dc.Add(a, bc)
|
||||
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
continue // Skip this iteration if any operation fails
|
||||
}
|
||||
|
||||
if ab_c.Value.Cmp(a_bc.Value) != 0 {
|
||||
t.Errorf("Addition not associative")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDecimalOperations benchmarks decimal operations
|
||||
func BenchmarkDecimalOperations(b *testing.B) {
|
||||
dc := NewDecimalConverter()
|
||||
val1, _ := dc.FromString("1000000000000000000", 18, "ETH")
|
||||
val2, _ := dc.FromString("2000000000000000000", 18, "ETH")
|
||||
|
||||
b.Run("Addition", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
dc.Add(val1, val2)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Subtraction", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
dc.Subtract(val2, val1)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Percentage", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
dc.CalculatePercentage(val1, val2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzDecimalOperations fuzzes decimal operations for edge cases
|
||||
func FuzzDecimalOperations(f *testing.F) {
|
||||
// Seed with known values
|
||||
f.Add(int64(1000000000000000000), int64(2000000000000000000)) // 1 ETH, 2 ETH
|
||||
f.Add(int64(1), int64(1000000000000000000)) // 1 wei, 1 ETH
|
||||
f.Add(int64(0), int64(1000000000000000000)) // 0, 1 ETH
|
||||
|
||||
f.Fuzz(func(t *testing.T, val1, val2 int64) {
|
||||
// Ensure positive values
|
||||
if val1 < 0 {
|
||||
val1 = -val1
|
||||
}
|
||||
if val2 <= 0 {
|
||||
return // Skip zero/negative denominators
|
||||
}
|
||||
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
a, err := NewUniversalDecimal(big.NewInt(val1), 18, "ETH")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b, err := NewUniversalDecimal(big.NewInt(val2), 18, "ETH")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Test addition doesn't panic
|
||||
_, err = dc.Add(a, b)
|
||||
if err != nil {
|
||||
t.Errorf("Addition failed: %v", err)
|
||||
}
|
||||
|
||||
// Test subtraction doesn't panic (if a >= b)
|
||||
if val1 >= val2 {
|
||||
_, err = dc.Subtract(a, b)
|
||||
if err != nil {
|
||||
t.Errorf("Subtraction failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test percentage calculation doesn't panic
|
||||
_, err = dc.CalculatePercentage(a, b)
|
||||
if err != nil {
|
||||
t.Errorf("Percentage calculation failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
192
orig/pkg/math/price_impact.go
Normal file
192
orig/pkg/math/price_impact.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// PriceImpactCalculator provides a unified interface for calculating price impact across all protocols
|
||||
type PriceImpactCalculator struct {
|
||||
mathCalculator *MathCalculator
|
||||
}
|
||||
|
||||
// NewPriceImpactCalculator creates a new price impact calculator
|
||||
func NewPriceImpactCalculator() *PriceImpactCalculator {
|
||||
return &PriceImpactCalculator{
|
||||
mathCalculator: NewMathCalculator(),
|
||||
}
|
||||
}
|
||||
|
||||
// CalculatePriceImpact calculates price impact for any supported protocol
|
||||
func (pic *PriceImpactCalculator) CalculatePriceImpact(
|
||||
protocol string,
|
||||
amountIn, reserveIn, reserveOut *big.Int,
|
||||
sqrtPriceX96, liquidity *big.Int, // For Uniswap V3 and Kyber
|
||||
) (float64, error) {
|
||||
switch protocol {
|
||||
case "uniswap_v2", "sushiswap":
|
||||
return pic.mathCalculator.uniswapV2.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
|
||||
case "uniswap_v3", "camelot_v3":
|
||||
return pic.mathCalculator.uniswapV3.CalculatePriceImpact(amountIn, sqrtPriceX96, liquidity)
|
||||
case "curve":
|
||||
return pic.mathCalculator.curve.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
|
||||
case "kyber_elastic", "kyber_classic":
|
||||
return pic.mathCalculator.kyber.CalculatePriceImpact(amountIn, sqrtPriceX96, liquidity)
|
||||
case "balancer":
|
||||
return pic.mathCalculator.balancer.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
|
||||
case "constant_sum":
|
||||
return pic.mathCalculator.constantSum.CalculatePriceImpact(amountIn, reserveIn, reserveOut)
|
||||
case "algebra_v1":
|
||||
return pic.calculateAlgebraPriceImpact(amountIn, reserveIn, reserveOut)
|
||||
case "integral":
|
||||
return pic.calculateIntegralPriceImpact(amountIn, reserveIn, reserveOut)
|
||||
case "oneinch":
|
||||
return pic.calculateOneInchPriceImpact(amountIn, reserveIn, reserveOut)
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
// calculateAlgebraPriceImpact calculates price impact for Algebra V1.9
|
||||
func (pic *PriceImpactCalculator) calculateAlgebraPriceImpact(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
|
||||
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return 0, fmt.Errorf("invalid amounts")
|
||||
}
|
||||
|
||||
// Calculate new reserves after swap
|
||||
amountOut, err := NewAlgebraV1Math().CalculateAmountOutAlgebra(amountIn, reserveIn, reserveOut, 500)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
|
||||
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
|
||||
|
||||
// Calculate price before and after swap
|
||||
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
|
||||
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
|
||||
|
||||
// Calculate price impact
|
||||
impact := new(big.Float).Sub(priceBefore, priceAfter)
|
||||
impact.Quo(impact, priceBefore)
|
||||
|
||||
impactFloat, _ := impact.Float64()
|
||||
return math.Abs(impactFloat), nil
|
||||
}
|
||||
|
||||
// calculateIntegralPriceImpact calculates price impact for Integral
|
||||
func (pic *PriceImpactCalculator) calculateIntegralPriceImpact(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
|
||||
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return 0, fmt.Errorf("invalid amounts")
|
||||
}
|
||||
|
||||
// Calculate new reserves after swap
|
||||
amountOut, err := NewIntegralMath().CalculateAmountOutIntegral(amountIn, reserveIn, reserveOut, 100)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
|
||||
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
|
||||
|
||||
// Calculate price before and after swap
|
||||
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
|
||||
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
|
||||
|
||||
// Calculate price impact
|
||||
impact := new(big.Float).Sub(priceBefore, priceAfter)
|
||||
impact.Quo(impact, priceBefore)
|
||||
|
||||
impactFloat, _ := impact.Float64()
|
||||
return math.Abs(impactFloat), nil
|
||||
}
|
||||
|
||||
// calculateOneInchPriceImpact calculates price impact for 1Inch aggregation
|
||||
func (pic *PriceImpactCalculator) calculateOneInchPriceImpact(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
|
||||
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return 0, fmt.Errorf("invalid amounts")
|
||||
}
|
||||
|
||||
// 1Inch aggregates multiple DEXs, so we'll calculate an effective price impact
|
||||
// based on the overall route
|
||||
|
||||
// For this implementation, we'll calculate using a simple weighted average
|
||||
// of the price impact across different paths
|
||||
|
||||
// Calculate new reserves after swap (simplified)
|
||||
amountOut, err := NewOneInchMath().CalculateAmountOutOneInch(amountIn, []PathElement{
|
||||
{
|
||||
Protocol: "uniswap_v2",
|
||||
ReserveIn: reserveIn,
|
||||
ReserveOut: reserveOut,
|
||||
Fee: 3000,
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
newReserveIn := new(big.Int).Add(reserveIn, amountIn)
|
||||
newReserveOut := new(big.Int).Sub(reserveOut, amountOut)
|
||||
|
||||
// Calculate price before and after swap
|
||||
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(reserveOut), new(big.Float).SetInt(reserveIn))
|
||||
priceAfter := new(big.Float).Quo(new(big.Float).SetInt(newReserveOut), new(big.Float).SetInt(newReserveIn))
|
||||
|
||||
// Calculate price impact
|
||||
impact := new(big.Float).Sub(priceBefore, priceAfter)
|
||||
impact.Quo(impact, priceBefore)
|
||||
|
||||
impactFloat, _ := impact.Float64()
|
||||
return math.Abs(impactFloat), nil
|
||||
}
|
||||
|
||||
// CalculatePriceMovementThreshold determines if a swap moves price beyond a certain threshold
|
||||
func (pic *PriceImpactCalculator) CalculatePriceMovementThreshold(
|
||||
protocol string,
|
||||
amountIn, reserveIn, reserveOut *big.Int,
|
||||
sqrtPriceX96, liquidity *big.Int, // For Uniswap V3 and Kyber
|
||||
threshold float64,
|
||||
) (bool, float64, error) {
|
||||
impact, err := pic.CalculatePriceImpact(protocol, amountIn, reserveIn, reserveOut, sqrtPriceX96, liquidity)
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
movesPrice := impact >= threshold
|
||||
|
||||
return movesPrice, impact, nil
|
||||
}
|
||||
|
||||
// CalculatePriceImpactWithSlippage combines price impact and slippage calculations
|
||||
func (pic *PriceImpactCalculator) CalculatePriceImpactWithSlippage(
|
||||
protocol string,
|
||||
amountIn, reserveIn, reserveOut *big.Int,
|
||||
sqrtPriceX96, liquidity *big.Int, // For Uniswap V3 and Kyber
|
||||
) (float64, float64, error) {
|
||||
// Calculate price impact
|
||||
priceImpact, err := pic.CalculatePriceImpact(protocol, amountIn, reserveIn, reserveOut, sqrtPriceX96, liquidity)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Calculate expected output
|
||||
mathCalculator := pic.mathCalculator.GetMathForExchange(protocol)
|
||||
expectedOut, err := mathCalculator.CalculateAmountOut(amountIn, reserveIn, reserveOut, 0)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Calculate actual output after slippage (simplified)
|
||||
actualOut := new(big.Int).Set(expectedOut)
|
||||
|
||||
// Calculate slippage
|
||||
slippage, err := mathCalculator.CalculateSlippage(expectedOut, actualOut)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return priceImpact, slippage, nil
|
||||
}
|
||||
Reference in New Issue
Block a user