Files
mev-beta/pkg/pools/create2.go
2025-09-16 11:05:47 -05:00

372 lines
12 KiB
Go

package pools
import (
"fmt"
"math/big"
"sort"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/fraktal/mev-beta/internal/logger"
)
// CREATE2Calculator handles CREATE2 address calculations for various DEX factories
type CREATE2Calculator struct {
logger *logger.Logger
factories map[string]*FactoryConfig
}
// FactoryConfig contains the configuration for a DEX factory
type FactoryConfig struct {
Name string // Factory name (e.g., "uniswap_v3", "sushiswap")
Address common.Address // Factory contract address
InitCodeHash common.Hash // Init code hash for CREATE2 calculation
FeeStructure FeeStructure // How fees are encoded
SortTokens bool // Whether tokens should be sorted
}
// FeeStructure defines how fees are handled in address calculation
type FeeStructure struct {
HasFee bool // Whether fee is part of salt
FeePositions []int // Byte positions where fee is encoded
DefaultFees []uint32 // Default fee tiers
}
// PoolIdentifier uniquely identifies a pool
type PoolIdentifier struct {
Factory string // Factory name
Token0 common.Address // First token (lower address if sorted)
Token1 common.Address // Second token (higher address if sorted)
Fee uint32 // Fee tier
PoolAddr common.Address // Calculated pool address
}
// NewCREATE2Calculator creates a new CREATE2 calculator
func NewCREATE2Calculator(logger *logger.Logger) *CREATE2Calculator {
calc := &CREATE2Calculator{
logger: logger,
factories: make(map[string]*FactoryConfig),
}
// Initialize with known factory configurations
calc.initializeFactories()
return calc
}
// initializeFactories sets up configurations for known DEX factories
func (c *CREATE2Calculator) initializeFactories() {
// Uniswap V3 Factory
c.factories["uniswap_v3"] = &FactoryConfig{
Name: "uniswap_v3",
Address: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"),
InitCodeHash: common.HexToHash("0xe34f199b19b2b4f47f68442619d555527d244f78a3297ea89325f843f87b8b54"),
FeeStructure: FeeStructure{
HasFee: true,
DefaultFees: []uint32{500, 3000, 10000}, // 0.05%, 0.3%, 1%
},
SortTokens: true,
}
// Uniswap V2 Factory
c.factories["uniswap_v2"] = &FactoryConfig{
Name: "uniswap_v2",
Address: common.HexToAddress("0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f"),
InitCodeHash: common.HexToHash("0x96e8ac4277198ff8b6f785478aa9a39f403cb768dd02cbee326c3e7da348845f"),
FeeStructure: FeeStructure{
HasFee: false,
DefaultFees: []uint32{3000}, // Fixed 0.3%
},
SortTokens: true,
}
// SushiSwap Factory (same as Uniswap V2 but different address)
c.factories["sushiswap"] = &FactoryConfig{
Name: "sushiswap",
Address: common.HexToAddress("0xC0AEe478e3658e2610c5F7A4A2E1777cE9e4f2Ac"),
InitCodeHash: common.HexToHash("0xe18a34eb0e04b04f7a0ac29a6e80748dca96319b42c54d679cb821dca90c6303"),
FeeStructure: FeeStructure{
HasFee: false,
DefaultFees: []uint32{3000}, // Fixed 0.3%
},
SortTokens: true,
}
// Camelot V3 (Arbitrum-specific)
c.factories["camelot_v3"] = &FactoryConfig{
Name: "camelot_v3",
Address: common.HexToAddress("0x1a3c9B1d2F0529D97f2afC5136Cc23e58f1FD35B"),
InitCodeHash: common.HexToHash("0xa856464ae65f7619087bc369daaf7e387dae1e5af69cfa7935850ebf754b04c1"),
FeeStructure: FeeStructure{
HasFee: true,
DefaultFees: []uint32{500, 3000, 10000}, // Similar to Uniswap V3
},
SortTokens: true,
}
// Curve Factory (simplified - Curve uses different math)
c.factories["curve"] = &FactoryConfig{
Name: "curve",
Address: common.HexToAddress("0xF18056Bbd320E96A48e3Fbf8bC061322531aac99"),
InitCodeHash: common.HexToHash("0x00"), // Curve doesn't use standard CREATE2
FeeStructure: FeeStructure{
HasFee: true,
DefaultFees: []uint32{400}, // 0.04% typical
},
SortTokens: false, // Curve maintains token order
}
}
// CalculatePoolAddress calculates the pool address using CREATE2
func (c *CREATE2Calculator) CalculatePoolAddress(factoryName string, token0, token1 common.Address, fee uint32) (common.Address, error) {
factory, exists := c.factories[factoryName]
if !exists {
return common.Address{}, fmt.Errorf("unknown factory: %s", factoryName)
}
// Sort tokens if required by the factory
if factory.SortTokens {
if token0.Big().Cmp(token1.Big()) > 0 {
token0, token1 = token1, token0
}
}
// Calculate salt based on factory type
salt, err := c.calculateSalt(factory, token0, token1, fee)
if err != nil {
return common.Address{}, fmt.Errorf("failed to calculate salt: %w", err)
}
// Special handling for factories that don't use standard CREATE2
if factoryName == "curve" {
return c.calculateCurvePoolAddress(token0, token1, fee)
}
// Standard CREATE2 calculation:
// address = keccak256(0xff + factory_address + salt + init_code_hash)[12:]
// Prepare the data for hashing
data := make([]byte, 0, 85) // 1 + 20 + 32 + 32 = 85 bytes
data = append(data, 0xff) // 1 byte
data = append(data, factory.Address.Bytes()...) // 20 bytes
data = append(data, salt...) // 32 bytes
data = append(data, factory.InitCodeHash.Bytes()...) // 32 bytes
// Calculate keccak256 hash
hash := crypto.Keccak256(data)
// Take the last 20 bytes as the address
var poolAddr common.Address
copy(poolAddr[:], hash[12:])
c.logger.Debug(fmt.Sprintf("Calculated %s pool address: %s for tokens %s/%s fee %d",
factoryName, poolAddr.Hex(), token0.Hex(), token1.Hex(), fee))
return poolAddr, nil
}
// calculateSalt generates the salt for CREATE2 calculation
func (c *CREATE2Calculator) calculateSalt(factory *FactoryConfig, token0, token1 common.Address, fee uint32) ([]byte, error) {
switch factory.Name {
case "uniswap_v3", "camelot_v3":
// Uniswap V3 salt: keccak256(abi.encode(token0, token1, fee))
return c.calculateUniswapV3Salt(token0, token1, fee)
case "uniswap_v2", "sushiswap":
// Uniswap V2 salt: keccak256(abi.encodePacked(token0, token1))
return c.calculateUniswapV2Salt(token0, token1)
default:
// Generic salt: keccak256(abi.encode(token0, token1, fee))
return c.calculateGenericSalt(token0, token1, fee)
}
}
// calculateUniswapV3Salt calculates salt for Uniswap V3 style factories
func (c *CREATE2Calculator) calculateUniswapV3Salt(token0, token1 common.Address, fee uint32) ([]byte, error) {
// ABI encode: token0 (32 bytes) + token1 (32 bytes) + fee (32 bytes)
data := make([]byte, 0, 96)
// Pad addresses to 32 bytes
token0Padded := make([]byte, 32)
token1Padded := make([]byte, 32)
feePadded := make([]byte, 32)
copy(token0Padded[12:], token0.Bytes())
copy(token1Padded[12:], token1.Bytes())
// Convert fee to big endian 32 bytes
feeBig := big.NewInt(int64(fee))
feeBytes := feeBig.Bytes()
copy(feePadded[32-len(feeBytes):], feeBytes)
data = append(data, token0Padded...)
data = append(data, token1Padded...)
data = append(data, feePadded...)
hash := crypto.Keccak256(data)
return hash, nil
}
// calculateUniswapV2Salt calculates salt for Uniswap V2 style factories
func (c *CREATE2Calculator) calculateUniswapV2Salt(token0, token1 common.Address) ([]byte, error) {
// ABI encodePacked: token0 (20 bytes) + token1 (20 bytes)
data := make([]byte, 0, 40)
data = append(data, token0.Bytes()...)
data = append(data, token1.Bytes()...)
hash := crypto.Keccak256(data)
return hash, nil
}
// calculateGenericSalt calculates salt for generic factories
func (c *CREATE2Calculator) calculateGenericSalt(token0, token1 common.Address, fee uint32) ([]byte, error) {
// Similar to Uniswap V3 but may have different encoding
return c.calculateUniswapV3Salt(token0, token1, fee)
}
// calculateCurvePoolAddress handles Curve's non-standard pool creation
func (c *CREATE2Calculator) calculateCurvePoolAddress(token0, token1 common.Address, fee uint32) (common.Address, error) {
// Curve uses a different mechanism - often registry-based
// For now, return a placeholder calculation
// In practice, you'd need to:
// 1. Query the Curve registry
// 2. Use Curve's specific pool creation logic
// 3. Handle different Curve pool types (stable, crypto, etc.)
c.logger.Warn("Curve pool address calculation not fully implemented - using placeholder")
// Placeholder calculation using simple hash
data := make([]byte, 0, 48)
data = append(data, token0.Bytes()...)
data = append(data, token1.Bytes()...)
data = append(data, big.NewInt(int64(fee)).Bytes()...)
hash := crypto.Keccak256(data)
var addr common.Address
copy(addr[:], hash[12:])
return addr, nil
}
// FindPoolsForTokenPair finds all possible pools for a token pair across all factories
func (c *CREATE2Calculator) FindPoolsForTokenPair(token0, token1 common.Address) ([]*PoolIdentifier, error) {
pools := make([]*PoolIdentifier, 0)
for factoryName, factory := range c.factories {
// Sort tokens if required
sortedToken0, sortedToken1 := token0, token1
if factory.SortTokens && token0.Big().Cmp(token1.Big()) > 0 {
sortedToken0, sortedToken1 = token1, token0
}
// Try each default fee tier for this factory
for _, fee := range factory.FeeStructure.DefaultFees {
poolAddr, err := c.CalculatePoolAddress(factoryName, sortedToken0, sortedToken1, fee)
if err != nil {
c.logger.Debug(fmt.Sprintf("Failed to calculate pool address for %s: %v", factoryName, err))
continue
}
pool := &PoolIdentifier{
Factory: factoryName,
Token0: sortedToken0,
Token1: sortedToken1,
Fee: fee,
PoolAddr: poolAddr,
}
pools = append(pools, pool)
}
}
c.logger.Debug(fmt.Sprintf("Found %d potential pools for tokens %s/%s",
len(pools), token0.Hex(), token1.Hex()))
return pools, nil
}
// ValidatePoolAddress verifies if a calculated address matches an expected address
func (c *CREATE2Calculator) ValidatePoolAddress(factoryName string, token0, token1 common.Address, fee uint32, expectedAddr common.Address) bool {
calculatedAddr, err := c.CalculatePoolAddress(factoryName, token0, token1, fee)
if err != nil {
c.logger.Debug(fmt.Sprintf("Validation failed - calculation error: %v", err))
return false
}
match := calculatedAddr == expectedAddr
c.logger.Debug(fmt.Sprintf("Pool address validation: calculated=%s, expected=%s, match=%v",
calculatedAddr.Hex(), expectedAddr.Hex(), match))
return match
}
// GetFactoryConfig returns the configuration for a specific factory
func (c *CREATE2Calculator) GetFactoryConfig(factoryName string) (*FactoryConfig, error) {
factory, exists := c.factories[factoryName]
if !exists {
return nil, fmt.Errorf("unknown factory: %s", factoryName)
}
// Return a copy to prevent modification
configCopy := *factory
return &configCopy, nil
}
// AddCustomFactory adds a custom factory configuration
func (c *CREATE2Calculator) AddCustomFactory(config *FactoryConfig) error {
if config.Name == "" {
return fmt.Errorf("factory name cannot be empty")
}
if config.Address == (common.Address{}) {
return fmt.Errorf("factory address cannot be zero")
}
c.factories[config.Name] = config
c.logger.Info(fmt.Sprintf("Added custom factory: %s at %s", config.Name, config.Address.Hex()))
return nil
}
// ListFactories returns the names of all configured factories
func (c *CREATE2Calculator) ListFactories() []string {
names := make([]string, 0, len(c.factories))
for name := range c.factories {
names = append(names, name)
}
sort.Strings(names)
return names
}
// CalculateInitCodeHash calculates the init code hash for a given bytecode
// This is useful when adding new factories
func CalculateInitCodeHash(initCode []byte) common.Hash {
return crypto.Keccak256Hash(initCode)
}
// VerifyFactorySupport checks if a factory supports CREATE2 pool creation
func (c *CREATE2Calculator) VerifyFactorySupport(factoryName string) error {
factory, exists := c.factories[factoryName]
if !exists {
return fmt.Errorf("factory %s not configured", factoryName)
}
// Basic validation
if factory.Address == (common.Address{}) {
return fmt.Errorf("factory %s has zero address", factoryName)
}
if factory.InitCodeHash == (common.Hash{}) && factoryName != "curve" {
return fmt.Errorf("factory %s has zero init code hash", factoryName)
}
if len(factory.FeeStructure.DefaultFees) == 0 {
return fmt.Errorf("factory %s has no default fees configured", factoryName)
}
return nil
}