Files
mev-beta/pkg/validation/pool_validator_test.go
2025-09-16 11:05:47 -05:00

1151 lines
39 KiB
Go

package validation
import (
"context"
"math/big"
"testing"
"time"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/fraktal/mev-beta/internal/logger"
"github.com/fraktal/mev-beta/pkg/pools"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// MockEthClient is a mock implementation of ethclient.Client for testing
type MockEthClient struct {
mock.Mock
}
// CodeAt mocks the CodeAt method
func (m *MockEthClient) CodeAt(ctx context.Context, contract common.Address, blockNumber *big.Int) ([]byte, error) {
args := m.Called(ctx, contract, blockNumber)
return args.Get(0).([]byte), args.Error(1)
}
// CallContract mocks the CallContract method
func (m *MockEthClient) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) {
args := m.Called(ctx, call, blockNumber)
return args.Get(0).([]byte), args.Error(1)
}
// PendingCodeAt mocks the PendingCodeAt method
func (m *MockEthClient) PendingCodeAt(ctx context.Context, account common.Address) ([]byte, error) {
args := m.Called(ctx, account)
return args.Get(0).([]byte), args.Error(1)
}
// PendingNonceAt mocks the PendingNonceAt method
func (m *MockEthClient) PendingNonceAt(ctx context.Context, account common.Address) (uint64, error) {
args := m.Called(ctx, account)
return args.Get(0).(uint64), args.Error(1)
}
// SuggestGasPrice mocks the SuggestGasPrice method
func (m *MockEthClient) SuggestGasPrice(ctx context.Context) (*big.Int, error) {
args := m.Called(ctx)
return args.Get(0).(*big.Int), args.Error(1)
}
// SuggestGasTipCap mocks the SuggestGasTipCap method
func (m *MockEthClient) SuggestGasTipCap(ctx context.Context) (*big.Int, error) {
args := m.Called(ctx)
return args.Get(0).(*big.Int), args.Error(1)
}
// EstimateGas mocks the EstimateGas method
func (m *MockEthClient) EstimateGas(ctx context.Context, call ethereum.CallMsg) (uint64, error) {
args := m.Called(ctx, call)
return args.Get(0).(uint64), args.Error(1)
}
// SendTransaction mocks the SendTransaction method
func (m *MockEthClient) SendTransaction(ctx context.Context, tx *types.Transaction) error {
args := m.Called(ctx, tx)
return args.Error(0)
}
// TransactionReceipt mocks the TransactionReceipt method
func (m *MockEthClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) {
args := m.Called(ctx, txHash)
return args.Get(0).(*types.Receipt), args.Error(1)
}
// TransactionByHash mocks the TransactionByHash method
func (m *MockEthClient) TransactionByHash(ctx context.Context, hash common.Hash) (*types.Transaction, bool, error) {
args := m.Called(ctx, hash)
return args.Get(0).(*types.Transaction), args.Get(1).(bool), args.Error(2)
}
// BlockByNumber mocks the BlockByNumber method
func (m *MockEthClient) BlockByNumber(ctx context.Context, number *big.Int) (*types.Block, error) {
args := m.Called(ctx, number)
return args.Get(0).(*types.Block), args.Error(1)
}
// HeaderByNumber mocks the HeaderByNumber method
func (m *MockEthClient) HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) {
args := m.Called(ctx, number)
return args.Get(0).(*types.Header), args.Error(1)
}
// BalanceAt mocks the BalanceAt method
func (m *MockEthClient) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) {
args := m.Called(ctx, account, blockNumber)
return args.Get(0).(*big.Int), args.Error(1)
}
// StorageAt mocks the StorageAt method
func (m *MockEthClient) StorageAt(ctx context.Context, account common.Address, key common.Hash, blockNumber *big.Int) ([]byte, error) {
args := m.Called(ctx, account, key, blockNumber)
return args.Get(0).([]byte), args.Error(1)
}
// FilterLogs mocks the FilterLogs method
func (m *MockEthClient) FilterLogs(ctx context.Context, q ethereum.FilterQuery) ([]types.Log, error) {
args := m.Called(ctx, q)
return args.Get(0).([]types.Log), args.Error(1)
}
// SubscribeFilterLogs mocks the SubscribeFilterLogs method
func (m *MockEthClient) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (ethereum.Subscription, error) {
args := m.Called(ctx, q, ch)
return args.Get(0).(ethereum.Subscription), args.Error(1)
}
// ChainID mocks the ChainID method
func (m *MockEthClient) ChainID(ctx context.Context) (*big.Int, error) {
args := m.Called(ctx)
return args.Get(0).(*big.Int), args.Error(1)
}
// Close mocks the Close method
func (m *MockEthClient) Close() {
m.Called()
}
// TestNewPoolValidator tests the creation of a new PoolValidator
func TestNewPoolValidator(t *testing.T) {
// Create a mock Ethereum client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create a new pool validator
validator := NewPoolValidator(client, log)
// Verify the validator was created correctly
require.NotNil(t, validator)
assert.Equal(t, client, validator.client)
assert.Equal(t, log, validator.logger)
assert.NotNil(t, validator.create2Calculator)
assert.NotNil(t, validator.trustedFactories)
assert.NotNil(t, validator.bannedAddresses)
assert.NotNil(t, validator.validationCache)
assert.Equal(t, 5*time.Minute, validator.cacheTimeout)
// Check that trusted factories were initialized
assert.NotEmpty(t, validator.trustedFactories)
assert.Contains(t, validator.trustedFactories, common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984")) // Uniswap V3
assert.Contains(t, validator.trustedFactories, common.HexToAddress("0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f")) // Uniswap V2
assert.Contains(t, validator.trustedFactories, common.HexToAddress("0xC0AEe478e3658e2610c5F7A4A2E1777cE9e4f2Ac")) // SushiSwap
assert.Contains(t, validator.trustedFactories, common.HexToAddress("0x1a3c9B1d2F0529D97f2afC5136Cc23e58f1FD35B")) // Camelot V3
}
// TestPoolValidator_GetDefaultConfig tests the default configuration
func TestPoolValidator_GetDefaultConfig(t *testing.T) {
validator := &PoolValidator{}
config := validator.getDefaultConfig()
assert.True(t, config.RequireFactoryVerification)
assert.Equal(t, 70, config.MinSecurityScore)
assert.Equal(t, 10*time.Second, config.MaxValidationTime)
assert.False(t, config.AllowUnknownFactories)
assert.True(t, config.RequireTokenValidation)
}
// TestPoolValidator_InitializeTrustedFactories tests trusted factory initialization
func TestPoolValidator_InitializeTrustedFactories(t *testing.T) {
validator := &PoolValidator{
trustedFactories: make(map[common.Address]string),
}
validator.initializeTrustedFactories()
// Check that all expected factories are present
assert.Contains(t, validator.trustedFactories, common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"))
assert.Equal(t, "uniswap_v3", validator.trustedFactories[common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984")])
assert.Contains(t, validator.trustedFactories, common.HexToAddress("0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f"))
assert.Equal(t, "uniswap_v2", validator.trustedFactories[common.HexToAddress("0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f")])
assert.Contains(t, validator.trustedFactories, common.HexToAddress("0xC0AEe478e3658e2610c5F7A4A2E1777cE9e4f2Ac"))
assert.Equal(t, "sushiswap", validator.trustedFactories[common.HexToAddress("0xC0AEe478e3658e2610c5F7A4A2E1777cE9e4f2Ac")])
assert.Contains(t, validator.trustedFactories, common.HexToAddress("0x1a3c9B1d2F0529D97f2afC5136Cc23e58f1FD35B"))
assert.Equal(t, "camelot_v3", validator.trustedFactories[common.HexToAddress("0x1a3c9B1d2F0529D97f2afC5136Cc23e58f1FD35B")])
}
// TestPoolValidator_InitializeBannedAddresses tests banned address initialization
func TestPoolValidator_InitializeBannedAddresses(t *testing.T) {
validator := &PoolValidator{
bannedAddresses: make(map[common.Address]string),
}
validator.initializeBannedAddresses()
// The banned addresses map should be empty initially
// In production, this would be populated with known malicious addresses
assert.Empty(t, validator.bannedAddresses)
}
// TestPoolValidator_AddTrustedFactory tests adding a trusted factory
func TestPoolValidator_AddTrustedFactory(t *testing.T) {
log := logger.New("info", "text", "")
validator := &PoolValidator{
trustedFactories: make(map[common.Address]string),
logger: log,
}
factoryAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
factoryName := "test_factory"
validator.AddTrustedFactory(factoryAddr, factoryName)
assert.Contains(t, validator.trustedFactories, factoryAddr)
assert.Equal(t, factoryName, validator.trustedFactories[factoryAddr])
}
// TestPoolValidator_BanAddress tests banning an address
func TestPoolValidator_BanAddress(t *testing.T) {
log := logger.New("info", "text", "")
validator := &PoolValidator{
bannedAddresses: make(map[common.Address]string),
logger: log,
}
addr := common.HexToAddress("0x1234567890123456789012345678901234567890")
reason := "test_ban_reason"
validator.BanAddress(addr, reason)
assert.Contains(t, validator.bannedAddresses, addr)
assert.Equal(t, reason, validator.bannedAddresses[addr])
}
// TestPoolValidator_ValidateBasicExistence tests basic existence validation
func TestPoolValidator_ValidateBasicExistence(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
// Test with valid contract code
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
validCode := make([]byte, 200) // 200 bytes of code (valid size)
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return(validCode, nil)
result := &ValidationResult{
Warnings: make([]string, 0),
Errors: make([]string, 0),
}
err := validator.validateBasicExistence(context.Background(), poolAddr, result)
assert.NoError(t, err)
assert.Empty(t, result.Errors)
// Test with empty code (no contract)
emptyCode := []byte{}
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return(emptyCode, nil).Once()
err = validator.validateBasicExistence(context.Background(), poolAddr, result)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no contract code at address")
// Test with very small code (should generate warning)
smallCode := make([]byte, 50) // 50 bytes (small)
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return(smallCode, nil).Once()
result = &ValidationResult{
Warnings: make([]string, 0),
Errors: make([]string, 0),
}
err = validator.validateBasicExistence(context.Background(), poolAddr, result)
assert.NoError(t, err)
assert.Len(t, result.Warnings, 1)
assert.Contains(t, result.Warnings[0], "Contract has very small code size")
assert.Equal(t, -10, result.SecurityScore) // Should reduce security score
// Test with CodeAt error
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return([]byte{}, assert.AnError).Once()
err = validator.validateBasicExistence(context.Background(), poolAddr, result)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get contract code")
}
// TestPoolValidator_CheckBannedAddresses tests banned address checking
func TestPoolValidator_CheckBannedAddresses(t *testing.T) {
log := logger.New("info", "text", "")
validator := &PoolValidator{
bannedAddresses: make(map[common.Address]string),
logger: log,
}
// Test with non-banned address
poolAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
result := &ValidationResult{}
err := validator.checkBannedAddresses(poolAddr, result)
assert.NoError(t, err)
// Test with banned address
bannedAddr := common.HexToAddress("0xbanned1234567890123456789012345678901234")
banReason := "known malicious pool"
validator.bannedAddresses[bannedAddr] = banReason
result = &ValidationResult{}
err = validator.checkBannedAddresses(bannedAddr, result)
assert.Error(t, err)
assert.Contains(t, err.Error(), "pool")
assert.Contains(t, err.Error(), "is banned")
assert.Contains(t, err.Error(), banReason)
}
// TestPoolValidator_ValidatePoolInterface tests pool interface validation
func TestPoolValidator_ValidatePoolInterface(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
// Test with successful interface detection (mock implementation)
result := &ValidationResult{
Warnings: make([]string, 0),
Errors: make([]string, 0),
}
// For now, just test that the function doesn't panic with valid inputs
// A full implementation would require more extensive mocking
ctx := context.Background()
// This is a simplified test - in practice, you'd need to mock the actual contract calls
err := validator.validatePoolInterface(ctx, poolAddr, result)
assert.NoError(t, err) // Should not error even if interface detection fails
}
// TestPoolValidator_IsUniswapV3Pool tests Uniswap V3 pool detection
func TestPoolValidator_IsUniswapV3Pool(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
// Test with successful slot0 call (indicates V3 pool)
slot0Result := make([]byte, 100) // Mock slot0 result
client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(slot0Result, nil).Once()
isV3 := validator.isUniswapV3Pool(context.Background(), poolAddr)
assert.True(t, isV3)
// Test with failed slot0 call (indicates not V3 pool)
client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, assert.AnError).Once()
isV3 = validator.isUniswapV3Pool(context.Background(), poolAddr)
assert.False(t, isV3)
}
// TestPoolValidator_IsUniswapV2Pool tests Uniswap V2 pool detection
func TestPoolValidator_IsUniswapV2Pool(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0xB4e16d0168e52d35CaCD2c6185b44281Ec28C9Dc")
// Test with successful getReserves call (indicates V2 pool)
reservesResult := make([]byte, 100) // Mock reserves result
client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(reservesResult, nil).Once()
isV2 := validator.isUniswapV2Pool(context.Background(), poolAddr)
assert.True(t, isV2)
// Test with failed getReserves call (indicates not V2 pool)
client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, assert.AnError).Once()
isV2 = validator.isUniswapV2Pool(context.Background(), poolAddr)
assert.False(t, isV2)
}
// TestPoolValidator_ValidateUniswapV3Interface tests Uniswap V3 interface validation
func TestPoolValidator_ValidateUniswapV3Interface(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
// Test with valid V3 pool info
result := &ValidationResult{
Warnings: make([]string, 0),
Errors: make([]string, 0),
}
// For now, just test that the function doesn't panic with valid inputs
// A full implementation would require more extensive mocking
ctx := context.Background()
// This is a simplified test - in practice, you'd need to mock the actual contract calls
err := validator.validateUniswapV3Interface(ctx, poolAddr, result)
assert.NoError(t, err) // Should not error even if validation fails
}
// TestPoolValidator_ValidateUniswapV2Interface tests Uniswap V2 interface validation
func TestPoolValidator_ValidateUniswapV2Interface(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0xB4e16d0168e52d35CaCD2c6185b44281Ec28C9Dc")
// Test with valid V2 pool info
result := &ValidationResult{
Warnings: make([]string, 0),
Errors: make([]string, 0),
}
// For now, just test that the function doesn't panic with valid inputs
// A full implementation would require more extensive mocking
ctx := context.Background()
// This is a simplified test - in practice, you'd need to mock the actual contract calls
err := validator.validateUniswapV2Interface(ctx, poolAddr, result)
assert.NoError(t, err) // Should not error even if validation fails
}
// TestPoolValidator_ValidateFactoryDeployment tests factory deployment validation
func TestPoolValidator_ValidateFactoryDeployment(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator with trusted factories
validator := &PoolValidator{
client: client,
logger: log,
trustedFactories: make(map[common.Address]string),
create2Calculator: pools.NewCREATE2Calculator(log),
}
// Add a trusted factory
factoryAddr := common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984")
validator.trustedFactories[factoryAddr] = "uniswap_v3"
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
// Test with valid factory deployment
result := &ValidationResult{
Warnings: make([]string, 0),
Errors: make([]string, 0),
Token0: common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), // USDC
Token1: common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"), // WETH
Fee: 3000,
}
// For now, just test that the function doesn't panic with valid inputs
// A full implementation would require more extensive mocking
ctx := context.Background()
// This is a simplified test - in practice, you'd need to mock the actual contract calls
err := validator.validateFactoryDeployment(ctx, poolAddr, result)
assert.NoError(t, err) // Should not error even if validation fails
}
// TestPoolValidator_VerifyFactoryDeployment tests factory deployment verification
func TestPoolValidator_VerifyFactoryDeployment(t *testing.T) {
// Create a logger
log := logger.New("info", "text", "")
// Create validator with CREATE2 calculator
validator := &PoolValidator{
logger: log,
create2Calculator: pools.NewCREATE2Calculator(log),
trustedFactories: make(map[common.Address]string),
}
factoryAddr := common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984")
factoryName := "uniswap_v3"
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
result := &ValidationResult{
Token0: common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), // USDC
Token1: common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"), // WETH
Fee: 3000,
}
// Test with valid factory deployment (should return false for mock implementation)
// In a real implementation, this would depend on the CREATE2 calculation
isValid := validator.verifyFactoryDeployment(factoryAddr, factoryName, poolAddr, result)
assert.False(t, isValid) // Should be false because we're using mock data
// Test with invalid token addresses
result.Token0 = common.Address{} // Zero address
isValid = validator.verifyFactoryDeployment(factoryAddr, factoryName, poolAddr, result)
assert.False(t, isValid) // Should be false because token0 is zero
}
// TestPoolValidator_ValidateTokenContracts tests token contract validation
func TestPoolValidator_ValidateTokenContracts(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
// Test with valid token contracts
result := &ValidationResult{
Token0: common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), // USDC
Token1: common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"), // WETH
}
// Mock successful contract code retrieval
validCode := make([]byte, 200) // Valid contract code
client.On("CodeAt", mock.Anything, result.Token0, mock.Anything).Return(validCode, nil).Once()
client.On("CodeAt", mock.Anything, result.Token1, mock.Anything).Return(validCode, nil).Once()
// Mock successful totalSupply calls
totalSupplyResult := make([]byte, 32) // Mock totalSupply result
client.On("CallContract", mock.Anything, mock.AnythingOfType("ethereum.CallMsg"), mock.Anything).Return(totalSupplyResult, nil)
ctx := context.Background()
err := validator.validateTokenContracts(ctx, result)
assert.NoError(t, err)
// Test with missing token addresses
resultEmpty := &ValidationResult{
Token0: common.Address{},
Token1: common.Address{},
}
err = validator.validateTokenContracts(ctx, resultEmpty)
assert.Error(t, err)
assert.Contains(t, err.Error(), "token addresses not available")
}
// TestPoolValidator_ValidateTokenContract tests single token contract validation
func TestPoolValidator_ValidateTokenContract(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
tokenAddr := common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48") // USDC
// Test with valid token contract
validCode := make([]byte, 200) // Valid contract code
client.On("CodeAt", mock.Anything, tokenAddr, mock.Anything).Return(validCode, nil).Once()
// Mock successful totalSupply call
totalSupplyResult := make([]byte, 32) // Mock totalSupply result
client.On("CallContract", mock.Anything, mock.AnythingOfType("ethereum.CallMsg"), mock.Anything).Return(totalSupplyResult, nil).Once()
ctx := context.Background()
err := validator.validateTokenContract(ctx, tokenAddr)
assert.NoError(t, err)
// Test with no code at address
emptyCode := []byte{}
client.On("CodeAt", mock.Anything, tokenAddr, mock.Anything).Return(emptyCode, nil).Once()
err = validator.validateTokenContract(ctx, tokenAddr)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no contract code at token address")
// Test with CodeAt error
client.On("CodeAt", mock.Anything, tokenAddr, mock.Anything).Return([]byte{}, assert.AnError).Once()
err = validator.validateTokenContract(ctx, tokenAddr)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get token contract code")
}
// TestPoolValidator_ValidateERC20Interface tests ERC20 interface validation
func TestPoolValidator_ValidateERC20Interface(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
tokenAddr := common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48") // USDC
// Test with successful totalSupply call
totalSupplyResult := make([]byte, 32) // Mock totalSupply result
client.On("CallContract", mock.Anything, mock.AnythingOfType("ethereum.CallMsg"), mock.Anything).Return(totalSupplyResult, nil).Once()
ctx := context.Background()
err := validator.validateERC20Interface(ctx, tokenAddr)
assert.NoError(t, err)
// Test with failed totalSupply call
client.On("CallContract", mock.Anything, mock.AnythingOfType("ethereum.CallMsg"), mock.Anything).Return([]byte{}, assert.AnError).Once()
err = validator.validateERC20Interface(ctx, tokenAddr)
assert.Error(t, err)
assert.Contains(t, err.Error(), "totalSupply call failed")
}
// TestPoolValidator_PerformAdditionalSecurityChecks tests additional security checks
func TestPoolValidator_PerformAdditionalSecurityChecks(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
result := &ValidationResult{
Warnings: make([]string, 0),
Errors: make([]string, 0),
SecurityScore: 0,
}
// For now, just test that the function doesn't panic
// A full implementation would require more extensive mocking
ctx := context.Background()
validator.performAdditionalSecurityChecks(ctx, poolAddr, result)
// The function should complete without error
// Additional assertions would depend on the specific checks implemented
}
// TestPoolValidator_GetContractCreationBlock tests contract creation block retrieval
func TestPoolValidator_GetContractCreationBlock(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
// Test that the function returns 0 (placeholder implementation)
ctx := context.Background()
blockNumber := validator.getContractCreationBlock(ctx, poolAddr)
assert.Equal(t, uint64(0), blockNumber) // Should return 0 for now
}
// TestPoolValidator_CheckForAttackPatterns tests attack pattern detection
func TestPoolValidator_CheckForAttackPatterns(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
result := &ValidationResult{
Warnings: make([]string, 0),
Errors: make([]string, 0),
SecurityScore: 0,
}
// For now, just test that the function doesn't panic
// A full implementation would require more extensive mocking
ctx := context.Background()
validator.checkForAttackPatterns(ctx, poolAddr, result)
// The function should complete without error
// Additional assertions would depend on the specific checks implemented
}
// TestPoolValidator_IsProxyContract tests proxy contract detection
func TestPoolValidator_IsProxyContract(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
// Test with valid contract code containing delegatecall (0xf4)
codeWithDelegateCall := []byte{0x00, 0x01, 0xf4, 0x02, 0x03} // Contains delegatecall
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return(codeWithDelegateCall, nil).Once()
ctx := context.Background()
isProxy := validator.isProxyContract(ctx, poolAddr)
assert.True(t, isProxy)
// Test with valid contract code without delegatecall
codeWithoutDelegateCall := []byte{0x00, 0x01, 0x02, 0x03, 0x04} // No delegatecall
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return(codeWithoutDelegateCall, nil).Once()
isProxy = validator.isProxyContract(ctx, poolAddr)
assert.False(t, isProxy)
// Test with CodeAt error
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return([]byte{}, assert.AnError).Once()
isProxy = validator.isProxyContract(ctx, poolAddr)
assert.False(t, isProxy)
// Test with empty code
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return([]byte{}, nil).Once()
isProxy = validator.isProxyContract(ctx, poolAddr)
assert.False(t, isProxy)
}
// TestPoolValidator_HasUnusualBytecode tests unusual bytecode detection
func TestPoolValidator_HasUnusualBytecode(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
// Test with normal-sized contract code
normalCode := make([]byte, 1000) // Reasonable size
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return(normalCode, nil).Once()
ctx := context.Background()
hasUnusual := validator.hasUnusualBytecode(ctx, poolAddr)
assert.False(t, hasUnusual)
// Test with unusually large contract code
largeCode := make([]byte, 60000) // Very large code
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return(largeCode, nil).Once()
hasUnusual = validator.hasUnusualBytecode(ctx, poolAddr)
assert.True(t, hasUnusual)
// Test with CodeAt error
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return([]byte{}, assert.AnError).Once()
hasUnusual = validator.hasUnusualBytecode(ctx, poolAddr)
assert.False(t, hasUnusual)
// Test with empty code
client.On("CodeAt", mock.Anything, poolAddr, mock.Anything).Return([]byte{}, nil).Once()
hasUnusual = validator.hasUnusualBytecode(ctx, poolAddr)
assert.False(t, hasUnusual)
}
// TestPoolValidator_CalculateEntropy tests entropy calculation
func TestPoolValidator_CalculateEntropy(t *testing.T) {
validator := &PoolValidator{}
// Test with empty data
emptyData := []byte{}
entropy := validator.calculateEntropy(emptyData)
assert.Equal(t, 0.0, entropy)
// Test with uniform data (low entropy)
uniformData := []byte{0x00, 0x00, 0x00, 0x00, 0x00}
entropy = validator.calculateEntropy(uniformData)
assert.Equal(t, 0.0, entropy)
// Test with varied data (higher entropy)
variedData := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}
entropy = validator.calculateEntropy(variedData)
assert.True(t, entropy > 0.0)
// Entropy should be between 0 and log2(length) for valid inputs
maxEntropy := 3.0 // log2(8) for 8 bytes
assert.True(t, entropy <= maxEntropy)
}
// TestPoolValidator_GetUniswapV3PoolInfo tests Uniswap V3 pool info retrieval
func TestPoolValidator_GetUniswapV3PoolInfo(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
// Test with successful contract calls
token0Result := make([]byte, 32) // Mock token0 result
token1Result := make([]byte, 32) // Mock token1 result
feeResult := make([]byte, 32) // Mock fee result
// Fill with mock data
copy(token0Result[12:], common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48").Bytes()) // USDC
copy(token1Result[12:], common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2").Bytes()) // WETH
copy(feeResult[28:], []byte{0x00, 0x00, 0x0b, 0xb8}) // 3000 (0xbb8 in hex)
client.On("CallContract", mock.Anything, mock.AnythingOfType("ethereum.CallMsg"), mock.Anything).Return(token0Result, nil).Once()
client.On("CallContract", mock.Anything, mock.AnythingOfType("ethereum.CallMsg"), mock.Anything).Return(token1Result, nil).Once()
client.On("CallContract", mock.Anything, mock.AnythingOfType("ethereum.CallMsg"), mock.Anything).Return(feeResult, nil).Once()
ctx := context.Background()
token0, token1, fee, err := validator.getUniswapV3PoolInfo(ctx, poolAddr)
assert.NoError(t, err)
assert.Equal(t, common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), token0)
assert.Equal(t, common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"), token1)
assert.Equal(t, uint32(3000), fee)
}
// TestPoolValidator_GetUniswapV2PoolInfo tests Uniswap V2 pool info retrieval
func TestPoolValidator_GetUniswapV2PoolInfo(t *testing.T) {
// Create a mock client
client := &MockEthClient{}
// Create a logger
log := logger.New("info", "text", "")
// Create validator
validator := &PoolValidator{
client: client,
logger: log,
}
poolAddr := common.HexToAddress("0xB4e16d0168e52d35CaCD2c6185b44281Ec28C9Dc")
// Test with successful contract calls
token0Result := make([]byte, 32) // Mock token0 result
token1Result := make([]byte, 32) // Mock token1 result
// Fill with mock data
copy(token0Result[12:], common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48").Bytes()) // USDC
copy(token1Result[12:], common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2").Bytes()) // WETH
client.On("CallContract", mock.Anything, mock.AnythingOfType("ethereum.CallMsg"), mock.Anything).Return(token0Result, nil).Once()
client.On("CallContract", mock.Anything, mock.AnythingOfType("ethereum.CallMsg"), mock.Anything).Return(token1Result, nil).Once()
ctx := context.Background()
token0, token1, err := validator.getUniswapV2PoolInfo(ctx, poolAddr)
assert.NoError(t, err)
assert.Equal(t, common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), token0)
assert.Equal(t, common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"), token1)
}
// TestPoolValidator_GetCachedResult tests cached result retrieval
func TestPoolValidator_GetCachedResult(t *testing.T) {
log := logger.New("info", "text", "")
validator := &PoolValidator{
validationCache: make(map[common.Address]*ValidationResult),
cacheTimeout: 5 * time.Minute,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
// Test with no cached result
result := validator.getCachedResult(poolAddr)
assert.Nil(t, result)
// Test with valid cached result
cachedResult := &ValidationResult{
IsValid: true,
SecurityScore: 90,
ValidatedAt: time.Now(),
}
validator.validationCache[poolAddr] = cachedResult
result = validator.getCachedResult(poolAddr)
assert.Equal(t, cachedResult, result)
// Test with expired cached result
expiredResult := &ValidationResult{
IsValid: true,
SecurityScore: 90,
ValidatedAt: time.Now().Add(-10 * time.Minute), // 10 minutes ago (expired)
}
validator.validationCache[poolAddr] = expiredResult
result = validator.getCachedResult(poolAddr)
assert.Nil(t, result) // Should be nil because it's expired
// Verify that expired result was removed from cache
_, exists := validator.validationCache[poolAddr]
assert.False(t, exists)
}
// TestPoolValidator_CacheResult tests result caching
func TestPoolValidator_CacheResult(t *testing.T) {
log := logger.New("info", "text", "")
validator := &PoolValidator{
validationCache: make(map[common.Address]*ValidationResult),
cacheTimeout: 5 * time.Minute,
logger: log,
}
poolAddr := common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")
result := &ValidationResult{
IsValid: true,
SecurityScore: 90,
ValidatedAt: time.Now(),
}
// Test caching a result
validator.cacheResult(poolAddr, result)
// Verify it was cached
cachedResult, exists := validator.validationCache[poolAddr]
assert.True(t, exists)
assert.Equal(t, result, cachedResult)
}
// TestValidationResult tests the ValidationResult struct
func TestValidationResult(t *testing.T) {
result := &ValidationResult{
IsValid: true,
SecurityScore: 95,
Warnings: []string{"Test warning"},
Errors: []string{},
PoolType: "uniswap_v3",
Factory: "0x1F98431c8aD98523631AE4a59f267346ea31F984",
Token0: common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"),
Token1: common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"),
Fee: 3000,
CreationBlock: 12345678,
ValidatedAt: time.Now(),
FactoryVerified: true,
InterfaceValid: true,
TokensValid: true,
}
assert.True(t, result.IsValid)
assert.Equal(t, 95, result.SecurityScore)
assert.Len(t, result.Warnings, 1)
assert.Equal(t, "Test warning", result.Warnings[0])
assert.Empty(t, result.Errors)
assert.Equal(t, "uniswap_v3", result.PoolType)
assert.Equal(t, "0x1F98431c8aD98523631AE4a59f267346ea31F984", result.Factory)
assert.Equal(t, common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), result.Token0)
assert.Equal(t, common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"), result.Token1)
assert.Equal(t, uint32(3000), result.Fee)
assert.Equal(t, uint64(12345678), result.CreationBlock)
assert.True(t, result.FactoryVerified)
assert.True(t, result.InterfaceValid)
assert.True(t, result.TokensValid)
assert.WithinDuration(t, time.Now(), result.ValidatedAt, time.Second)
}
// TestValidationConfig tests the ValidationConfig struct
func TestValidationConfig(t *testing.T) {
config := &ValidationConfig{
RequireFactoryVerification: true,
MinSecurityScore: 80,
MaxValidationTime: 30 * time.Second,
AllowUnknownFactories: false,
RequireTokenValidation: true,
}
assert.True(t, config.RequireFactoryVerification)
assert.Equal(t, 80, config.MinSecurityScore)
assert.Equal(t, 30*time.Second, config.MaxValidationTime)
assert.False(t, config.AllowUnknownFactories)
assert.True(t, config.RequireTokenValidation)
}
// TestPoolInfo tests the PoolInfo struct
func TestPoolInfo(t *testing.T) {
now := time.Now()
poolInfo := &PoolInfo{
Address: common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640"),
Token0: common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"),
Token1: common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"),
Protocol: "UniswapV3",
Fee: 3000,
Liquidity: big.NewInt(1000000000000000000),
SqrtPriceX96: big.NewInt(79228162514264337593543950336),
LastUpdated: now,
}
assert.Equal(t, common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640"), poolInfo.Address)
assert.Equal(t, common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), poolInfo.Token0)
assert.Equal(t, common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"), poolInfo.Token1)
assert.Equal(t, "UniswapV3", poolInfo.Protocol)
assert.Equal(t, uint32(3000), poolInfo.Fee)
assert.Equal(t, int64(1000000000000000000), poolInfo.Liquidity.Int64())
assert.Equal(t, int64(79228162514264337593543950336), poolInfo.SqrtPriceX96.Int64())
assert.Equal(t, now, poolInfo.LastUpdated)
}
// TestTokenGraph tests the TokenGraph struct
func TestTokenGraph(t *testing.T) {
graph := NewTokenGraph()
assert.NotNil(t, graph)
assert.NotNil(t, graph.adjacencyList)
assert.NotNil(t, &graph.mutex)
// Test adding edges to the graph
tokenA := common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48") // USDC
tokenB := common.HexToAddress("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2") // WETH
pool := &PoolInfo{
Address: common.HexToAddress("0x88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640"),
Token0: tokenA,
Token1: tokenB,
Protocol: "UniswapV3",
Fee: 3000,
Liquidity: big.NewInt(1000000000000000000),
SqrtPriceX96: big.NewInt(79228162514264337593543950336),
LastUpdated: time.Now(),
}
// Add the pool to the graph
graph.mutex.Lock()
if graph.adjacencyList[tokenA] == nil {
graph.adjacencyList[tokenA] = make(map[common.Address][]*PoolInfo)
}
graph.adjacencyList[tokenA][tokenB] = append(graph.adjacencyList[tokenA][tokenB], pool)
if graph.adjacencyList[tokenB] == nil {
graph.adjacencyList[tokenB] = make(map[common.Address][]*PoolInfo)
}
graph.adjacencyList[tokenB][tokenA] = append(graph.adjacencyList[tokenB][tokenA], pool)
graph.mutex.Unlock()
// Test retrieving adjacent tokens
adjacent := graph.GetAdjacentTokens(tokenA)
assert.Len(t, adjacent, 1)
assert.Contains(t, adjacent, tokenB)
assert.Len(t, adjacent[tokenB], 1)
assert.Equal(t, pool, adjacent[tokenB][0])
}