Files
mev-beta/pkg/pools/create2_test.go
Krypto Kajun 850223a953 fix(multicall): resolve critical multicall parsing corruption issues
- Added comprehensive bounds checking to prevent buffer overruns in multicall parsing
- Implemented graduated validation system (Strict/Moderate/Permissive) to reduce false positives
- Added LRU caching system for address validation with 10-minute TTL
- Enhanced ABI decoder with missing Universal Router and Arbitrum-specific DEX signatures
- Fixed duplicate function declarations and import conflicts across multiple files
- Added error recovery mechanisms with multiple fallback strategies
- Updated tests to handle new validation behavior for suspicious addresses
- Fixed parser test expectations for improved validation system
- Applied gofmt formatting fixes to ensure code style compliance
- Fixed mutex copying issues in monitoring package by introducing MetricsSnapshot
- Resolved critical security vulnerabilities in heuristic address extraction
- Progress: Updated TODO audit from 10% to 35% complete

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 00:12:55 -05:00

374 lines
13 KiB
Go

//go:build legacy_pools
// +build legacy_pools
package pools
import (
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/fraktal/mev-beta/internal/logger"
)
// TestNewCREATE2Calculator tests the creation of a new CREATE2 calculator
func TestNewCREATE2Calculator(t *testing.T) {
logger := logger.New("info", "text", "")
var ethClient *ethclient.Client // nil for testing
calc := NewCREATE2Calculator(logger, ethClient)
require.NotNil(t, calc)
assert.NotNil(t, calc.logger)
assert.NotNil(t, calc.factories)
assert.NotEmpty(t, calc.factories)
// Check that key factories are initialized
assert.Contains(t, calc.factories, "uniswap_v3")
assert.Contains(t, calc.factories, "uniswap_v2")
assert.Contains(t, calc.factories, "sushiswap")
assert.Contains(t, calc.factories, "camelot_v3")
assert.Contains(t, calc.factories, "curve")
}
// TestInitializeFactories tests the initialization of factory configurations
func TestInitializeFactories(t *testing.T) {
logger := logger.New("info", "text", "")
var ethClient *ethclient.Client // nil for testing
calc := NewCREATE2Calculator(logger, ethClient)
// Test Uniswap V3 factory configuration
uniswapV3, exists := calc.factories["uniswap_v3"]
assert.True(t, exists)
assert.Equal(t, "uniswap_v3", uniswapV3.Name)
assert.Equal(t, "0x1F98431c8aD98523631AE4a59f267346ea31F984", uniswapV3.Address.Hex())
assert.Equal(t, "0xe34f199b19b2b4f47f68442619d555527d244f78a3297ea89325f843f87b8b54", uniswapV3.InitCodeHash.Hex())
assert.True(t, uniswapV3.FeeStructure.HasFee)
assert.Equal(t, []uint32{500, 3000, 10000}, uniswapV3.FeeStructure.DefaultFees)
assert.True(t, uniswapV3.SortTokens)
// Test Uniswap V2 factory configuration
uniswapV2, exists := calc.factories["uniswap_v2"]
assert.True(t, exists)
assert.Equal(t, "uniswap_v2", uniswapV2.Name)
assert.Equal(t, "0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f", uniswapV2.Address.Hex())
assert.Equal(t, "0x96e8ac4277198ff8b6f785478aa9a39f403cb768dd02cbee326c3e7da348845f", uniswapV2.InitCodeHash.Hex())
assert.False(t, uniswapV2.FeeStructure.HasFee)
assert.Equal(t, []uint32{3000}, uniswapV2.FeeStructure.DefaultFees)
assert.True(t, uniswapV2.SortTokens)
// Test SushiSwap factory configuration
sushiswap, exists := calc.factories["sushiswap"]
assert.True(t, exists)
assert.Equal(t, "sushiswap", sushiswap.Name)
assert.Equal(t, "0xC0AEe478e3658e2610c5F7A4A2E1777cE9e4f2Ac", sushiswap.Address.Hex())
assert.Equal(t, "0xe18a34eb0e04b04f7a0ac29a6e80748dca96319b42c54d679cb821dca90c6303", sushiswap.InitCodeHash.Hex())
assert.False(t, sushiswap.FeeStructure.HasFee)
assert.Equal(t, []uint32{3000}, sushiswap.FeeStructure.DefaultFees)
assert.True(t, sushiswap.SortTokens)
}
// TestCalculatePoolAddress tests pool address calculation
func TestCalculatePoolAddress(t *testing.T) {
logger := logger.New("info", "text", "")
calc := NewCREATE2Calculator(logger, nil)
// Test with unknown factory
addr, err := calc.CalculatePoolAddress("unknown_factory", common.Address{}, common.Address{}, 3000)
assert.Error(t, err)
assert.Equal(t, common.Address{}, addr)
assert.Contains(t, err.Error(), "unknown factory")
// Test with valid Uniswap V3 configuration
token0 := common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48") // USDC
token1 := common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2") // WETH
fee := uint32(3000)
addr, err = calc.CalculatePoolAddress("uniswap_v3", token0, token1, fee)
assert.NoError(t, err)
assert.NotEqual(t, common.Address{}, addr)
// Test with valid Uniswap V2 configuration
addr, err = calc.CalculatePoolAddress("uniswap_v2", token0, token1, fee)
assert.NoError(t, err)
assert.NotEqual(t, common.Address{}, addr)
// Test token sorting for Uniswap V3 (tokens should be sorted)
// When token0 > token1, they should be swapped internally
addrSorted, err := calc.CalculatePoolAddress("uniswap_v3", token1, token0, fee) // Swapped order
assert.NoError(t, err)
// Addresses should be the same because tokens are sorted internally
assert.Equal(t, addr.Hex(), addrSorted.Hex())
// Test with SushiSwap
addr, err = calc.CalculatePoolAddress("sushiswap", token0, token1, fee)
assert.NoError(t, err)
assert.NotEqual(t, common.Address{}, addr)
}
// TestCalculateSalt tests salt calculation for different protocols
func TestCalculateSalt(t *testing.T) {
logger := logger.New("info", "text", "")
calc := NewCREATE2Calculator(logger, nil)
token0 := common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48") // USDC
token1 := common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2") // WETH
fee := uint32(3000)
// Test Uniswap V3 salt calculation
factory := calc.factories["uniswap_v3"]
salt, err := calc.calculateSalt(factory, token0, token1, fee)
assert.NoError(t, err)
assert.NotNil(t, salt)
assert.Len(t, salt, 32)
// Test Uniswap V2 salt calculation
factory = calc.factories["uniswap_v2"]
salt, err = calc.calculateSalt(factory, token0, token1, fee)
assert.NoError(t, err)
assert.NotNil(t, salt)
assert.Len(t, salt, 32)
// Test generic salt calculation
factory = calc.factories["sushiswap"]
salt, err = calc.calculateSalt(factory, token0, token1, fee)
assert.NoError(t, err)
assert.NotNil(t, salt)
assert.Len(t, salt, 32)
}
// TestCalculateUniswapV3Salt tests Uniswap V3 specific salt calculation
func TestCalculateUniswapV3Salt(t *testing.T) {
logger := logger.New("info", "text", "")
calc := NewCREATE2Calculator(logger, nil)
token0 := common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48") // USDC
token1 := common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2") // WETH
fee := uint32(3000)
salt, err := calc.calculateUniswapV3Salt(token0, token1, fee)
assert.NoError(t, err)
assert.NotNil(t, salt)
assert.Len(t, salt, 32)
// Test with different order (should produce different salt)
salt2, err := calc.calculateUniswapV3Salt(token1, token0, fee)
assert.NoError(t, err)
assert.NotEqual(t, salt, salt2)
// Test with different fee (should produce different salt)
salt3, err := calc.calculateUniswapV3Salt(token0, token1, 500)
assert.NoError(t, err)
assert.NotEqual(t, salt, salt3)
}
// TestCalculateUniswapV2Salt tests Uniswap V2 specific salt calculation
func TestCalculateUniswapV2Salt(t *testing.T) {
logger := logger.New("info", "text", "")
calc := NewCREATE2Calculator(logger, nil)
token0 := common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48") // USDC
token1 := common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2") // WETH
salt, err := calc.calculateUniswapV2Salt(token0, token1)
assert.NoError(t, err)
assert.NotNil(t, salt)
assert.Len(t, salt, 32)
// Test with different order (should produce different salt)
salt2, err := calc.calculateUniswapV2Salt(token1, token0)
assert.NoError(t, err)
assert.NotEqual(t, salt, salt2)
}
// TestFindPoolsForTokenPair tests finding pools for a token pair
func TestFindPoolsForTokenPair(t *testing.T) {
logger := logger.New("info", "text", "")
calc := NewCREATE2Calculator(logger, nil)
token0 := common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48") // USDC
token1 := common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2") // WETH
pools, err := calc.FindPoolsForTokenPair(token0, token1)
assert.NoError(t, err)
assert.NotNil(t, pools)
assert.NotEmpty(t, pools)
// Should find pools for multiple factories
assert.True(t, len(pools) >= 3) // At least Uniswap V2, V3, and SushiSwap
// Check that each pool has required fields
for _, pool := range pools {
assert.NotEmpty(t, pool.Factory)
assert.NotEqual(t, common.Address{}, pool.Token0)
assert.NotEqual(t, common.Address{}, pool.Token1)
assert.NotEqual(t, uint32(0), pool.Fee)
assert.NotEqual(t, common.Address{}, pool.PoolAddr)
}
}
// TestValidatePoolAddress tests pool address validation
func TestValidatePoolAddress(t *testing.T) {
logger := logger.New("info", "text", "")
calc := NewCREATE2Calculator(logger, nil)
token0 := common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48") // USDC
token1 := common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2") // WETH
fee := uint32(3000)
// Calculate an expected address
expectedAddr, err := calc.CalculatePoolAddress("uniswap_v3", token0, token1, fee)
assert.NoError(t, err)
assert.NotEqual(t, common.Address{}, expectedAddr)
// Validate the address
isValid := calc.ValidatePoolAddress("uniswap_v3", token0, token1, fee, expectedAddr)
assert.True(t, isValid)
// Test with incorrect address
wrongAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
isValid = calc.ValidatePoolAddress("uniswap_v3", token0, token1, fee, wrongAddr)
assert.False(t, isValid)
// Test with unknown factory
isValid = calc.ValidatePoolAddress("unknown_factory", token0, token1, fee, expectedAddr)
assert.False(t, isValid)
}
// TestGetFactoryConfig tests getting factory configuration
func TestGetFactoryConfig(t *testing.T) {
logger := logger.New("info", "text", "")
calc := NewCREATE2Calculator(logger, nil)
// Test getting existing factory
config, err := calc.GetFactoryConfig("uniswap_v3")
assert.NoError(t, err)
assert.NotNil(t, config)
assert.Equal(t, "uniswap_v3", config.Name)
assert.Equal(t, "0x1F98431c8aD98523631AE4a59f267346ea31F984", config.Address.Hex())
// Test getting non-existent factory
config, err = calc.GetFactoryConfig("unknown_factory")
assert.Error(t, err)
assert.Nil(t, config)
assert.Contains(t, err.Error(), "unknown factory")
}
// TestAddCustomFactory tests adding a custom factory
func TestAddCustomFactory(t *testing.T) {
logger := logger.New("info", "text", "")
calc := NewCREATE2Calculator(logger, nil)
// Test with invalid config (empty name)
invalidConfig := &FactoryConfig{
Name: "",
Address: common.HexToAddress("0x1234567890123456789012345678901234567890"),
InitCodeHash: common.HexToHash("0x1234567890123456789012345678901234567890123456789012345678901234"),
FeeStructure: FeeStructure{
HasFee: true,
DefaultFees: []uint32{1000},
},
SortTokens: true,
}
err := calc.AddCustomFactory(invalidConfig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "factory name cannot be empty")
// Test with invalid config (zero address)
invalidConfig2 := &FactoryConfig{
Name: "test_factory",
Address: common.Address{},
InitCodeHash: common.HexToHash("0x1234567890123456789012345678901234567890123456789012345678901234"),
FeeStructure: FeeStructure{
HasFee: true,
DefaultFees: []uint32{1000},
},
SortTokens: true,
}
err = calc.AddCustomFactory(invalidConfig2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "factory address cannot be zero")
// Test with valid config
validConfig := &FactoryConfig{
Name: "test_factory",
Address: common.HexToAddress("0x1234567890123456789012345678901234567890"),
InitCodeHash: common.HexToHash("0x1234567890123456789012345678901234567890123456789012345678901234"),
FeeStructure: FeeStructure{
HasFee: true,
DefaultFees: []uint32{1000},
},
SortTokens: true,
}
err = calc.AddCustomFactory(validConfig)
assert.NoError(t, err)
// Verify the factory was added
config, err := calc.GetFactoryConfig("test_factory")
assert.NoError(t, err)
assert.NotNil(t, config)
assert.Equal(t, "test_factory", config.Name)
assert.Equal(t, "0x1234567890123456789012345678901234567890", config.Address.Hex())
}
// TestListFactories tests listing all factories
func TestListFactories(t *testing.T) {
logger := logger.New("info", "text", "")
calc := NewCREATE2Calculator(logger, nil)
factories := calc.ListFactories()
assert.NotEmpty(t, factories)
assert.Contains(t, factories, "uniswap_v3")
assert.Contains(t, factories, "uniswap_v2")
assert.Contains(t, factories, "sushiswap")
assert.Contains(t, factories, "camelot_v3")
assert.Contains(t, factories, "curve")
// Factories should be sorted
sorted := true
for i := 1; i < len(factories); i++ {
if factories[i-1] > factories[i] {
sorted = false
break
}
}
assert.True(t, sorted)
}
// TestCalculateInitCodeHash tests init code hash calculation
func TestCalculateInitCodeHash(t *testing.T) {
// Test with empty init code
hash := CalculateInitCodeHash([]byte{})
assert.Equal(t, "0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470", hash.Hex())
// Test with sample init code
sampleCode := []byte("hello world")
hash = CalculateInitCodeHash(sampleCode)
assert.Equal(t, "0x47173285a8d7341e5e972fc677286384f802f8ef42a5ec5f03bbfa254cb01fad", hash.Hex())
}
// TestVerifyFactorySupport tests factory support verification
func TestVerifyFactorySupport(t *testing.T) {
logger := logger.New("info", "text", "")
calc := NewCREATE2Calculator(logger, nil)
// Test with non-existent factory
err := calc.VerifyFactorySupport("unknown_factory")
assert.Error(t, err)
assert.Contains(t, err.Error(), "factory unknown_factory not configured")
// Test with valid factory
err = calc.VerifyFactorySupport("uniswap_v3")
assert.NoError(t, err)
// Test with Curve (special case)
err = calc.VerifyFactorySupport("curve")
assert.NoError(t, err)
}