Files
mev-beta/tools/math-audit/internal/audit/runner.go
Krypto Kajun 850223a953 fix(multicall): resolve critical multicall parsing corruption issues
- Added comprehensive bounds checking to prevent buffer overruns in multicall parsing
- Implemented graduated validation system (Strict/Moderate/Permissive) to reduce false positives
- Added LRU caching system for address validation with 10-minute TTL
- Enhanced ABI decoder with missing Universal Router and Arbitrum-specific DEX signatures
- Fixed duplicate function declarations and import conflicts across multiple files
- Added error recovery mechanisms with multiple fallback strategies
- Updated tests to handle new validation behavior for suspicious addresses
- Fixed parser test expectations for improved validation system
- Applied gofmt formatting fixes to ensure code style compliance
- Fixed mutex copying issues in monitoring package by introducing MetricsSnapshot
- Resolved critical security vulnerabilities in heuristic address extraction
- Progress: Updated TODO audit from 10% to 35% complete

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 00:12:55 -05:00

350 lines
9.1 KiB
Go

package audit
import (
"fmt"
"math"
"math/big"
"strings"
"time"
mmath "github.com/fraktal/mev-beta/pkg/math"
"github.com/fraktal/mev-beta/tools/math-audit/internal/models"
)
// TestResult captures the outcome of a single assertion.
type TestResult struct {
Name string `json:"name"`
Type string `json:"type"`
Passed bool `json:"passed"`
DeltaBPS float64 `json:"delta_bps"`
Expected string `json:"expected"`
Actual string `json:"actual"`
Details string `json:"details,omitempty"`
Annotations []string `json:"annotations,omitempty"`
}
// VectorResult summarises the results for a single pool vector.
type VectorResult struct {
Name string `json:"name"`
Description string `json:"description"`
Exchange string `json:"exchange"`
Passed bool `json:"passed"`
Tests []TestResult `json:"tests"`
Errors []string `json:"errors,omitempty"`
}
// Summary aggregates overall audit statistics.
type Summary struct {
GeneratedAt time.Time `json:"generated_at"`
TotalVectors int `json:"total_vectors"`
VectorsPassed int `json:"vectors_passed"`
TotalAssertions int `json:"total_assertions"`
AssertionsPassed int `json:"assertions_passed"`
PropertyChecks int `json:"property_checks"`
PropertySucceeded int `json:"property_succeeded"`
}
// Result is the top-level audit payload.
type Result struct {
Summary Summary `json:"summary"`
Vectors []VectorResult `json:"vectors"`
PropertyChecks []TestResult `json:"property_checks"`
}
// Runner executes vector assertions using the math pricing engine.
type Runner struct {
dc *mmath.DecimalConverter
engine *mmath.ExchangePricingEngine
}
func NewRunner() *Runner {
return &Runner{
dc: mmath.NewDecimalConverter(),
engine: mmath.NewExchangePricingEngine(),
}
}
// Run executes the provided vectors and property checks.
func (r *Runner) Run(vectors []models.Vector, propertyChecks []TestResult) Result {
var (
vectorResults []VectorResult
totalAssertions int
assertionsPassed int
vectorsPassed int
)
for _, vec := range vectors {
vr := r.evaluateVector(vec)
vectorResults = append(vectorResults, vr)
allPassed := vr.Passed
for _, tr := range vr.Tests {
totalAssertions++
if tr.Passed {
assertionsPassed++
} else {
allPassed = false
}
}
if allPassed {
vectorsPassed++
}
}
propPassed := 0
for _, check := range propertyChecks {
if check.Passed {
propPassed++
}
}
summary := Summary{
GeneratedAt: time.Now().UTC(),
TotalVectors: len(vectorResults),
VectorsPassed: vectorsPassed,
TotalAssertions: totalAssertions,
AssertionsPassed: assertionsPassed,
PropertyChecks: len(propertyChecks),
PropertySucceeded: propPassed,
}
return Result{
Summary: summary,
Vectors: vectorResults,
PropertyChecks: propertyChecks,
}
}
func (r *Runner) evaluateVector(vec models.Vector) VectorResult {
poolData, err := r.buildPool(vec.Pool)
if err != nil {
return VectorResult{
Name: vec.Name,
Description: vec.Description,
Exchange: vec.Pool.Exchange,
Passed: false,
Errors: []string{fmt.Sprintf("build pool: %v", err)},
}
}
pricer, err := r.engine.GetExchangePricer(poolData.ExchangeType)
if err != nil {
return VectorResult{
Name: vec.Name,
Description: vec.Description,
Exchange: vec.Pool.Exchange,
Passed: false,
Errors: []string{fmt.Sprintf("get pricer: %v", err)},
}
}
vr := VectorResult{
Name: vec.Name,
Description: vec.Description,
Exchange: vec.Pool.Exchange,
Passed: true,
}
for _, test := range vec.Tests {
tr := r.executeTest(test, pricer, poolData)
vr.Tests = append(vr.Tests, tr)
if !tr.Passed {
vr.Passed = false
}
}
return vr
}
func (r *Runner) executeTest(test models.TestCase, pricer mmath.ExchangePricer, pool *mmath.PoolData) TestResult {
result := TestResult{Name: test.Name, Type: test.Type}
expected, err := r.toUniversalDecimal(test.Expected)
if err != nil {
result.Passed = false
result.Details = fmt.Sprintf("parse expected: %v", err)
return result
}
result.Expected = r.dc.ToHumanReadable(expected)
tolerance := test.ToleranceBPS
if tolerance <= 0 {
tolerance = 1 // default tolerance of 1 bp
}
switch strings.ToLower(test.Type) {
case "spot_price":
actual, err := pricer.GetSpotPrice(pool)
if err != nil {
return failure(result, fmt.Sprintf("spot price: %v", err))
}
return r.compareDecimals(result, expected, actual, tolerance)
case "amount_out":
if test.AmountIn == nil {
return failure(result, "amount_in required for amount_out test")
}
amountIn, err := r.toUniversalDecimal(*test.AmountIn)
if err != nil {
return failure(result, fmt.Sprintf("parse amount_in: %v", err))
}
actual, err := pricer.CalculateAmountOut(amountIn, pool)
if err != nil {
return failure(result, fmt.Sprintf("calculate amount_out: %v", err))
}
return r.compareDecimals(result, expected, actual, tolerance)
case "amount_in":
if test.AmountOut == nil {
return failure(result, "amount_out required for amount_in test")
}
amountOut, err := r.toUniversalDecimal(*test.AmountOut)
if err != nil {
return failure(result, fmt.Sprintf("parse amount_out: %v", err))
}
actual, err := pricer.CalculateAmountIn(amountOut, pool)
if err != nil {
return failure(result, fmt.Sprintf("calculate amount_in: %v", err))
}
return r.compareDecimals(result, expected, actual, tolerance)
default:
return failure(result, fmt.Sprintf("unsupported test type %q", test.Type))
}
}
func (r *Runner) compareDecimals(result TestResult, expected, actual *mmath.UniversalDecimal, tolerance float64) TestResult {
convertedActual, err := r.dc.ConvertTo(actual, expected.Decimals, expected.Symbol)
if err != nil {
return failure(result, fmt.Sprintf("rescale actual: %v", err))
}
diff := new(big.Int).Sub(convertedActual.Value, expected.Value)
absDiff := new(big.Int).Abs(diff)
deltaBPS := math.Inf(1)
if expected.Value.Sign() == 0 {
if convertedActual.Value.Sign() == 0 {
deltaBPS = 0
}
} else {
// delta_bps = |actual - expected| / expected * 1e4
numerator := new(big.Float).SetInt(absDiff)
denominator := new(big.Float).SetInt(expected.Value)
if denominator.Cmp(big.NewFloat(0)) != 0 {
ratio := new(big.Float).Quo(numerator, denominator)
bps := new(big.Float).Mul(ratio, big.NewFloat(10000))
val, _ := bps.Float64()
deltaBPS = val
}
}
result.DeltaBPS = deltaBPS
result.Expected = r.dc.ToHumanReadable(expected)
result.Actual = r.dc.ToHumanReadable(convertedActual)
result.Passed = deltaBPS <= tolerance
result.Annotations = append(result.Annotations, fmt.Sprintf("tolerance %.4f bps", tolerance))
if !result.Passed {
result.Details = fmt.Sprintf("delta %.4f bps exceeds tolerance %.4f", deltaBPS, tolerance)
}
return result
}
func failure(result TestResult, msg string) TestResult {
result.Passed = false
result.Details = msg
return result
}
func (r *Runner) toUniversalDecimal(dec models.DecimalValue) (*mmath.UniversalDecimal, error) {
if err := dec.Validate(); err != nil {
return nil, err
}
value, ok := new(big.Int).SetString(dec.Value, 10)
if !ok {
return nil, fmt.Errorf("invalid integer %s", dec.Value)
}
return mmath.NewUniversalDecimal(value, dec.Decimals, dec.Symbol)
}
func (r *Runner) buildPool(pool models.Pool) (*mmath.PoolData, error) {
reserve0, err := r.toUniversalDecimal(pool.Reserve0)
if err != nil {
return nil, fmt.Errorf("reserve0: %w", err)
}
reserve1, err := r.toUniversalDecimal(pool.Reserve1)
if err != nil {
return nil, fmt.Errorf("reserve1: %w", err)
}
pd := &mmath.PoolData{
Address: pool.Address,
ExchangeType: mmath.ExchangeType(pool.Exchange),
Token0: mmath.TokenInfo{
Address: pool.Token0.Address,
Symbol: pool.Token0.Symbol,
Decimals: pool.Token0.Decimals,
},
Token1: mmath.TokenInfo{
Address: pool.Token1.Address,
Symbol: pool.Token1.Symbol,
Decimals: pool.Token1.Decimals,
},
Reserve0: reserve0,
Reserve1: reserve1,
}
if pool.Fee != nil {
fee, err := r.toUniversalDecimal(*pool.Fee)
if err != nil {
return nil, fmt.Errorf("fee: %w", err)
}
pd.Fee = fee
}
if pool.SqrtPriceX96 != "" {
val, ok := new(big.Int).SetString(pool.SqrtPriceX96, 10)
if !ok {
return nil, fmt.Errorf("sqrt_price_x96 invalid")
}
pd.SqrtPriceX96 = val
}
if pool.Tick != "" {
val, ok := new(big.Int).SetString(pool.Tick, 10)
if !ok {
return nil, fmt.Errorf("tick invalid")
}
pd.Tick = val
}
if pool.Liquidity != "" {
val, ok := new(big.Int).SetString(pool.Liquidity, 10)
if !ok {
return nil, fmt.Errorf("liquidity invalid")
}
pd.Liquidity = val
}
if pool.Amplification != "" {
val, ok := new(big.Int).SetString(pool.Amplification, 10)
if !ok {
return nil, fmt.Errorf("amplification invalid")
}
pd.A = val
}
if len(pool.Weights) > 0 {
for _, w := range pool.Weights {
ud, err := r.toUniversalDecimal(w)
if err != nil {
return nil, fmt.Errorf("weight: %w", err)
}
pd.Weights = append(pd.Weights, ud)
}
}
return pd, nil
}