saving in place
This commit is contained in:
627
pkg/math/arbitrage_calculator.go
Normal file
627
pkg/math/arbitrage_calculator.go
Normal file
@@ -0,0 +1,627 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/fraktal/mev-beta/pkg/types"
|
||||
)
|
||||
|
||||
// Use the canonical ArbitrageOpportunity from types package
|
||||
// Extended fields for advanced calculations can be added as needed
|
||||
|
||||
// ExchangeStep represents one step in the arbitrage execution
|
||||
type ExchangeStep struct {
|
||||
Exchange ExchangeType
|
||||
Pool *PoolData
|
||||
TokenIn TokenInfo
|
||||
TokenOut TokenInfo
|
||||
AmountIn *UniversalDecimal
|
||||
AmountOut *UniversalDecimal
|
||||
PriceImpact *UniversalDecimal
|
||||
EstimatedGas uint64
|
||||
}
|
||||
|
||||
// RiskAssessment evaluates the risk level of an arbitrage opportunity
|
||||
type RiskAssessment struct {
|
||||
Overall RiskLevel
|
||||
Liquidity RiskLevel
|
||||
PriceImpact RiskLevel
|
||||
Competition RiskLevel
|
||||
Slippage RiskLevel
|
||||
GasPrice RiskLevel
|
||||
Warnings []string
|
||||
OverallRisk float64 // Numeric representation of overall risk (0.0 to 1.0)
|
||||
}
|
||||
|
||||
// RiskLevel represents different risk categories
|
||||
type RiskLevel string
|
||||
|
||||
const (
|
||||
RiskLow RiskLevel = "low"
|
||||
RiskMedium RiskLevel = "medium"
|
||||
RiskHigh RiskLevel = "high"
|
||||
RiskCritical RiskLevel = "critical"
|
||||
)
|
||||
|
||||
// ArbitrageCalculator performs precise arbitrage calculations
|
||||
type ArbitrageCalculator struct {
|
||||
pricingEngine *ExchangePricingEngine
|
||||
decimalConverter *DecimalConverter
|
||||
gasEstimator GasEstimator
|
||||
|
||||
// Configuration
|
||||
minProfitThreshold *UniversalDecimal
|
||||
maxPriceImpact *UniversalDecimal
|
||||
maxSlippage *UniversalDecimal
|
||||
maxGasPriceGwei *UniversalDecimal
|
||||
}
|
||||
|
||||
// GasEstimator interface for gas cost calculations
|
||||
type GasEstimator interface {
|
||||
EstimateSwapGas(exchange ExchangeType, poolData *PoolData) (uint64, error)
|
||||
EstimateFlashSwapGas(route []*PoolData) (uint64, error)
|
||||
GetCurrentGasPrice() (*UniversalDecimal, error)
|
||||
}
|
||||
|
||||
// NewArbitrageCalculator creates a new arbitrage calculator
|
||||
func NewArbitrageCalculator(gasEstimator GasEstimator) *ArbitrageCalculator {
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
// Default configuration
|
||||
minProfit, _ := dc.FromString("0.01", 18, "ETH") // 0.01 ETH minimum
|
||||
maxImpact, _ := dc.FromString("2", 4, "PERCENT") // 2% max price impact
|
||||
maxSlip, _ := dc.FromString("1", 4, "PERCENT") // 1% max slippage
|
||||
maxGas, _ := dc.FromString("50", 9, "GWEI") // 50 gwei max gas
|
||||
|
||||
return &ArbitrageCalculator{
|
||||
pricingEngine: NewExchangePricingEngine(),
|
||||
decimalConverter: dc,
|
||||
gasEstimator: gasEstimator,
|
||||
minProfitThreshold: minProfit,
|
||||
maxPriceImpact: maxImpact,
|
||||
maxSlippage: maxSlip,
|
||||
maxGasPriceGwei: maxGas,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateArbitrageOpportunity performs comprehensive arbitrage analysis
|
||||
func (calc *ArbitrageCalculator) CalculateArbitrageOpportunity(
|
||||
path []*PoolData,
|
||||
inputAmount *UniversalDecimal,
|
||||
inputToken TokenInfo,
|
||||
outputToken TokenInfo,
|
||||
) (*types.ArbitrageOpportunity, error) {
|
||||
|
||||
if len(path) == 0 {
|
||||
return nil, fmt.Errorf("empty arbitrage path")
|
||||
}
|
||||
|
||||
// Step 1: Calculate execution route with amounts
|
||||
route, err := calc.calculateExecutionRoute(path, inputAmount, inputToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating execution route: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Get final output amount
|
||||
finalOutput := route[len(route)-1].AmountOut
|
||||
|
||||
// Step 3: Calculate gas costs
|
||||
totalGasCost, err := calc.calculateTotalGasCost(route)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating gas cost: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Calculate profits (convert to common denomination - ETH)
|
||||
_, netProfit, profitPercentage, err := calc.calculateProfits(
|
||||
inputAmount, finalOutput, totalGasCost, inputToken, outputToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating profits: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Calculate total price impact
|
||||
totalPriceImpact, err := calc.calculateTotalPriceImpact(route)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating price impact: %w", err)
|
||||
}
|
||||
|
||||
// Step 6: Calculate minimum output with slippage (we don't use this in the final result)
|
||||
_, err = calc.calculateMinimumOutput(finalOutput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating minimum output: %w", err)
|
||||
}
|
||||
|
||||
// Step 7: Assess risks
|
||||
riskAssessment := calc.assessRisks(route, totalPriceImpact, netProfit)
|
||||
|
||||
// Step 8: Calculate confidence and execution time
|
||||
confidence := calc.calculateConfidence(riskAssessment, netProfit, totalPriceImpact)
|
||||
executionTime := calc.estimateExecutionTime(route)
|
||||
|
||||
// Convert path to string array
|
||||
pathStrings := make([]string, len(path))
|
||||
for i, pool := range path {
|
||||
pathStrings[i] = pool.Address // Address is already a string
|
||||
}
|
||||
|
||||
// Convert pools to string array
|
||||
poolStrings := make([]string, len(path))
|
||||
for i, pool := range path {
|
||||
poolStrings[i] = pool.Address // Address is already a string
|
||||
}
|
||||
|
||||
opportunity := &types.ArbitrageOpportunity{
|
||||
Path: pathStrings,
|
||||
Pools: poolStrings,
|
||||
AmountIn: inputAmount.Value,
|
||||
Profit: netProfit.Value,
|
||||
NetProfit: netProfit.Value,
|
||||
GasEstimate: totalGasCost.Value,
|
||||
ROI: func() float64 { f, _ := profitPercentage.Value.Float64(); return f }(),
|
||||
Protocol: "multi", // Default protocol for multi-step arbitrage
|
||||
ExecutionTime: executionTime,
|
||||
Confidence: confidence,
|
||||
PriceImpact: func() float64 { f, _ := totalPriceImpact.Value.Float64(); return f }(),
|
||||
MaxSlippage: 0.01, // Default 1% max slippage
|
||||
TokenIn: common.HexToAddress(inputToken.Address),
|
||||
TokenOut: common.HexToAddress(outputToken.Address),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Risk: riskAssessment.OverallRisk,
|
||||
}
|
||||
|
||||
return opportunity, nil
|
||||
}
|
||||
|
||||
// calculateExecutionRoute calculates amounts through each step of the arbitrage
|
||||
func (calc *ArbitrageCalculator) calculateExecutionRoute(
|
||||
path []*PoolData,
|
||||
inputAmount *UniversalDecimal,
|
||||
inputToken TokenInfo,
|
||||
) ([]ExchangeStep, error) {
|
||||
|
||||
route := make([]ExchangeStep, len(path))
|
||||
currentAmount := inputAmount
|
||||
currentToken := inputToken
|
||||
|
||||
for i, pool := range path {
|
||||
// Determine output token for this step
|
||||
var outputToken TokenInfo
|
||||
if currentToken.Address == pool.Token0.Address {
|
||||
outputToken = TokenInfo{
|
||||
Address: pool.Token1.Address,
|
||||
Symbol: "TOKEN1", // In a real implementation, you'd fetch the actual symbol
|
||||
Decimals: 18,
|
||||
}
|
||||
} else if currentToken.Address == pool.Token1.Address {
|
||||
outputToken = TokenInfo{
|
||||
Address: pool.Token0.Address,
|
||||
Symbol: "TOKEN0", // In a real implementation, you'd fetch the actual symbol
|
||||
Decimals: 18,
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("token %s not found in pool %s", currentToken.Symbol, pool.Address)
|
||||
}
|
||||
|
||||
// For this simplified implementation, we'll calculate a mock amount out
|
||||
// In a real implementation, you'd use the pricer's CalculateAmountOut method
|
||||
amountOut := currentAmount // Simple 1:1 for this example
|
||||
priceImpact := &UniversalDecimal{Value: big.NewInt(0), Decimals: 4, Symbol: "PERCENT"} // No impact in mock
|
||||
|
||||
// Estimate gas for this step
|
||||
estimatedGas, err := calc.gasEstimator.EstimateSwapGas(ExchangeUniswapV3, pool) // Using a mock exchange type
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error estimating gas for pool %s: %w", pool.Address, err)
|
||||
}
|
||||
|
||||
// Create execution step
|
||||
route[i] = ExchangeStep{
|
||||
Exchange: ExchangeUniswapV3, // Using a mock exchange type
|
||||
Pool: pool,
|
||||
TokenIn: currentToken,
|
||||
TokenOut: outputToken,
|
||||
AmountIn: currentAmount,
|
||||
AmountOut: amountOut,
|
||||
PriceImpact: priceImpact,
|
||||
EstimatedGas: estimatedGas,
|
||||
}
|
||||
|
||||
// Update for next iteration
|
||||
currentAmount = amountOut
|
||||
currentToken = outputToken
|
||||
}
|
||||
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// calculateTotalGasCost calculates the total gas cost for the entire route
|
||||
func (calc *ArbitrageCalculator) calculateTotalGasCost(route []ExchangeStep) (*UniversalDecimal, error) {
|
||||
// Get current gas price
|
||||
gasPrice, err := calc.gasEstimator.GetCurrentGasPrice()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting gas price: %w", err)
|
||||
}
|
||||
|
||||
// Sum up all gas estimates
|
||||
totalGas := uint64(0)
|
||||
for _, step := range route {
|
||||
totalGas += step.EstimatedGas
|
||||
}
|
||||
|
||||
// Add flash swap overhead if multi-step
|
||||
if len(route) > 1 {
|
||||
flashSwapGas, err := calc.gasEstimator.EstimateFlashSwapGas([]*PoolData{})
|
||||
if err == nil {
|
||||
totalGas += flashSwapGas
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to gas cost in ETH
|
||||
totalGasBig := big.NewInt(int64(totalGas))
|
||||
totalGasDecimal, err := NewUniversalDecimal(totalGasBig, 0, "GAS")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return calc.decimalConverter.Multiply(totalGasDecimal, gasPrice, 18, "ETH")
|
||||
}
|
||||
|
||||
// calculateProfits calculates gross profit, net profit, and profit percentage
|
||||
func (calc *ArbitrageCalculator) calculateProfits(
|
||||
inputAmount, outputAmount, gasCost *UniversalDecimal,
|
||||
inputToken, outputToken TokenInfo,
|
||||
) (*UniversalDecimal, *UniversalDecimal, *UniversalDecimal, error) {
|
||||
|
||||
// Convert amounts to common denomination (ETH) for comparison
|
||||
inputETH := calc.convertToETH(inputAmount, inputToken)
|
||||
outputETH := calc.convertToETH(outputAmount, outputToken)
|
||||
|
||||
// Gross profit = output - input (in ETH terms)
|
||||
grossProfit, err := calc.decimalConverter.Subtract(outputETH, inputETH)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error calculating gross profit: %w", err)
|
||||
}
|
||||
|
||||
// Net profit = gross profit - gas cost
|
||||
netProfit, err := calc.decimalConverter.Subtract(grossProfit, gasCost)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error calculating net profit: %w", err)
|
||||
}
|
||||
|
||||
// Profit percentage = (net profit / input) * 100
|
||||
profitPercentage, err := calc.decimalConverter.CalculatePercentage(netProfit, inputETH)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error calculating profit percentage: %w", err)
|
||||
}
|
||||
|
||||
return grossProfit, netProfit, profitPercentage, nil
|
||||
}
|
||||
|
||||
// calculateTotalPriceImpact calculates cumulative price impact across all steps
|
||||
func (calc *ArbitrageCalculator) calculateTotalPriceImpact(route []ExchangeStep) (*UniversalDecimal, error) {
|
||||
if len(route) == 0 {
|
||||
return NewUniversalDecimal(big.NewInt(0), 4, "PERCENT")
|
||||
}
|
||||
|
||||
// Compound price impacts: (1 + impact1) * (1 + impact2) - 1
|
||||
compoundedImpact, err := calc.decimalConverter.FromString("1", 4, "COMPOUND")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, step := range route {
|
||||
// Convert price impact to factor (1 + impact)
|
||||
one, _ := calc.decimalConverter.FromString("1", 4, "ONE")
|
||||
impactFactor, err := calc.decimalConverter.Add(one, step.PriceImpact)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating impact factor: %w", err)
|
||||
}
|
||||
|
||||
// Multiply with cumulative impact
|
||||
compoundedImpact, err = calc.decimalConverter.Multiply(compoundedImpact, impactFactor, 4, "COMPOUND")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error compounding impact: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Subtract 1 to get final impact percentage
|
||||
one, _ := calc.decimalConverter.FromString("1", 4, "ONE")
|
||||
totalImpact, err := calc.decimalConverter.Subtract(compoundedImpact, one)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating total impact: %w", err)
|
||||
}
|
||||
|
||||
return totalImpact, nil
|
||||
}
|
||||
|
||||
// calculateMinimumOutput calculates minimum output accounting for slippage
|
||||
func (calc *ArbitrageCalculator) calculateMinimumOutput(expectedOutput *UniversalDecimal) (*UniversalDecimal, error) {
|
||||
// Apply slippage tolerance
|
||||
slippageFactor, err := calc.decimalConverter.Subtract(
|
||||
&UniversalDecimal{Value: big.NewInt(10000), Decimals: 4, Symbol: "ONE"},
|
||||
calc.maxSlippage,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return calc.decimalConverter.Multiply(expectedOutput, slippageFactor, 18, "TOKEN")
|
||||
}
|
||||
|
||||
// assessRisks performs comprehensive risk assessment
|
||||
func (calc *ArbitrageCalculator) assessRisks(route []ExchangeStep, priceImpact, netProfit *UniversalDecimal) RiskAssessment {
|
||||
assessment := RiskAssessment{
|
||||
Warnings: make([]string, 0),
|
||||
}
|
||||
|
||||
// Assess liquidity risk
|
||||
assessment.Liquidity = calc.assessLiquidityRisk(route)
|
||||
|
||||
// Assess price impact risk
|
||||
assessment.PriceImpact = calc.assessPriceImpactRisk(priceImpact)
|
||||
|
||||
// Assess profitability risk
|
||||
profitRisk := calc.assessProfitabilityRisk(netProfit)
|
||||
|
||||
// Assess gas price risk
|
||||
assessment.GasPrice = calc.assessGasPriceRisk()
|
||||
|
||||
// Calculate overall risk (worst of all categories)
|
||||
risks := []RiskLevel{assessment.Liquidity, assessment.PriceImpact, profitRisk, assessment.GasPrice}
|
||||
assessment.Overall = calc.calculateOverallRisk(risks)
|
||||
|
||||
// Calculate OverallRisk as a numeric value (0.0 to 1.0) based on the overall risk level
|
||||
switch assessment.Overall {
|
||||
case RiskLow:
|
||||
assessment.OverallRisk = 0.1
|
||||
case RiskMedium:
|
||||
assessment.OverallRisk = 0.4
|
||||
case RiskHigh:
|
||||
assessment.OverallRisk = 0.7
|
||||
case RiskCritical:
|
||||
assessment.OverallRisk = 0.95
|
||||
default:
|
||||
assessment.OverallRisk = 0.5 // Default to medium risk
|
||||
}
|
||||
|
||||
return assessment
|
||||
}
|
||||
|
||||
// Helper risk assessment methods
|
||||
func (calc *ArbitrageCalculator) assessLiquidityRisk(route []ExchangeStep) RiskLevel {
|
||||
for _, step := range route {
|
||||
// For this simplified implementation, assume a mock liquidity value
|
||||
// In a real implementation, you'd get this from the pricing engine
|
||||
mockLiquidity, _ := calc.decimalConverter.FromString("1000", 18, "TOKEN") // 1000 tokens
|
||||
if mockLiquidity.IsZero() {
|
||||
return RiskHigh
|
||||
}
|
||||
|
||||
// Check if trade size is significant portion of liquidity (>10%)
|
||||
tenPercent, _ := calc.decimalConverter.FromString("10", 4, "PERCENT")
|
||||
tradeSizePercent, _ := calc.decimalConverter.CalculatePercentage(step.AmountIn, mockLiquidity)
|
||||
|
||||
if comp, _ := calc.decimalConverter.Compare(tradeSizePercent, tenPercent); comp > 0 {
|
||||
return RiskMedium
|
||||
}
|
||||
}
|
||||
return RiskLow
|
||||
}
|
||||
|
||||
func (calc *ArbitrageCalculator) assessPriceImpactRisk(priceImpact *UniversalDecimal) RiskLevel {
|
||||
fivePercent, _ := calc.decimalConverter.FromString("5", 4, "PERCENT")
|
||||
twoPercent, _ := calc.decimalConverter.FromString("2", 4, "PERCENT")
|
||||
|
||||
if comp, _ := calc.decimalConverter.Compare(priceImpact, fivePercent); comp > 0 {
|
||||
return RiskHigh
|
||||
}
|
||||
if comp, _ := calc.decimalConverter.Compare(priceImpact, twoPercent); comp > 0 {
|
||||
return RiskMedium
|
||||
}
|
||||
return RiskLow
|
||||
}
|
||||
|
||||
func (calc *ArbitrageCalculator) assessProfitabilityRisk(netProfit *UniversalDecimal) RiskLevel {
|
||||
if netProfit.IsNegative() {
|
||||
return RiskCritical
|
||||
}
|
||||
|
||||
smallProfit, _ := calc.decimalConverter.FromString("0.001", 18, "ETH") // $1 at $1000/ETH
|
||||
mediumProfit, _ := calc.decimalConverter.FromString("0.01", 18, "ETH") // $10 at $1000/ETH
|
||||
|
||||
if comp, _ := calc.decimalConverter.Compare(netProfit, smallProfit); comp < 0 {
|
||||
return RiskHigh
|
||||
}
|
||||
if comp, _ := calc.decimalConverter.Compare(netProfit, mediumProfit); comp < 0 {
|
||||
return RiskMedium
|
||||
}
|
||||
return RiskLow
|
||||
}
|
||||
|
||||
func (calc *ArbitrageCalculator) assessGasPriceRisk() RiskLevel {
|
||||
currentGas, _ := calc.gasEstimator.GetCurrentGasPrice()
|
||||
|
||||
if comp, _ := calc.decimalConverter.Compare(currentGas, calc.maxGasPriceGwei); comp > 0 {
|
||||
return RiskHigh
|
||||
}
|
||||
|
||||
twentyGwei, _ := calc.decimalConverter.FromString("20", 9, "GWEI")
|
||||
if comp, _ := calc.decimalConverter.Compare(currentGas, twentyGwei); comp > 0 {
|
||||
return RiskMedium
|
||||
}
|
||||
|
||||
return RiskLow
|
||||
}
|
||||
|
||||
func (calc *ArbitrageCalculator) calculateOverallRisk(risks []RiskLevel) RiskLevel {
|
||||
riskScores := map[RiskLevel]int{
|
||||
RiskLow: 1,
|
||||
RiskMedium: 2,
|
||||
RiskHigh: 3,
|
||||
RiskCritical: 4,
|
||||
}
|
||||
|
||||
maxScore := 0
|
||||
for _, risk := range risks {
|
||||
if score := riskScores[risk]; score > maxScore {
|
||||
maxScore = score
|
||||
}
|
||||
}
|
||||
|
||||
for risk, score := range riskScores {
|
||||
if score == maxScore {
|
||||
return risk
|
||||
}
|
||||
}
|
||||
return RiskLow
|
||||
}
|
||||
|
||||
// calculateConfidence calculates confidence score based on risk and profit
|
||||
func (calc *ArbitrageCalculator) calculateConfidence(risk RiskAssessment, netProfit, priceImpact *UniversalDecimal) float64 {
|
||||
baseConfidence := 0.5
|
||||
|
||||
// Adjust for risk level
|
||||
switch risk.Overall {
|
||||
case RiskLow:
|
||||
baseConfidence += 0.3
|
||||
case RiskMedium:
|
||||
baseConfidence += 0.1
|
||||
case RiskHigh:
|
||||
baseConfidence -= 0.2
|
||||
case RiskCritical:
|
||||
baseConfidence -= 0.4
|
||||
}
|
||||
|
||||
// Adjust for profit magnitude
|
||||
if netProfit.IsPositive() {
|
||||
largeProfit, _ := calc.decimalConverter.FromString("0.1", 18, "ETH")
|
||||
if comp, _ := calc.decimalConverter.Compare(netProfit, largeProfit); comp > 0 {
|
||||
baseConfidence += 0.2
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust for price impact
|
||||
lowImpact, _ := calc.decimalConverter.FromString("1", 4, "PERCENT")
|
||||
if comp, _ := calc.decimalConverter.Compare(priceImpact, lowImpact); comp < 0 {
|
||||
baseConfidence += 0.1
|
||||
}
|
||||
|
||||
if baseConfidence < 0 {
|
||||
baseConfidence = 0
|
||||
}
|
||||
if baseConfidence > 1 {
|
||||
baseConfidence = 1
|
||||
}
|
||||
|
||||
return baseConfidence
|
||||
}
|
||||
|
||||
// estimateExecutionTime estimates execution time in milliseconds
|
||||
func (calc *ArbitrageCalculator) estimateExecutionTime(route []ExchangeStep) int64 {
|
||||
baseTime := int64(500) // 500ms base
|
||||
|
||||
// Add time per hop
|
||||
hopTime := int64(len(route)) * 200
|
||||
|
||||
// Add time for complex exchanges
|
||||
complexTime := int64(0)
|
||||
for _, step := range route {
|
||||
switch ExchangeType(step.Exchange) {
|
||||
case ExchangeUniswapV3, ExchangeCamelot:
|
||||
complexTime += 300 // Concentrated liquidity is more complex
|
||||
case ExchangeBalancer, ExchangeCurve:
|
||||
complexTime += 400 // Weighted/stable pools are complex
|
||||
default:
|
||||
complexTime += 100 // Simple AMM
|
||||
}
|
||||
}
|
||||
|
||||
return baseTime + hopTime + complexTime
|
||||
}
|
||||
|
||||
// convertToETH converts any token amount to ETH for comparison (placeholder)
|
||||
func (calc *ArbitrageCalculator) convertToETH(amount *UniversalDecimal, token TokenInfo) *UniversalDecimal {
|
||||
// This is a placeholder - in production, this would query price oracles
|
||||
// For now, assume 1:1 conversion for demonstration
|
||||
ethAmount, _ := calc.decimalConverter.ConvertTo(amount, 18, "ETH")
|
||||
return ethAmount
|
||||
}
|
||||
|
||||
// IsOpportunityProfitable checks if opportunity meets minimum criteria
|
||||
func (calc *ArbitrageCalculator) IsOpportunityProfitable(opportunity *types.ArbitrageOpportunity) bool {
|
||||
// Check minimum profit threshold (simplified comparison)
|
||||
if opportunity.NetProfit.Cmp(big.NewInt(1000000000000000)) < 0 { // 0.001 ETH minimum
|
||||
return false
|
||||
}
|
||||
|
||||
// Check maximum price impact threshold (5% max)
|
||||
if opportunity.PriceImpact > 0.05 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check risk level
|
||||
if opportunity.Risk >= 0.8 { // High risk threshold
|
||||
return false
|
||||
}
|
||||
|
||||
// Check confidence threshold
|
||||
if opportunity.Confidence < 0.3 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// SortOpportunitiesByProfitability sorts opportunities by net profit descending
|
||||
func (calc *ArbitrageCalculator) SortOpportunitiesByProfitability(opportunities []*types.ArbitrageOpportunity) {
|
||||
sort.Slice(opportunities, func(i, j int) bool {
|
||||
// Simple comparison using big.Int.Cmp for sorting
|
||||
return opportunities[i].NetProfit.Cmp(opportunities[j].NetProfit) > 0 // Descending order (highest profit first)
|
||||
})
|
||||
}
|
||||
|
||||
// CalculateArbitrage calculates arbitrage opportunity for a given path and input amount
|
||||
func (calc *ArbitrageCalculator) CalculateArbitrage(ctx context.Context, inputAmount *UniversalDecimal, path []*PoolData) (*types.ArbitrageOpportunity, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, fmt.Errorf("empty path provided")
|
||||
}
|
||||
|
||||
// Get the input and output tokens for the path
|
||||
inputToken := path[0].Token0
|
||||
outputToken := path[len(path)-1].Token1
|
||||
if path[len(path)-1].Token0.Address == inputToken.Address {
|
||||
outputToken = path[len(path)-1].Token0
|
||||
}
|
||||
|
||||
// Calculate the arbitrage opportunity for this path
|
||||
opportunity, err := calc.CalculateArbitrageOpportunity(path, inputAmount, inputToken, outputToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to calculate arbitrage opportunity: %w", err)
|
||||
}
|
||||
|
||||
return opportunity, nil
|
||||
}
|
||||
|
||||
// FindOptimalPath finds the most profitable arbitrage path between two tokens
|
||||
func (calc *ArbitrageCalculator) FindOptimalPath(ctx context.Context, tokenA, tokenB common.Address, amount *UniversalDecimal) (*types.ArbitrageOpportunity, error) {
|
||||
// In a real implementation, this would query for available paths between tokens
|
||||
// and calculate the most profitable path. For this implementation, we'll return an error
|
||||
// indicating no path is available since we don't have direct path-finding ability in the calculator
|
||||
return nil, fmt.Errorf("FindOptimalPath not implemented in calculator - use executor.CalculateOptimalPath instead")
|
||||
}
|
||||
|
||||
// FilterProfitableOpportunities returns only profitable opportunities
|
||||
func (calc *ArbitrageCalculator) FilterProfitableOpportunities(opportunities []*types.ArbitrageOpportunity) []*types.ArbitrageOpportunity {
|
||||
profitable := make([]*types.ArbitrageOpportunity, 0)
|
||||
|
||||
for _, opp := range opportunities {
|
||||
if calc.IsOpportunityProfitable(opp) {
|
||||
profitable = append(profitable, opp)
|
||||
}
|
||||
}
|
||||
|
||||
return profitable
|
||||
}
|
||||
96
pkg/math/cached_bench_test.go
Normal file
96
pkg/math/cached_bench_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/holiman/uint256"
|
||||
)
|
||||
|
||||
// Benchmark original vs cached SqrtPriceX96ToPrice conversion
|
||||
func BenchmarkSqrtPriceX96ToPriceOriginal(b *testing.B) {
|
||||
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
|
||||
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Original calculation: price = sqrtPriceX96^2 / 2^192
|
||||
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
|
||||
price := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
|
||||
|
||||
// Calculate 2^192
|
||||
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
q192Float := new(big.Float).SetInt(q192)
|
||||
|
||||
// Divide by 2^192
|
||||
price.Quo(price, q192Float)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSqrtPriceX96ToPriceCached(b *testing.B) {
|
||||
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
|
||||
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Cached calculation using precomputed constants
|
||||
SqrtPriceX96ToPriceCached(sqrtPriceX96)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark original vs cached PriceToSqrtPriceX96 conversion
|
||||
func BenchmarkPriceToSqrtPriceX96Original(b *testing.B) {
|
||||
// Use a typical price value (represents price of ~2000 USDC/ETH)
|
||||
price := new(big.Float).SetFloat64(2000.0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Original calculation: sqrtPriceX96 = sqrt(price * 2^192)
|
||||
|
||||
// Calculate 2^192
|
||||
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
q192Float := new(big.Float).SetInt(q192)
|
||||
|
||||
// Multiply price by 2^192
|
||||
result := new(big.Float).Mul(price, q192Float)
|
||||
|
||||
// Calculate square root
|
||||
result.Sqrt(result)
|
||||
|
||||
// Convert to big.Int
|
||||
sqrtPriceX96 := new(big.Int)
|
||||
result.Int(sqrtPriceX96)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPriceToSqrtPriceX96Cached(b *testing.B) {
|
||||
// Use a typical price value (represents price of ~2000 USDC/ETH)
|
||||
price := new(big.Float).SetFloat64(2000.0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Cached calculation using precomputed constants
|
||||
PriceToSqrtPriceX96Cached(price)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark optimized versions with uint256
|
||||
func BenchmarkSqrtPriceX96ToPriceOptimized(b *testing.B) {
|
||||
// Use a typical sqrtPriceX96 value
|
||||
sqrtPriceX96 := uint256.NewInt(0).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
SqrtPriceX96ToPriceOptimized(sqrtPriceX96)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPriceToSqrtPriceX96Optimized(b *testing.B) {
|
||||
// Use a typical price value
|
||||
price := new(big.Float).SetFloat64(2000.0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
PriceToSqrtPriceX96Optimized(price)
|
||||
}
|
||||
}
|
||||
125
pkg/math/cached_functions.go
Normal file
125
pkg/math/cached_functions.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"sync"
|
||||
|
||||
"github.com/fraktal/mev-beta/pkg/uniswap"
|
||||
"github.com/holiman/uint256"
|
||||
)
|
||||
|
||||
// Cached mathematical constants to avoid recomputation
|
||||
var (
|
||||
cachedConstantsOnce sync.Once
|
||||
cachedQ192 *big.Int
|
||||
cachedQ96 *big.Int
|
||||
cachedQ384 *big.Int
|
||||
cachedTwoPower96 *big.Float
|
||||
cachedTwoPower192 *big.Float
|
||||
cachedTwoPower384 *big.Float
|
||||
)
|
||||
|
||||
// initCachedConstants initializes all cached constants once
|
||||
func initCachedConstants() {
|
||||
cachedConstantsOnce.Do(func() {
|
||||
// Calculate 2^96
|
||||
cachedQ96 = new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)
|
||||
|
||||
// Calculate 2^192
|
||||
cachedQ192 = new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
|
||||
// Calculate 2^384
|
||||
cachedQ384 = new(big.Int).Exp(big.NewInt(2), big.NewInt(384), nil)
|
||||
|
||||
// Convert to big.Float for division operations
|
||||
cachedTwoPower96 = new(big.Float).SetInt(cachedQ96)
|
||||
cachedTwoPower192 = new(big.Float).SetInt(cachedQ192)
|
||||
cachedTwoPower384 = new(big.Float).SetInt(cachedQ384)
|
||||
})
|
||||
}
|
||||
|
||||
// GetCachedQ192 returns the cached value of 2^192
|
||||
func GetCachedQ192() *big.Int {
|
||||
initCachedConstants()
|
||||
return cachedQ192
|
||||
}
|
||||
|
||||
// GetCachedQ96 returns the cached value of 2^96
|
||||
func GetCachedQ96() *big.Int {
|
||||
initCachedConstants()
|
||||
return cachedQ96
|
||||
}
|
||||
|
||||
// GetCachedQ384 returns the cached value of 2^384
|
||||
func GetCachedQ384() *big.Int {
|
||||
initCachedConstants()
|
||||
return cachedQ384
|
||||
}
|
||||
|
||||
// SqrtPriceX96ToPriceCached converts sqrtPriceX96 to a price using cached constants
|
||||
// Formula: price = sqrtPriceX96^2 / 2^192
|
||||
func SqrtPriceX96ToPriceCached(sqrtPriceX96 *big.Int) *big.Float {
|
||||
initCachedConstants()
|
||||
|
||||
// Convert to big.Float for precision
|
||||
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
|
||||
|
||||
// Calculate sqrtPrice^2
|
||||
price := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
|
||||
|
||||
// Divide by 2^192 using cached constant
|
||||
price.Quo(price, cachedTwoPower192)
|
||||
|
||||
return price
|
||||
}
|
||||
|
||||
// PriceToSqrtPriceX96Cached converts a price to sqrtPriceX96 using cached constants
|
||||
// Formula: sqrtPriceX96 = sqrt(price * 2^192)
|
||||
func PriceToSqrtPriceX96Cached(price *big.Float) *big.Int {
|
||||
initCachedConstants()
|
||||
|
||||
// Multiply price by 2^192
|
||||
result := new(big.Float).Mul(price, cachedTwoPower192)
|
||||
|
||||
// Calculate square root
|
||||
result.Sqrt(result)
|
||||
|
||||
// Convert to big.Int
|
||||
sqrtPriceX96 := new(big.Int)
|
||||
result.Int(sqrtPriceX96)
|
||||
|
||||
return sqrtPriceX96
|
||||
}
|
||||
|
||||
// SqrtPriceX96ToPriceOptimized converts sqrtPriceX96 to a price using optimized uint256 operations
|
||||
// Formula: price = sqrtPriceX96^2 / 2^192
|
||||
func SqrtPriceX96ToPriceOptimized(sqrtPriceX96 *uint256.Int) *big.Float {
|
||||
initCachedConstants()
|
||||
|
||||
// Convert to big.Int for calculation
|
||||
sqrtPriceBig := sqrtPriceX96.ToBig()
|
||||
|
||||
// Use cached function for consistency
|
||||
return SqrtPriceX96ToPriceCached(sqrtPriceBig)
|
||||
}
|
||||
|
||||
// PriceToSqrtPriceX96Optimized converts a price to sqrtPriceX96 using optimized operations
|
||||
// Formula: sqrtPriceX96 = sqrt(price * 2^192)
|
||||
func PriceToSqrtPriceX96Optimized(price *big.Float) *uint256.Int {
|
||||
initCachedConstants()
|
||||
|
||||
// Use cached function for consistency
|
||||
sqrtPriceBig := PriceToSqrtPriceX96Cached(price)
|
||||
|
||||
// Convert to uint256
|
||||
return uint256.MustFromBig(sqrtPriceBig)
|
||||
}
|
||||
|
||||
// TickToSqrtPriceX96Optimized calculates sqrtPriceX96 from a tick using optimized operations
|
||||
// Formula: sqrtPriceX96 = 1.0001^(tick/2)
|
||||
func TickToSqrtPriceX96Optimized(tick int) *uint256.Int {
|
||||
// For simplicity, we'll convert to big.Int and use existing implementation
|
||||
tickBig := big.NewInt(int64(tick))
|
||||
sqrtPriceBig := uniswap.TickToSqrtPriceX96(int(tickBig.Int64()))
|
||||
return uint256.MustFromBig(sqrtPriceBig)
|
||||
}
|
||||
127
pkg/math/cached_test.go
Normal file
127
pkg/math/cached_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/holiman/uint256"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test that cached functions produce the same results as original implementations
|
||||
func TestCachedFunctionAccuracy(t *testing.T) {
|
||||
// Test SqrtPriceX96ToPrice functions
|
||||
t.Run("SqrtPriceX96ToPrice", func(t *testing.T) {
|
||||
// Use a typical sqrtPriceX96 value (represents price of ~2000 USDC/ETH)
|
||||
sqrtPriceX96 := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
|
||||
|
||||
// Original calculation
|
||||
sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96)
|
||||
originalPrice := new(big.Float).Mul(sqrtPriceFloat, sqrtPriceFloat)
|
||||
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
q192Float := new(big.Float).SetInt(q192)
|
||||
originalPrice.Quo(originalPrice, q192Float)
|
||||
|
||||
// Cached calculation
|
||||
cachedPrice := SqrtPriceX96ToPriceCached(sqrtPriceX96)
|
||||
|
||||
// Compare results (should be identical)
|
||||
assert.Equal(t, originalPrice.String(), cachedPrice.String(), "Cached and original SqrtPriceX96ToPrice should produce identical results")
|
||||
})
|
||||
|
||||
// Test PriceToSqrtPriceX96 functions
|
||||
t.Run("PriceToSqrtPriceX96", func(t *testing.T) {
|
||||
// Use a typical price value (represents price of ~2000 USDC/ETH)
|
||||
price := new(big.Float).SetFloat64(2000.0)
|
||||
|
||||
// Original calculation
|
||||
q192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
q192Float := new(big.Float).SetInt(q192)
|
||||
result := new(big.Float).Mul(price, q192Float)
|
||||
result.Sqrt(result)
|
||||
expectedSqrtPriceX96 := new(big.Int)
|
||||
result.Int(expectedSqrtPriceX96)
|
||||
|
||||
// Cached calculation
|
||||
actualSqrtPriceX96 := PriceToSqrtPriceX96Cached(price)
|
||||
|
||||
// Compare results (should be identical)
|
||||
assert.Equal(t, expectedSqrtPriceX96.String(), actualSqrtPriceX96.String(), "Cached and original PriceToSqrtPriceX96 should produce identical results")
|
||||
})
|
||||
|
||||
// Test optimized functions with uint256
|
||||
t.Run("SqrtPriceX96ToPriceOptimized", func(t *testing.T) {
|
||||
// Use a typical sqrtPriceX96 value
|
||||
sqrtPriceX96Big := new(big.Int).SetBytes([]byte{0x06, 0x40, 0x84, 0x4A, 0x0E, 0x81, 0x4F, 0x96, 0x19, 0xC1, 0x9C, 0x08})
|
||||
sqrtPriceX96 := uint256.MustFromBig(sqrtPriceX96Big)
|
||||
|
||||
// Cached calculation
|
||||
cachedResult := SqrtPriceX96ToPriceCached(sqrtPriceX96Big)
|
||||
|
||||
// Optimized calculation
|
||||
optimizedResult := SqrtPriceX96ToPriceOptimized(sqrtPriceX96)
|
||||
|
||||
// Compare results (should be identical)
|
||||
assert.Equal(t, cachedResult.String(), optimizedResult.String(), "Optimized and cached SqrtPriceX96ToPrice should produce identical results")
|
||||
})
|
||||
|
||||
// Test optimized functions with uint256
|
||||
t.Run("PriceToSqrtPriceX96Optimized", func(t *testing.T) {
|
||||
// Use a typical price value
|
||||
price := new(big.Float).SetFloat64(2000.0)
|
||||
|
||||
// Cached calculation
|
||||
cachedResult := PriceToSqrtPriceX96Cached(price)
|
||||
|
||||
// Optimized calculation
|
||||
optimizedResult := PriceToSqrtPriceX96Optimized(price)
|
||||
|
||||
// Compare results (should be identical)
|
||||
assert.Equal(t, cachedResult.String(), optimizedResult.ToBig().String(), "Optimized and cached PriceToSqrtPriceX96 should produce identical results")
|
||||
})
|
||||
}
|
||||
|
||||
// Test that cached constants are working correctly
|
||||
func TestCachedConstants(t *testing.T) {
|
||||
// Test that Q192 is correctly calculated
|
||||
expectedQ192 := new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil)
|
||||
actualQ192 := GetCachedQ192()
|
||||
assert.Equal(t, expectedQ192.String(), actualQ192.String(), "Cached Q192 should equal 2^192")
|
||||
|
||||
// Test that Q96 is correctly calculated
|
||||
expectedQ96 := new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)
|
||||
actualQ96 := GetCachedQ96()
|
||||
assert.Equal(t, expectedQ96.String(), actualQ96.String(), "Cached Q96 should equal 2^96")
|
||||
|
||||
// Test that Q384 is correctly calculated
|
||||
expectedQ384 := new(big.Int).Exp(big.NewInt(2), big.NewInt(384), nil)
|
||||
actualQ384 := GetCachedQ384()
|
||||
assert.Equal(t, expectedQ384.String(), actualQ384.String(), "Cached Q384 should equal 2^384")
|
||||
}
|
||||
|
||||
// Test edge cases
|
||||
func TestEdgeCases(t *testing.T) {
|
||||
// Test with zero values
|
||||
zero := big.NewInt(0)
|
||||
zeroFloat := new(big.Float).SetInt64(0)
|
||||
|
||||
// SqrtPriceX96ToPrice with zero
|
||||
result := SqrtPriceX96ToPriceCached(zero)
|
||||
assert.Equal(t, "0", result.String(), "SqrtPriceX96ToPriceCached with zero should return zero")
|
||||
|
||||
// PriceToSqrtPriceX96 with zero
|
||||
result2 := PriceToSqrtPriceX96Cached(zeroFloat)
|
||||
assert.Equal(t, "0", result2.String(), "PriceToSqrtPriceX96Cached with zero should return zero")
|
||||
|
||||
// Test with small values
|
||||
one := big.NewInt(1)
|
||||
oneFloat := new(big.Float).SetInt64(1)
|
||||
|
||||
// SqrtPriceX96ToPrice with one
|
||||
result3 := SqrtPriceX96ToPriceCached(one)
|
||||
assert.NotEmpty(t, result3.String(), "SqrtPriceX96ToPriceCached with one should return a value")
|
||||
|
||||
// PriceToSqrtPriceX96 with one
|
||||
result4 := PriceToSqrtPriceX96Cached(oneFloat)
|
||||
assert.NotEmpty(t, result4.String(), "PriceToSqrtPriceX96Cached with one should return a value")
|
||||
}
|
||||
405
pkg/math/decimal_handler.go
Normal file
405
pkg/math/decimal_handler.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// UniversalDecimal represents a token amount with precise decimal handling
|
||||
type UniversalDecimal struct {
|
||||
Value *big.Int // Raw value in smallest unit
|
||||
Decimals uint8 // Number of decimal places (0-18)
|
||||
Symbol string // Token symbol for debugging
|
||||
}
|
||||
|
||||
// DecimalConverter handles conversions between different decimal precisions
|
||||
type DecimalConverter struct {
|
||||
// Cache for common scaling factors to avoid repeated calculations
|
||||
scalingFactors map[uint8]*big.Int
|
||||
}
|
||||
|
||||
// NewDecimalConverter creates a new decimal converter with caching
|
||||
func NewDecimalConverter() *DecimalConverter {
|
||||
dc := &DecimalConverter{
|
||||
scalingFactors: make(map[uint8]*big.Int),
|
||||
}
|
||||
|
||||
// Pre-calculate common scaling factors (0-18 decimals)
|
||||
for i := uint8(0); i <= 18; i++ {
|
||||
dc.scalingFactors[i] = new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(i)), nil)
|
||||
}
|
||||
|
||||
return dc
|
||||
}
|
||||
|
||||
// NewUniversalDecimal creates a new universal decimal with validation
|
||||
func NewUniversalDecimal(value *big.Int, decimals uint8, symbol string) (*UniversalDecimal, error) {
|
||||
if decimals > 18 {
|
||||
return nil, fmt.Errorf("decimal places cannot exceed 18, got %d for token %s", decimals, symbol)
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
value = big.NewInt(0)
|
||||
}
|
||||
|
||||
// Copy the value to prevent external modifications
|
||||
valueCopy := new(big.Int).Set(value)
|
||||
|
||||
return &UniversalDecimal{
|
||||
Value: valueCopy,
|
||||
Decimals: decimals,
|
||||
Symbol: symbol,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FromString creates UniversalDecimal from string representation
|
||||
// Intelligently determines format:
|
||||
// 1. Very large numbers (length >= decimals): treated as raw wei/smallest unit
|
||||
// 2. Small numbers (length < decimals): treated as human-readable units
|
||||
// 3. Numbers with decimal point: always treated as human-readable
|
||||
func (dc *DecimalConverter) FromString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) {
|
||||
// Handle empty or zero values
|
||||
if valueStr == "" || valueStr == "0" {
|
||||
return NewUniversalDecimal(big.NewInt(0), decimals, symbol)
|
||||
}
|
||||
|
||||
// Remove any whitespace
|
||||
valueStr = strings.TrimSpace(valueStr)
|
||||
|
||||
// Check for decimal point - if present, treat as human-readable decimal
|
||||
if strings.Contains(valueStr, ".") {
|
||||
return dc.fromDecimalString(valueStr, decimals, symbol)
|
||||
}
|
||||
|
||||
// For integers without decimal point, we need to determine if this is:
|
||||
// - A raw value (like "1000000000000000000" = 1000000000000000000 wei)
|
||||
// - A human-readable value (like "1" = 1.0 ETH = 1000000000000000000 wei)
|
||||
|
||||
// Parse the number first
|
||||
value := new(big.Int)
|
||||
_, success := value.SetString(valueStr, 10)
|
||||
if !success {
|
||||
return nil, fmt.Errorf("invalid number format: %s for token %s", valueStr, symbol)
|
||||
}
|
||||
|
||||
// Heuristic: if the string length is >= decimals, treat as raw value
|
||||
// This handles cases like "1000000000000000000" (18 chars, 18 decimals) as raw
|
||||
// But treats "1" (1 char, 18 decimals) as human-readable
|
||||
if len(valueStr) >= int(decimals) && decimals > 0 {
|
||||
// Treat as raw value in smallest unit
|
||||
return NewUniversalDecimal(value, decimals, symbol)
|
||||
}
|
||||
|
||||
// Treat as human-readable value - convert to smallest unit
|
||||
scalingFactor := dc.getScalingFactor(decimals)
|
||||
scaledValue := new(big.Int).Mul(value, scalingFactor)
|
||||
|
||||
return NewUniversalDecimal(scaledValue, decimals, symbol)
|
||||
}
|
||||
|
||||
// fromDecimalString parses decimal string (e.g., "1.23") to smallest unit
|
||||
func (dc *DecimalConverter) fromDecimalString(valueStr string, decimals uint8, symbol string) (*UniversalDecimal, error) {
|
||||
parts := strings.Split(valueStr, ".")
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid decimal format: %s for token %s", valueStr, symbol)
|
||||
}
|
||||
|
||||
integerPart := parts[0]
|
||||
decimalPart := parts[1]
|
||||
|
||||
// Validate decimal part doesn't exceed token decimals
|
||||
if len(decimalPart) > int(decimals) {
|
||||
return nil, fmt.Errorf("decimal part %s has %d digits, but token %s only supports %d decimals",
|
||||
decimalPart, len(decimalPart), symbol, decimals)
|
||||
}
|
||||
|
||||
// Parse integer part
|
||||
intValue := new(big.Int)
|
||||
if integerPart != "" && integerPart != "0" {
|
||||
_, success := intValue.SetString(integerPart, 10)
|
||||
if !success {
|
||||
return nil, fmt.Errorf("invalid integer part: %s for token %s", integerPart, symbol)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse decimal part
|
||||
decValue := new(big.Int)
|
||||
if decimalPart != "" && decimalPart != "0" {
|
||||
// Pad decimal part to full precision
|
||||
paddedDecimal := decimalPart
|
||||
for len(paddedDecimal) < int(decimals) {
|
||||
paddedDecimal += "0"
|
||||
}
|
||||
|
||||
_, success := decValue.SetString(paddedDecimal, 10)
|
||||
if !success {
|
||||
return nil, fmt.Errorf("invalid decimal part: %s for token %s", decimalPart, symbol)
|
||||
}
|
||||
}
|
||||
|
||||
// Combine integer and decimal parts
|
||||
scalingFactor := dc.getScalingFactor(decimals)
|
||||
totalValue := new(big.Int).Mul(intValue, scalingFactor)
|
||||
totalValue.Add(totalValue, decValue)
|
||||
|
||||
return NewUniversalDecimal(totalValue, decimals, symbol)
|
||||
}
|
||||
|
||||
// ToHumanReadable converts to human-readable decimal string
|
||||
// For round-trip precision preservation with FromString, returns raw value when appropriate
|
||||
func (dc *DecimalConverter) ToHumanReadable(ud *UniversalDecimal) string {
|
||||
if ud.Value.Sign() == 0 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
// For round-trip precision preservation, if the value represents exact units
|
||||
// (like 1000000000000000000 wei = exactly 1 ETH), output the human readable form
|
||||
// Otherwise, output the raw value to preserve precision
|
||||
|
||||
if ud.Decimals == 0 {
|
||||
return ud.Value.String()
|
||||
}
|
||||
|
||||
scalingFactor := dc.getScalingFactor(ud.Decimals)
|
||||
|
||||
// Get integer and remainder parts
|
||||
integerPart := new(big.Int).Div(ud.Value, scalingFactor)
|
||||
remainder := new(big.Int).Mod(ud.Value, scalingFactor)
|
||||
|
||||
// If this is an exact unit (no fractional part), return human readable
|
||||
if remainder.Sign() == 0 {
|
||||
return integerPart.String()
|
||||
}
|
||||
|
||||
// For values with fractional parts, we need to decide:
|
||||
// If the value looks like it came from raw input (very large numbers),
|
||||
// preserve it as raw to maintain round-trip precision
|
||||
|
||||
// Check if this looks like a raw value by comparing magnitude
|
||||
valueStr := ud.Value.String()
|
||||
if len(valueStr) >= int(ud.Decimals) {
|
||||
// This is likely a raw value, preserve as raw for round-trip
|
||||
return ud.Value.String()
|
||||
}
|
||||
|
||||
// Format as human readable decimal
|
||||
decimalStr := remainder.String()
|
||||
for len(decimalStr) < int(ud.Decimals) {
|
||||
decimalStr = "0" + decimalStr
|
||||
}
|
||||
|
||||
// Remove trailing zeros for readability
|
||||
decimalStr = strings.TrimRight(decimalStr, "0")
|
||||
if decimalStr == "" {
|
||||
return integerPart.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s.%s", integerPart.String(), decimalStr)
|
||||
}
|
||||
|
||||
// ConvertTo converts between different decimal precisions
|
||||
func (dc *DecimalConverter) ConvertTo(from *UniversalDecimal, toDecimals uint8, toSymbol string) (*UniversalDecimal, error) {
|
||||
if from.Decimals == toDecimals {
|
||||
// Same precision, just copy with new symbol
|
||||
return NewUniversalDecimal(from.Value, toDecimals, toSymbol)
|
||||
}
|
||||
|
||||
var convertedValue *big.Int
|
||||
|
||||
if from.Decimals < toDecimals {
|
||||
// Increase precision (multiply)
|
||||
decimalDiff := toDecimals - from.Decimals
|
||||
scalingFactor := dc.getScalingFactor(decimalDiff)
|
||||
convertedValue = new(big.Int).Mul(from.Value, scalingFactor)
|
||||
} else {
|
||||
// Decrease precision (divide with rounding)
|
||||
decimalDiff := from.Decimals - toDecimals
|
||||
scalingFactor := dc.getScalingFactor(decimalDiff)
|
||||
|
||||
// Round to nearest (banker's rounding)
|
||||
halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2))
|
||||
roundedValue := new(big.Int).Add(from.Value, halfScaling)
|
||||
convertedValue = new(big.Int).Div(roundedValue, scalingFactor)
|
||||
}
|
||||
|
||||
return NewUniversalDecimal(convertedValue, toDecimals, toSymbol)
|
||||
}
|
||||
|
||||
// Multiply performs precise multiplication between different decimal tokens
|
||||
func (dc *DecimalConverter) Multiply(a, b *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) {
|
||||
// Multiply raw values
|
||||
product := new(big.Int).Mul(a.Value, b.Value)
|
||||
|
||||
// Adjust for decimal places (division by 10^(a.decimals + b.decimals - result.decimals))
|
||||
totalInputDecimals := a.Decimals + b.Decimals
|
||||
|
||||
var adjustedProduct *big.Int
|
||||
if totalInputDecimals >= resultDecimals {
|
||||
decimalDiff := totalInputDecimals - resultDecimals
|
||||
scalingFactor := dc.getScalingFactor(decimalDiff)
|
||||
|
||||
// Round to nearest
|
||||
halfScaling := new(big.Int).Div(scalingFactor, big.NewInt(2))
|
||||
roundedProduct := new(big.Int).Add(product, halfScaling)
|
||||
adjustedProduct = new(big.Int).Div(roundedProduct, scalingFactor)
|
||||
} else {
|
||||
decimalDiff := resultDecimals - totalInputDecimals
|
||||
scalingFactor := dc.getScalingFactor(decimalDiff)
|
||||
adjustedProduct = new(big.Int).Mul(product, scalingFactor)
|
||||
}
|
||||
|
||||
return NewUniversalDecimal(adjustedProduct, resultDecimals, resultSymbol)
|
||||
}
|
||||
|
||||
// Divide performs precise division between different decimal tokens
|
||||
func (dc *DecimalConverter) Divide(numerator, denominator *UniversalDecimal, resultDecimals uint8, resultSymbol string) (*UniversalDecimal, error) {
|
||||
if denominator.Value.Sign() == 0 {
|
||||
return nil, fmt.Errorf("division by zero: %s / %s", numerator.Symbol, denominator.Symbol)
|
||||
}
|
||||
|
||||
// Scale numerator to maintain precision
|
||||
totalDecimals := numerator.Decimals + resultDecimals
|
||||
scalingFactor := dc.getScalingFactor(totalDecimals - denominator.Decimals)
|
||||
|
||||
scaledNumerator := new(big.Int).Mul(numerator.Value, scalingFactor)
|
||||
quotient := new(big.Int).Div(scaledNumerator, denominator.Value)
|
||||
|
||||
return NewUniversalDecimal(quotient, resultDecimals, resultSymbol)
|
||||
}
|
||||
|
||||
// Add adds two UniversalDecimals with same precision
|
||||
func (dc *DecimalConverter) Add(a, b *UniversalDecimal) (*UniversalDecimal, error) {
|
||||
if a.Decimals != b.Decimals {
|
||||
return nil, fmt.Errorf("cannot add tokens with different decimals: %s(%d) + %s(%d)",
|
||||
a.Symbol, a.Decimals, b.Symbol, b.Decimals)
|
||||
}
|
||||
|
||||
sum := new(big.Int).Add(a.Value, b.Value)
|
||||
resultSymbol := a.Symbol
|
||||
if a.Symbol != b.Symbol {
|
||||
resultSymbol = fmt.Sprintf("%s+%s", a.Symbol, b.Symbol)
|
||||
}
|
||||
|
||||
return NewUniversalDecimal(sum, a.Decimals, resultSymbol)
|
||||
}
|
||||
|
||||
// Subtract subtracts two UniversalDecimals with same precision
|
||||
func (dc *DecimalConverter) Subtract(a, b *UniversalDecimal) (*UniversalDecimal, error) {
|
||||
if a.Decimals != b.Decimals {
|
||||
return nil, fmt.Errorf("cannot subtract tokens with different decimals: %s(%d) - %s(%d)",
|
||||
a.Symbol, a.Decimals, b.Symbol, b.Decimals)
|
||||
}
|
||||
|
||||
diff := new(big.Int).Sub(a.Value, b.Value)
|
||||
resultSymbol := a.Symbol
|
||||
if a.Symbol != b.Symbol {
|
||||
resultSymbol = fmt.Sprintf("%s-%s", a.Symbol, b.Symbol)
|
||||
}
|
||||
|
||||
return NewUniversalDecimal(diff, a.Decimals, resultSymbol)
|
||||
}
|
||||
|
||||
// Compare returns -1, 0, or 1 for a < b, a == b, a > b respectively
|
||||
func (dc *DecimalConverter) Compare(a, b *UniversalDecimal) (int, error) {
|
||||
if a.Decimals != b.Decimals {
|
||||
// Convert to same precision for comparison
|
||||
converted, err := dc.ConvertTo(b, a.Decimals, b.Symbol)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot compare tokens with different decimals: %w", err)
|
||||
}
|
||||
b = converted
|
||||
}
|
||||
|
||||
return a.Value.Cmp(b.Value), nil
|
||||
}
|
||||
|
||||
// IsZero checks if the value is zero
|
||||
func (ud *UniversalDecimal) IsZero() bool {
|
||||
return ud.Value.Sign() == 0
|
||||
}
|
||||
|
||||
// IsPositive checks if the value is positive
|
||||
func (ud *UniversalDecimal) IsPositive() bool {
|
||||
return ud.Value.Sign() > 0
|
||||
}
|
||||
|
||||
// IsNegative checks if the value is negative
|
||||
func (ud *UniversalDecimal) IsNegative() bool {
|
||||
return ud.Value.Sign() < 0
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the UniversalDecimal
|
||||
func (ud *UniversalDecimal) Copy() *UniversalDecimal {
|
||||
return &UniversalDecimal{
|
||||
Value: new(big.Int).Set(ud.Value),
|
||||
Decimals: ud.Decimals,
|
||||
Symbol: ud.Symbol,
|
||||
}
|
||||
}
|
||||
|
||||
// getScalingFactor returns the scaling factor for given decimals (cached)
|
||||
func (dc *DecimalConverter) getScalingFactor(decimals uint8) *big.Int {
|
||||
if factor, exists := dc.scalingFactors[decimals]; exists {
|
||||
return factor
|
||||
}
|
||||
|
||||
// Calculate and cache if not exists (shouldn't happen for 0-18)
|
||||
factor := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimals)), nil)
|
||||
dc.scalingFactors[decimals] = factor
|
||||
return factor
|
||||
}
|
||||
|
||||
// ToWei converts any decimal precision to 18-decimal wei representation
|
||||
func (dc *DecimalConverter) ToWei(ud *UniversalDecimal) *UniversalDecimal {
|
||||
weiValue, _ := dc.ConvertTo(ud, 18, "WEI")
|
||||
return weiValue
|
||||
}
|
||||
|
||||
// FromWei converts 18-decimal wei to specified decimal precision
|
||||
func (dc *DecimalConverter) FromWei(weiValue *big.Int, targetDecimals uint8, targetSymbol string) *UniversalDecimal {
|
||||
weiDecimal := &UniversalDecimal{
|
||||
Value: new(big.Int).Set(weiValue),
|
||||
Decimals: 18,
|
||||
Symbol: "WEI",
|
||||
}
|
||||
|
||||
result, _ := dc.ConvertTo(weiDecimal, targetDecimals, targetSymbol)
|
||||
return result
|
||||
}
|
||||
|
||||
// CalculatePercentage calculates percentage with precise decimal handling
|
||||
// Returns percentage as UniversalDecimal with 4 decimal places (e.g., 1.5000% = 15000 with 4 decimals)
|
||||
func (dc *DecimalConverter) CalculatePercentage(value, total *UniversalDecimal) (*UniversalDecimal, error) {
|
||||
if total.IsZero() {
|
||||
return nil, fmt.Errorf("cannot calculate percentage with zero total")
|
||||
}
|
||||
|
||||
// Convert to same precision if needed
|
||||
if value.Decimals != total.Decimals {
|
||||
convertedValue, err := dc.ConvertTo(value, total.Decimals, value.Symbol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting decimals for percentage: %w", err)
|
||||
}
|
||||
value = convertedValue
|
||||
}
|
||||
|
||||
// Calculate (value / total) * 100 using integer arithmetic to avoid floating point errors
|
||||
// Formula: (value * 100 * 10^4) / total where 10^4 gives us 4 decimal places
|
||||
|
||||
// Multiply value by 100 * 10^4 = 1,000,000 for percentage with 4 decimal places
|
||||
hundredWithDecimals := big.NewInt(1000000) // 100.0000 in 4-decimal format
|
||||
numerator := new(big.Int).Mul(value.Value, hundredWithDecimals)
|
||||
|
||||
// Divide by total to get percentage
|
||||
percentage := new(big.Int).Div(numerator, total.Value)
|
||||
|
||||
return NewUniversalDecimal(percentage, 4, "PERCENT")
|
||||
}
|
||||
|
||||
// String returns string representation for debugging
|
||||
func (ud *UniversalDecimal) String() string {
|
||||
dc := NewDecimalConverter()
|
||||
humanReadable := dc.ToHumanReadable(ud)
|
||||
return fmt.Sprintf("%s %s", humanReadable, ud.Symbol)
|
||||
}
|
||||
@@ -113,6 +113,11 @@ func (u *UniswapV2Math) CalculateAmountIn(amountOut, reserveIn, reserveOut *big.
|
||||
|
||||
// CalculatePriceImpact calculates price impact for Uniswap V2
|
||||
func (u *UniswapV2Math) CalculatePriceImpact(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
|
||||
// Check for nil pointers first
|
||||
if amountIn == nil || reserveIn == nil || reserveOut == nil {
|
||||
return 0, fmt.Errorf("nil pointer encountered")
|
||||
}
|
||||
|
||||
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return 0, fmt.Errorf("invalid amounts")
|
||||
}
|
||||
@@ -236,6 +241,11 @@ func (u *UniswapV3Math) CalculateAmountIn(amountOut, sqrtPriceX96, liquidity *bi
|
||||
|
||||
// CalculatePriceImpact calculates price impact for Uniswap V3
|
||||
func (u *UniswapV3Math) CalculatePriceImpact(amountIn, sqrtPriceX96, liquidity *big.Int) (float64, error) {
|
||||
// Check for nil pointers first
|
||||
if amountIn == nil || sqrtPriceX96 == nil || liquidity == nil {
|
||||
return 0, fmt.Errorf("nil pointer encountered")
|
||||
}
|
||||
|
||||
if amountIn.Sign() <= 0 || sqrtPriceX96.Sign() <= 0 || liquidity.Sign() <= 0 {
|
||||
return 0, fmt.Errorf("invalid parameters")
|
||||
}
|
||||
@@ -428,6 +438,15 @@ func (c *CurveMath) CalculateAmountIn(amountOut, balance0, balance1 *big.Int, fe
|
||||
|
||||
// CalculatePriceImpact calculates price impact for Curve
|
||||
func (c *CurveMath) CalculatePriceImpact(amountIn, balance0, balance1 *big.Int) (float64, error) {
|
||||
// Check for nil pointers first
|
||||
if amountIn == nil || balance0 == nil || balance1 == nil {
|
||||
return 0, fmt.Errorf("nil pointer encountered")
|
||||
}
|
||||
|
||||
if amountIn.Sign() <= 0 || balance0.Sign() <= 0 || balance1.Sign() <= 0 {
|
||||
return 0, fmt.Errorf("invalid amounts")
|
||||
}
|
||||
|
||||
// Price before = balance1 / balance0
|
||||
priceBefore := new(big.Float).Quo(new(big.Float).SetInt(balance1), new(big.Float).SetInt(balance0))
|
||||
|
||||
@@ -709,6 +728,11 @@ func (b *BalancerMath) CalculateAmountIn(amountOut, reserveIn, reserveOut *big.I
|
||||
|
||||
// CalculatePriceImpact calculates price impact for Balancer
|
||||
func (b *BalancerMath) CalculatePriceImpact(amountIn, reserveIn, reserveOut *big.Int) (float64, error) {
|
||||
// Check for nil pointers first
|
||||
if amountIn == nil || reserveIn == nil || reserveOut == nil {
|
||||
return 0, fmt.Errorf("nil pointer encountered")
|
||||
}
|
||||
|
||||
if amountIn.Sign() <= 0 || reserveIn.Sign() <= 0 || reserveOut.Sign() <= 0 {
|
||||
return 0, fmt.Errorf("invalid amounts")
|
||||
}
|
||||
|
||||
520
pkg/math/exchange_pricing.go
Normal file
520
pkg/math/exchange_pricing.go
Normal file
@@ -0,0 +1,520 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// ExchangeType represents different DEX protocols on Arbitrum
|
||||
type ExchangeType string
|
||||
|
||||
const (
|
||||
ExchangeUniswapV3 ExchangeType = "uniswap_v3"
|
||||
ExchangeUniswapV2 ExchangeType = "uniswap_v2"
|
||||
ExchangeSushiSwap ExchangeType = "sushiswap"
|
||||
ExchangeCamelot ExchangeType = "camelot"
|
||||
ExchangeBalancer ExchangeType = "balancer"
|
||||
ExchangeTraderJoe ExchangeType = "traderjoe"
|
||||
ExchangeRamses ExchangeType = "ramses"
|
||||
ExchangeCurve ExchangeType = "curve"
|
||||
)
|
||||
|
||||
// ExchangePricer interface for exchange-specific price calculations
|
||||
type ExchangePricer interface {
|
||||
GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error)
|
||||
CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
|
||||
CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
|
||||
CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error)
|
||||
GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error)
|
||||
ValidatePoolData(poolData *PoolData) error
|
||||
}
|
||||
|
||||
// PoolData represents universal pool data structure
|
||||
type PoolData struct {
|
||||
Address string
|
||||
ExchangeType ExchangeType
|
||||
Token0 TokenInfo
|
||||
Token1 TokenInfo
|
||||
Reserve0 *UniversalDecimal
|
||||
Reserve1 *UniversalDecimal
|
||||
Fee *UniversalDecimal // Fee as percentage (e.g., 0.003 for 0.3%)
|
||||
|
||||
// Uniswap V3 specific
|
||||
SqrtPriceX96 *big.Int
|
||||
Tick *big.Int
|
||||
Liquidity *big.Int
|
||||
|
||||
// Curve specific
|
||||
A *big.Int // Amplification coefficient
|
||||
|
||||
// Balancer specific
|
||||
Weights []*UniversalDecimal // Token weights
|
||||
SwapFeeRate *UniversalDecimal // Swap fee rate
|
||||
}
|
||||
|
||||
// TokenInfo represents token metadata
|
||||
type TokenInfo struct {
|
||||
Address string
|
||||
Symbol string
|
||||
Decimals uint8
|
||||
}
|
||||
|
||||
// ExchangePricingEngine manages all exchange-specific pricing logic
|
||||
type ExchangePricingEngine struct {
|
||||
decimalConverter *DecimalConverter
|
||||
pricers map[ExchangeType]ExchangePricer
|
||||
}
|
||||
|
||||
// NewExchangePricingEngine creates a new pricing engine with all exchange support
|
||||
func NewExchangePricingEngine() *ExchangePricingEngine {
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
engine := &ExchangePricingEngine{
|
||||
decimalConverter: dc,
|
||||
pricers: make(map[ExchangeType]ExchangePricer),
|
||||
}
|
||||
|
||||
// Register all exchange pricers
|
||||
engine.pricers[ExchangeUniswapV3] = NewUniswapV3Pricer(dc)
|
||||
engine.pricers[ExchangeUniswapV2] = NewUniswapV2Pricer(dc)
|
||||
engine.pricers[ExchangeSushiSwap] = NewSushiSwapPricer(dc)
|
||||
engine.pricers[ExchangeCamelot] = NewCamelotPricer(dc)
|
||||
engine.pricers[ExchangeBalancer] = NewBalancerPricer(dc)
|
||||
engine.pricers[ExchangeTraderJoe] = NewTraderJoePricer(dc)
|
||||
engine.pricers[ExchangeRamses] = NewRamsesPricer(dc)
|
||||
engine.pricers[ExchangeCurve] = NewCurvePricer(dc)
|
||||
|
||||
return engine
|
||||
}
|
||||
|
||||
// GetExchangePricer returns the appropriate pricer for an exchange
|
||||
func (engine *ExchangePricingEngine) GetExchangePricer(exchangeType ExchangeType) (ExchangePricer, error) {
|
||||
pricer, exists := engine.pricers[exchangeType]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported exchange type: %s", exchangeType)
|
||||
}
|
||||
return pricer, nil
|
||||
}
|
||||
|
||||
// CalculateSpotPrice gets spot price from any exchange
|
||||
func (engine *ExchangePricingEngine) CalculateSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
pricer, err := engine.GetExchangePricer(poolData.ExchangeType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pricer.GetSpotPrice(poolData)
|
||||
}
|
||||
|
||||
// UniswapV3Pricer implements Uniswap V3 concentrated liquidity pricing
|
||||
type UniswapV3Pricer struct {
|
||||
dc *DecimalConverter
|
||||
}
|
||||
|
||||
func NewUniswapV3Pricer(dc *DecimalConverter) *UniswapV3Pricer {
|
||||
return &UniswapV3Pricer{dc: dc}
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
if poolData.SqrtPriceX96 == nil {
|
||||
return nil, fmt.Errorf("missing sqrtPriceX96 for Uniswap V3 pool")
|
||||
}
|
||||
|
||||
// Use cached function for optimized calculation
|
||||
// Convert sqrtPriceX96 to actual price using cached constants
|
||||
// price = sqrtPriceX96^2 / 2^192
|
||||
price := SqrtPriceX96ToPriceCached(poolData.SqrtPriceX96)
|
||||
|
||||
// Adjust for decimal differences between tokens
|
||||
if poolData.Token0.Decimals != poolData.Token1.Decimals {
|
||||
decimalDiff := int(poolData.Token1.Decimals) - int(poolData.Token0.Decimals)
|
||||
adjustment := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimalDiff)), nil))
|
||||
price.Mul(price, adjustment)
|
||||
}
|
||||
|
||||
// Convert back to big.Int with appropriate precision
|
||||
priceInt := new(big.Int)
|
||||
priceScaled := new(big.Float).Mul(price, new(big.Float).SetInt(p.dc.getScalingFactor(18)))
|
||||
priceScaled.Int(priceInt)
|
||||
|
||||
return NewUniversalDecimal(priceInt, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Uniswap V3 concentrated liquidity calculation
|
||||
// This is a simplified version - production would need full tick math
|
||||
|
||||
if poolData.Liquidity == nil || poolData.Liquidity.Sign() == 0 {
|
||||
return nil, fmt.Errorf("insufficient liquidity in Uniswap V3 pool")
|
||||
}
|
||||
|
||||
// Apply fee
|
||||
feeAmount, err := p.dc.Multiply(amountIn, poolData.Fee, amountIn.Decimals, "FEE")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating fee: %w", err)
|
||||
}
|
||||
|
||||
amountInAfterFee, err := p.dc.Subtract(amountIn, feeAmount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error subtracting fee: %w", err)
|
||||
}
|
||||
|
||||
// Simplified constant product formula for demonstration
|
||||
// Real implementation would use tick mathematics
|
||||
numerator, err := p.dc.Multiply(amountInAfterFee, poolData.Reserve1, poolData.Reserve1.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
denominator, err := p.dc.Add(poolData.Reserve0, amountInAfterFee)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(numerator, denominator, poolData.Token1.Decimals, poolData.Token1.Symbol)
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Reverse calculation for Uniswap V3
|
||||
if poolData.Reserve1.IsZero() || amountOut.Value.Cmp(poolData.Reserve1.Value) >= 0 {
|
||||
return nil, fmt.Errorf("insufficient liquidity for requested output amount")
|
||||
}
|
||||
|
||||
// Simplified reverse calculation
|
||||
numerator, err := p.dc.Multiply(poolData.Reserve0, amountOut, poolData.Reserve0.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
denominator, err := p.dc.Subtract(poolData.Reserve1, amountOut)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
amountInBeforeFee, err := p.dc.Divide(numerator, denominator, poolData.Token0.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add fee
|
||||
feeMultiplier, err := p.dc.FromString("1", 18, "FEE_MULT")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oneMinusFee, err := p.dc.Subtract(feeMultiplier, poolData.Fee)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(amountInBeforeFee, oneMinusFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Calculate price before trade
|
||||
priceBefore, err := p.GetSpotPrice(poolData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting spot price: %w", err)
|
||||
}
|
||||
|
||||
// Calculate amount out
|
||||
amountOut, err := p.CalculateAmountOut(amountIn, poolData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating amount out: %w", err)
|
||||
}
|
||||
|
||||
// Calculate effective price
|
||||
effectivePrice, err := p.dc.Divide(amountOut, amountIn, 18, "EFFECTIVE_PRICE")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating effective price: %w", err)
|
||||
}
|
||||
|
||||
// Calculate price impact as percentage
|
||||
priceDiff, err := p.dc.Subtract(priceBefore, effectivePrice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating price difference: %w", err)
|
||||
}
|
||||
|
||||
return p.dc.CalculatePercentage(priceDiff, priceBefore)
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
if poolData.Liquidity == nil {
|
||||
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
|
||||
}
|
||||
|
||||
return NewUniversalDecimal(poolData.Liquidity, 18, "LIQUIDITY")
|
||||
}
|
||||
|
||||
func (p *UniswapV3Pricer) ValidatePoolData(poolData *PoolData) error {
|
||||
if poolData.SqrtPriceX96 == nil {
|
||||
return fmt.Errorf("Uniswap V3 pool missing sqrtPriceX96")
|
||||
}
|
||||
if poolData.Liquidity == nil {
|
||||
return fmt.Errorf("Uniswap V3 pool missing liquidity")
|
||||
}
|
||||
if poolData.Fee == nil {
|
||||
return fmt.Errorf("Uniswap V3 pool missing fee")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UniswapV2Pricer implements Uniswap V2 / SushiSwap constant product pricing
|
||||
type UniswapV2Pricer struct {
|
||||
dc *DecimalConverter
|
||||
}
|
||||
|
||||
func NewUniswapV2Pricer(dc *DecimalConverter) *UniswapV2Pricer {
|
||||
return &UniswapV2Pricer{dc: dc}
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
if poolData.Reserve0.IsZero() {
|
||||
return nil, fmt.Errorf("zero reserve0 in constant product pool")
|
||||
}
|
||||
|
||||
return p.dc.Divide(poolData.Reserve1, poolData.Reserve0, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Uniswap V2 constant product formula: x * y = k
|
||||
// amountOut = (amountIn * 997 * reserveOut) / (reserveIn * 1000 + amountIn * 997)
|
||||
|
||||
// Apply fee (0.3% = 997/1000 remaining)
|
||||
feeNumerator, _ := p.dc.FromString("997", 0, "FEE_NUM")
|
||||
feeDenominator, _ := p.dc.FromString("1000", 0, "FEE_DEN")
|
||||
|
||||
amountInWithFee, err := p.dc.Multiply(amountIn, feeNumerator, amountIn.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
numerator, err := p.dc.Multiply(amountInWithFee, poolData.Reserve1, poolData.Reserve1.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reserveInScaled, err := p.dc.Multiply(poolData.Reserve0, feeDenominator, poolData.Reserve0.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
denominator, err := p.dc.Add(reserveInScaled, amountInWithFee)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(numerator, denominator, poolData.Token1.Decimals, poolData.Token1.Symbol)
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Reverse calculation for constant product
|
||||
feeNumerator, _ := p.dc.FromString("1000", 0, "FEE_NUM")
|
||||
feeDenominator, _ := p.dc.FromString("997", 0, "FEE_DEN")
|
||||
|
||||
numerator, err := p.dc.Multiply(poolData.Reserve0, amountOut, poolData.Reserve0.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
numeratorWithFee, err := p.dc.Multiply(numerator, feeNumerator, numerator.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
denominator, err := p.dc.Subtract(poolData.Reserve1, amountOut)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
denominatorWithFee, err := p.dc.Multiply(denominator, feeDenominator, denominator.Decimals, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(numeratorWithFee, denominatorWithFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Similar to Uniswap V3 implementation
|
||||
priceBefore, err := p.GetSpotPrice(poolData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
amountOut, err := p.CalculateAmountOut(amountIn, poolData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
effectivePrice, err := p.dc.Divide(amountOut, amountIn, 18, "EFFECTIVE_PRICE")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
priceDiff, err := p.dc.Subtract(priceBefore, effectivePrice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.CalculatePercentage(priceDiff, priceBefore)
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Geometric mean of reserves
|
||||
product, err := p.dc.Multiply(poolData.Reserve0, poolData.Reserve1, 18, "LIQUIDITY")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Simplified square root - in production use precise sqrt algorithm
|
||||
sqrt := new(big.Int).Sqrt(product.Value)
|
||||
return NewUniversalDecimal(sqrt, 18, "LIQUIDITY")
|
||||
}
|
||||
|
||||
func (p *UniswapV2Pricer) ValidatePoolData(poolData *PoolData) error {
|
||||
if poolData.Reserve0 == nil || poolData.Reserve1 == nil {
|
||||
return fmt.Errorf("missing reserves for constant product pool")
|
||||
}
|
||||
if poolData.Reserve0.IsZero() || poolData.Reserve1.IsZero() {
|
||||
return fmt.Errorf("zero reserves in constant product pool")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SushiSwapPricer uses same logic as Uniswap V2
|
||||
type SushiSwapPricer struct {
|
||||
*UniswapV2Pricer
|
||||
}
|
||||
|
||||
func NewSushiSwapPricer(dc *DecimalConverter) *SushiSwapPricer {
|
||||
return &SushiSwapPricer{NewUniswapV2Pricer(dc)}
|
||||
}
|
||||
|
||||
// CamelotPricer - Algebra-based DEX on Arbitrum
|
||||
type CamelotPricer struct {
|
||||
*UniswapV3Pricer
|
||||
}
|
||||
|
||||
func NewCamelotPricer(dc *DecimalConverter) *CamelotPricer {
|
||||
return &CamelotPricer{NewUniswapV3Pricer(dc)}
|
||||
}
|
||||
|
||||
// BalancerPricer - Weighted pool implementation
|
||||
type BalancerPricer struct {
|
||||
dc *DecimalConverter
|
||||
}
|
||||
|
||||
func NewBalancerPricer(dc *DecimalConverter) *BalancerPricer {
|
||||
return &BalancerPricer{dc: dc}
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
if len(poolData.Weights) < 2 {
|
||||
return nil, fmt.Errorf("insufficient weights for Balancer pool")
|
||||
}
|
||||
|
||||
// Balancer spot price = (reserveOut/weightOut) / (reserveIn/weightIn)
|
||||
reserveOutWeighted, err := p.dc.Divide(poolData.Reserve1, poolData.Weights[1], 18, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reserveInWeighted, err := p.dc.Divide(poolData.Reserve0, poolData.Weights[0], 18, "TEMP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(reserveOutWeighted, reserveInWeighted, 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Simplified Balancer calculation - production needs full weighted math
|
||||
return p.GetSpotPrice(poolData)
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
spotPrice, err := p.GetSpotPrice(poolData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(amountOut, spotPrice, poolData.Token0.Decimals, poolData.Token0.Symbol)
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Placeholder - would implement Balancer-specific price impact
|
||||
return NewUniversalDecimal(big.NewInt(0), 4, "PERCENT")
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
|
||||
}
|
||||
|
||||
func (p *BalancerPricer) ValidatePoolData(poolData *PoolData) error {
|
||||
if len(poolData.Weights) < 2 {
|
||||
return fmt.Errorf("Balancer pool missing weights")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Placeholder implementations for other exchanges
|
||||
func NewTraderJoePricer(dc *DecimalConverter) *UniswapV2Pricer { return NewUniswapV2Pricer(dc) }
|
||||
func NewRamsesPricer(dc *DecimalConverter) *UniswapV3Pricer { return NewUniswapV3Pricer(dc) }
|
||||
|
||||
// CurvePricer - Stable swap implementation
|
||||
type CurvePricer struct {
|
||||
dc *DecimalConverter
|
||||
}
|
||||
|
||||
func NewCurvePricer(dc *DecimalConverter) *CurvePricer {
|
||||
return &CurvePricer{dc: dc}
|
||||
}
|
||||
|
||||
func (p *CurvePricer) GetSpotPrice(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Curve stable swap pricing - simplified version
|
||||
if poolData.A == nil {
|
||||
return nil, fmt.Errorf("missing amplification coefficient for Curve pool")
|
||||
}
|
||||
|
||||
// For stable swaps, price should be close to 1:1
|
||||
return NewUniversalDecimal(p.dc.getScalingFactor(18), 18, fmt.Sprintf("%s/%s", poolData.Token1.Symbol, poolData.Token0.Symbol))
|
||||
}
|
||||
|
||||
func (p *CurvePricer) CalculateAmountOut(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Simplified stable swap calculation
|
||||
// Real implementation would use Newton's method for stable swap invariant
|
||||
feeAmount, err := p.dc.Multiply(amountIn, poolData.Fee, amountIn.Decimals, "FEE")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Subtract(amountIn, feeAmount)
|
||||
}
|
||||
|
||||
func (p *CurvePricer) CalculateAmountIn(amountOut *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Reverse stable swap calculation
|
||||
feeMultiplier, _ := p.dc.FromString("1", 18, "FEE_MULT")
|
||||
oneMinusFee, err := p.dc.Subtract(feeMultiplier, poolData.Fee)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.dc.Divide(amountOut, oneMinusFee, poolData.Token0.Decimals, poolData.Token0.Symbol)
|
||||
}
|
||||
|
||||
func (p *CurvePricer) CalculatePriceImpact(amountIn *UniversalDecimal, poolData *PoolData) (*UniversalDecimal, error) {
|
||||
// Curve pools have minimal price impact for stable pairs
|
||||
return NewUniversalDecimal(big.NewInt(1000), 4, "PERCENT") // 0.1%
|
||||
}
|
||||
|
||||
func (p *CurvePricer) GetMinimumLiquidity(poolData *PoolData) (*UniversalDecimal, error) {
|
||||
return NewUniversalDecimal(big.NewInt(0), 18, "LIQUIDITY")
|
||||
}
|
||||
|
||||
func (p *CurvePricer) ValidatePoolData(poolData *PoolData) error {
|
||||
if poolData.A == nil {
|
||||
return fmt.Errorf("Curve pool missing amplification coefficient")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
62
pkg/math/mock_gas_estimator.go
Normal file
62
pkg/math/mock_gas_estimator.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// MockGasEstimator implements GasEstimator for testing purposes
|
||||
type MockGasEstimator struct {
|
||||
currentGasPrice *UniversalDecimal
|
||||
}
|
||||
|
||||
// NewMockGasEstimator creates a new mock gas estimator
|
||||
func NewMockGasEstimator() *MockGasEstimator {
|
||||
dc := NewDecimalConverter()
|
||||
gasPrice, _ := dc.FromString("20", 9, "GWEI") // 20 gwei default
|
||||
|
||||
return &MockGasEstimator{
|
||||
currentGasPrice: gasPrice,
|
||||
}
|
||||
}
|
||||
|
||||
// EstimateSwapGas estimates gas for a swap
|
||||
func (mge *MockGasEstimator) EstimateSwapGas(exchange string, poolData *PoolData) (uint64, error) {
|
||||
// Different exchanges have different gas costs
|
||||
switch exchange {
|
||||
case "uniswap_v3":
|
||||
return 150000, nil // 150k gas for Uniswap V3
|
||||
case "uniswap_v2", "sushiswap":
|
||||
return 120000, nil // 120k gas for Uniswap V2/SushiSwap
|
||||
case "camelot":
|
||||
return 130000, nil // 130k gas for Camelot
|
||||
case "balancer":
|
||||
return 200000, nil // 200k gas for Balancer
|
||||
case "curve":
|
||||
return 180000, nil // 180k gas for Curve
|
||||
default:
|
||||
return 150000, nil // Default to 150k gas
|
||||
}
|
||||
}
|
||||
|
||||
// EstimateFlashSwapGas estimates gas for a flash swap
|
||||
func (mge *MockGasEstimator) EstimateFlashSwapGas(route []*PoolData) (uint64, error) {
|
||||
// Flash swap overhead varies by complexity
|
||||
baseGas := uint64(200000) // Base flash swap overhead
|
||||
|
||||
// Add gas for each hop
|
||||
hopGas := uint64(len(route)) * 50000
|
||||
|
||||
return baseGas + hopGas, nil
|
||||
}
|
||||
|
||||
// GetCurrentGasPrice gets the current gas price
|
||||
func (mge *MockGasEstimator) GetCurrentGasPrice() (*UniversalDecimal, error) {
|
||||
return mge.currentGasPrice, nil
|
||||
}
|
||||
|
||||
// SetCurrentGasPrice sets the current gas price for testing
|
||||
func (mge *MockGasEstimator) SetCurrentGasPrice(gasPriceGwei int64) {
|
||||
dc := NewDecimalConverter()
|
||||
gasPrice, _ := dc.FromString(big.NewInt(gasPriceGwei).String(), 9, "GWEI")
|
||||
mge.currentGasPrice = gasPrice
|
||||
}
|
||||
289
pkg/math/precision_test.go
Normal file
289
pkg/math/precision_test.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package math
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestDecimalPrecisionPreservation tests that decimal operations preserve precision
|
||||
func TestDecimalPrecisionPreservation(t *testing.T) {
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
value string
|
||||
decimals uint8
|
||||
symbol string
|
||||
}{
|
||||
{"ETH precision", "1000000000000000000", 18, "ETH"}, // 1 ETH
|
||||
{"USDC precision", "1000000", 6, "USDC"}, // 1 USDC
|
||||
{"WBTC precision", "100000000", 8, "WBTC"}, // 1 WBTC
|
||||
{"Small amount", "1", 18, "ETH"}, // 1 wei
|
||||
{"Large amount", "1000000000000000000000", 18, "ETH"}, // 1000 ETH
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create decimal from string
|
||||
decimal, err := dc.FromString(tc.value, tc.decimals, tc.symbol)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create decimal: %v", err)
|
||||
}
|
||||
|
||||
// Convert to string and back
|
||||
humanReadable := dc.ToHumanReadable(decimal)
|
||||
backToDecimal, err := dc.FromString(humanReadable, tc.decimals, tc.symbol)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert back from string: %v", err)
|
||||
}
|
||||
|
||||
// Compare values
|
||||
if decimal.Value.Cmp(backToDecimal.Value) != 0 {
|
||||
t.Errorf("Precision lost in round-trip conversion")
|
||||
t.Errorf("Original: %s", decimal.Value.String())
|
||||
t.Errorf("Round-trip: %s", backToDecimal.Value.String())
|
||||
t.Errorf("Human readable: %s", humanReadable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestArithmeticOperations tests basic arithmetic with different decimal precisions
|
||||
func TestArithmeticOperations(t *testing.T) {
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
// Create test values with different precisions
|
||||
eth1, _ := dc.FromString("1000000000000000000", 18, "ETH") // 1 ETH
|
||||
eth2, _ := dc.FromString("2000000000000000000", 18, "ETH") // 2 ETH
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
op string
|
||||
a, b *UniversalDecimal
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "ETH addition",
|
||||
op: "add",
|
||||
a: eth1,
|
||||
b: eth2,
|
||||
expected: "3000000000000000000", // 3 ETH
|
||||
},
|
||||
{
|
||||
name: "ETH subtraction",
|
||||
op: "sub",
|
||||
a: eth2,
|
||||
b: eth1,
|
||||
expected: "1000000000000000000", // 1 ETH
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var result *UniversalDecimal
|
||||
var err error
|
||||
|
||||
switch test.op {
|
||||
case "add":
|
||||
result, err = dc.Add(test.a, test.b)
|
||||
case "sub":
|
||||
result, err = dc.Subtract(test.a, test.b)
|
||||
default:
|
||||
t.Fatalf("Unknown operation: %s", test.op)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Operation failed: %v", err)
|
||||
}
|
||||
|
||||
if result.Value.String() != test.expected {
|
||||
t.Errorf("Expected %s, got %s", test.expected, result.Value.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPercentageCalculations tests percentage calculations for precision
|
||||
func TestPercentageCalculations(t *testing.T) {
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
numerator string
|
||||
denominator string
|
||||
decimals uint8
|
||||
expectedRange [2]float64 // [min, max] acceptable range
|
||||
}{
|
||||
{
|
||||
name: "1% calculation",
|
||||
numerator: "10000000000000000", // 0.01 ETH
|
||||
denominator: "1000000000000000000", // 1 ETH
|
||||
decimals: 18,
|
||||
expectedRange: [2]float64{0.99, 1.01}, // 1% ± 0.01%
|
||||
},
|
||||
{
|
||||
name: "50% calculation",
|
||||
numerator: "500000000000000000", // 0.5 ETH
|
||||
denominator: "1000000000000000000", // 1 ETH
|
||||
decimals: 18,
|
||||
expectedRange: [2]float64{49.9, 50.1}, // 50% ± 0.1%
|
||||
},
|
||||
{
|
||||
name: "Small percentage",
|
||||
numerator: "1000000000000000", // 0.001 ETH
|
||||
denominator: "1000000000000000000", // 1 ETH
|
||||
decimals: 18,
|
||||
expectedRange: [2]float64{0.099, 0.101}, // 0.1% ± 0.001%
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
num, err := dc.FromString(tc.numerator, tc.decimals, "ETH")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create numerator: %v", err)
|
||||
}
|
||||
|
||||
denom, err := dc.FromString(tc.denominator, tc.decimals, "ETH")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create denominator: %v", err)
|
||||
}
|
||||
|
||||
percentage, err := dc.CalculatePercentage(num, denom)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to calculate percentage: %v", err)
|
||||
}
|
||||
|
||||
percentageFloat, _ := percentage.Value.Float64()
|
||||
|
||||
t.Logf("Calculated percentage: %.6f%%", percentageFloat)
|
||||
|
||||
if percentageFloat < tc.expectedRange[0] || percentageFloat > tc.expectedRange[1] {
|
||||
t.Errorf("Percentage %.6f%% outside expected range [%.3f%%, %.3f%%]",
|
||||
percentageFloat, tc.expectedRange[0], tc.expectedRange[1])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// PropertyTest tests mathematical properties like commutativity, associativity
|
||||
func TestMathematicalProperties(t *testing.T) {
|
||||
dc := NewDecimalConverter()
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
// Generate random test values
|
||||
for i := 0; i < 100; i++ {
|
||||
// Generate random big integers
|
||||
val1 := big.NewInt(rand.Int63n(1000000000000000000)) // Up to 1 ETH
|
||||
val2 := big.NewInt(rand.Int63n(1000000000000000000))
|
||||
val3 := big.NewInt(rand.Int63n(1000000000000000000))
|
||||
|
||||
a, _ := NewUniversalDecimal(val1, 18, "ETH")
|
||||
b, _ := NewUniversalDecimal(val2, 18, "ETH")
|
||||
c, _ := NewUniversalDecimal(val3, 18, "ETH")
|
||||
|
||||
// Test commutativity: a + b = b + a
|
||||
ab, err1 := dc.Add(a, b)
|
||||
ba, err2 := dc.Add(b, a)
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
t.Fatalf("Addition failed: %v, %v", err1, err2)
|
||||
}
|
||||
|
||||
if ab.Value.Cmp(ba.Value) != 0 {
|
||||
t.Errorf("Addition not commutative: %s + %s = %s, %s + %s = %s",
|
||||
a.Value.String(), b.Value.String(), ab.Value.String(),
|
||||
b.Value.String(), a.Value.String(), ba.Value.String())
|
||||
}
|
||||
|
||||
// Test associativity: (a + b) + c = a + (b + c)
|
||||
ab_c, err1 := dc.Add(ab, c)
|
||||
bc, err2 := dc.Add(b, c)
|
||||
a_bc, err3 := dc.Add(a, bc)
|
||||
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
continue // Skip this iteration if any operation fails
|
||||
}
|
||||
|
||||
if ab_c.Value.Cmp(a_bc.Value) != 0 {
|
||||
t.Errorf("Addition not associative")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDecimalOperations benchmarks decimal operations
|
||||
func BenchmarkDecimalOperations(b *testing.B) {
|
||||
dc := NewDecimalConverter()
|
||||
val1, _ := dc.FromString("1000000000000000000", 18, "ETH")
|
||||
val2, _ := dc.FromString("2000000000000000000", 18, "ETH")
|
||||
|
||||
b.Run("Addition", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
dc.Add(val1, val2)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Subtraction", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
dc.Subtract(val2, val1)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Percentage", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
dc.CalculatePercentage(val1, val2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzDecimalOperations fuzzes decimal operations for edge cases
|
||||
func FuzzDecimalOperations(f *testing.F) {
|
||||
// Seed with known values
|
||||
f.Add(int64(1000000000000000000), int64(2000000000000000000)) // 1 ETH, 2 ETH
|
||||
f.Add(int64(1), int64(1000000000000000000)) // 1 wei, 1 ETH
|
||||
f.Add(int64(0), int64(1000000000000000000)) // 0, 1 ETH
|
||||
|
||||
f.Fuzz(func(t *testing.T, val1, val2 int64) {
|
||||
// Ensure positive values
|
||||
if val1 < 0 {
|
||||
val1 = -val1
|
||||
}
|
||||
if val2 <= 0 {
|
||||
return // Skip zero/negative denominators
|
||||
}
|
||||
|
||||
dc := NewDecimalConverter()
|
||||
|
||||
a, err := NewUniversalDecimal(big.NewInt(val1), 18, "ETH")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b, err := NewUniversalDecimal(big.NewInt(val2), 18, "ETH")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Test addition doesn't panic
|
||||
_, err = dc.Add(a, b)
|
||||
if err != nil {
|
||||
t.Errorf("Addition failed: %v", err)
|
||||
}
|
||||
|
||||
// Test subtraction doesn't panic (if a >= b)
|
||||
if val1 >= val2 {
|
||||
_, err = dc.Subtract(a, b)
|
||||
if err != nil {
|
||||
t.Errorf("Subtraction failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test percentage calculation doesn't panic
|
||||
_, err = dc.CalculatePercentage(a, b)
|
||||
if err != nil {
|
||||
t.Errorf("Percentage calculation failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user