package audit import ( "fmt" "math" "math/big" "strings" "time" mmath "github.com/fraktal/mev-beta/pkg/math" "github.com/fraktal/mev-beta/tools/math-audit/internal/models" ) // TestResult captures the outcome of a single assertion. type TestResult struct { Name string `json:"name"` Type string `json:"type"` Passed bool `json:"passed"` DeltaBPS float64 `json:"delta_bps"` Expected string `json:"expected"` Actual string `json:"actual"` Details string `json:"details,omitempty"` Annotations []string `json:"annotations,omitempty"` } // VectorResult summarises the results for a single pool vector. type VectorResult struct { Name string `json:"name"` Description string `json:"description"` Exchange string `json:"exchange"` Passed bool `json:"passed"` Tests []TestResult `json:"tests"` Errors []string `json:"errors,omitempty"` } // Summary aggregates overall audit statistics. type Summary struct { GeneratedAt time.Time `json:"generated_at"` TotalVectors int `json:"total_vectors"` VectorsPassed int `json:"vectors_passed"` TotalAssertions int `json:"total_assertions"` AssertionsPassed int `json:"assertions_passed"` PropertyChecks int `json:"property_checks"` PropertySucceeded int `json:"property_succeeded"` } // Result is the top-level audit payload. type Result struct { Summary Summary `json:"summary"` Vectors []VectorResult `json:"vectors"` PropertyChecks []TestResult `json:"property_checks"` } // Runner executes vector assertions using the math pricing engine. type Runner struct { dc *mmath.DecimalConverter engine *mmath.ExchangePricingEngine } func NewRunner() *Runner { return &Runner{ dc: mmath.NewDecimalConverter(), engine: mmath.NewExchangePricingEngine(), } } // Run executes the provided vectors and property checks. func (r *Runner) Run(vectors []models.Vector, propertyChecks []TestResult) Result { var ( vectorResults []VectorResult totalAssertions int assertionsPassed int vectorsPassed int ) for _, vec := range vectors { vr := r.evaluateVector(vec) vectorResults = append(vectorResults, vr) allPassed := vr.Passed for _, tr := range vr.Tests { totalAssertions++ if tr.Passed { assertionsPassed++ } else { allPassed = false } } if allPassed { vectorsPassed++ } } propPassed := 0 for _, check := range propertyChecks { if check.Passed { propPassed++ } } summary := Summary{ GeneratedAt: time.Now().UTC(), TotalVectors: len(vectorResults), VectorsPassed: vectorsPassed, TotalAssertions: totalAssertions, AssertionsPassed: assertionsPassed, PropertyChecks: len(propertyChecks), PropertySucceeded: propPassed, } return Result{ Summary: summary, Vectors: vectorResults, PropertyChecks: propertyChecks, } } func (r *Runner) evaluateVector(vec models.Vector) VectorResult { poolData, err := r.buildPool(vec.Pool) if err != nil { return VectorResult{ Name: vec.Name, Description: vec.Description, Exchange: vec.Pool.Exchange, Passed: false, Errors: []string{fmt.Sprintf("build pool: %v", err)}, } } pricer, err := r.engine.GetExchangePricer(poolData.ExchangeType) if err != nil { return VectorResult{ Name: vec.Name, Description: vec.Description, Exchange: vec.Pool.Exchange, Passed: false, Errors: []string{fmt.Sprintf("get pricer: %v", err)}, } } vr := VectorResult{ Name: vec.Name, Description: vec.Description, Exchange: vec.Pool.Exchange, Passed: true, } for _, test := range vec.Tests { tr := r.executeTest(test, pricer, poolData) vr.Tests = append(vr.Tests, tr) if !tr.Passed { vr.Passed = false } } return vr } func (r *Runner) executeTest(test models.TestCase, pricer mmath.ExchangePricer, pool *mmath.PoolData) TestResult { result := TestResult{Name: test.Name, Type: test.Type} expected, err := r.toUniversalDecimal(test.Expected) if err != nil { result.Passed = false result.Details = fmt.Sprintf("parse expected: %v", err) return result } result.Expected = r.dc.ToHumanReadable(expected) tolerance := test.ToleranceBPS if tolerance <= 0 { tolerance = 1 // default tolerance of 1 bp } switch strings.ToLower(test.Type) { case "spot_price": actual, err := pricer.GetSpotPrice(pool) if err != nil { return failure(result, fmt.Sprintf("spot price: %v", err)) } return r.compareDecimals(result, expected, actual, tolerance) case "amount_out": if test.AmountIn == nil { return failure(result, "amount_in required for amount_out test") } amountIn, err := r.toUniversalDecimal(*test.AmountIn) if err != nil { return failure(result, fmt.Sprintf("parse amount_in: %v", err)) } actual, err := pricer.CalculateAmountOut(amountIn, pool) if err != nil { return failure(result, fmt.Sprintf("calculate amount_out: %v", err)) } return r.compareDecimals(result, expected, actual, tolerance) case "amount_in": if test.AmountOut == nil { return failure(result, "amount_out required for amount_in test") } amountOut, err := r.toUniversalDecimal(*test.AmountOut) if err != nil { return failure(result, fmt.Sprintf("parse amount_out: %v", err)) } actual, err := pricer.CalculateAmountIn(amountOut, pool) if err != nil { return failure(result, fmt.Sprintf("calculate amount_in: %v", err)) } return r.compareDecimals(result, expected, actual, tolerance) default: return failure(result, fmt.Sprintf("unsupported test type %q", test.Type)) } } func (r *Runner) compareDecimals(result TestResult, expected, actual *mmath.UniversalDecimal, tolerance float64) TestResult { convertedActual, err := r.dc.ConvertTo(actual, expected.Decimals, expected.Symbol) if err != nil { return failure(result, fmt.Sprintf("rescale actual: %v", err)) } diff := new(big.Int).Sub(convertedActual.Value, expected.Value) absDiff := new(big.Int).Abs(diff) deltaBPS := math.Inf(1) if expected.Value.Sign() == 0 { if convertedActual.Value.Sign() == 0 { deltaBPS = 0 } } else { // delta_bps = |actual - expected| / expected * 1e4 numerator := new(big.Float).SetInt(absDiff) denominator := new(big.Float).SetInt(expected.Value) if denominator.Cmp(big.NewFloat(0)) != 0 { ratio := new(big.Float).Quo(numerator, denominator) bps := new(big.Float).Mul(ratio, big.NewFloat(10000)) val, _ := bps.Float64() deltaBPS = val } } result.DeltaBPS = deltaBPS result.Expected = r.dc.ToHumanReadable(expected) result.Actual = r.dc.ToHumanReadable(convertedActual) result.Passed = deltaBPS <= tolerance result.Annotations = append(result.Annotations, fmt.Sprintf("tolerance %.4f bps", tolerance)) if !result.Passed { result.Details = fmt.Sprintf("delta %.4f bps exceeds tolerance %.4f", deltaBPS, tolerance) } return result } func failure(result TestResult, msg string) TestResult { result.Passed = false result.Details = msg return result } func (r *Runner) toUniversalDecimal(dec models.DecimalValue) (*mmath.UniversalDecimal, error) { if err := dec.Validate(); err != nil { return nil, err } value, ok := new(big.Int).SetString(dec.Value, 10) if !ok { return nil, fmt.Errorf("invalid integer %s", dec.Value) } return mmath.NewUniversalDecimal(value, dec.Decimals, dec.Symbol) } func (r *Runner) buildPool(pool models.Pool) (*mmath.PoolData, error) { reserve0, err := r.toUniversalDecimal(pool.Reserve0) if err != nil { return nil, fmt.Errorf("reserve0: %w", err) } reserve1, err := r.toUniversalDecimal(pool.Reserve1) if err != nil { return nil, fmt.Errorf("reserve1: %w", err) } pd := &mmath.PoolData{ Address: pool.Address, ExchangeType: mmath.ExchangeType(pool.Exchange), Token0: mmath.TokenInfo{ Address: pool.Token0.Address, Symbol: pool.Token0.Symbol, Decimals: pool.Token0.Decimals, }, Token1: mmath.TokenInfo{ Address: pool.Token1.Address, Symbol: pool.Token1.Symbol, Decimals: pool.Token1.Decimals, }, Reserve0: reserve0, Reserve1: reserve1, } if pool.Fee != nil { fee, err := r.toUniversalDecimal(*pool.Fee) if err != nil { return nil, fmt.Errorf("fee: %w", err) } pd.Fee = fee } if pool.SqrtPriceX96 != "" { val, ok := new(big.Int).SetString(pool.SqrtPriceX96, 10) if !ok { return nil, fmt.Errorf("sqrt_price_x96 invalid") } pd.SqrtPriceX96 = val } if pool.Tick != "" { val, ok := new(big.Int).SetString(pool.Tick, 10) if !ok { return nil, fmt.Errorf("tick invalid") } pd.Tick = val } if pool.Liquidity != "" { val, ok := new(big.Int).SetString(pool.Liquidity, 10) if !ok { return nil, fmt.Errorf("liquidity invalid") } pd.Liquidity = val } if pool.Amplification != "" { val, ok := new(big.Int).SetString(pool.Amplification, 10) if !ok { return nil, fmt.Errorf("amplification invalid") } pd.A = val } if len(pool.Weights) > 0 { for _, w := range pool.Weights { ud, err := r.toUniversalDecimal(w) if err != nil { return nil, fmt.Errorf("weight: %w", err) } pd.Weights = append(pd.Weights, ud) } } return pd, nil }