package pools import ( "fmt" "math/big" "sort" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/fraktal/mev-beta/internal/logger" ) // CREATE2Calculator handles CREATE2 address calculations for various DEX factories type CREATE2Calculator struct { logger *logger.Logger factories map[string]*FactoryConfig } // FactoryConfig contains the configuration for a DEX factory type FactoryConfig struct { Name string // Factory name (e.g., "uniswap_v3", "sushiswap") Address common.Address // Factory contract address InitCodeHash common.Hash // Init code hash for CREATE2 calculation FeeStructure FeeStructure // How fees are encoded SortTokens bool // Whether tokens should be sorted } // FeeStructure defines how fees are handled in address calculation type FeeStructure struct { HasFee bool // Whether fee is part of salt FeePositions []int // Byte positions where fee is encoded DefaultFees []uint32 // Default fee tiers } // PoolIdentifier uniquely identifies a pool type PoolIdentifier struct { Factory string // Factory name Token0 common.Address // First token (lower address if sorted) Token1 common.Address // Second token (higher address if sorted) Fee uint32 // Fee tier PoolAddr common.Address // Calculated pool address } // NewCREATE2Calculator creates a new CREATE2 calculator func NewCREATE2Calculator(logger *logger.Logger) *CREATE2Calculator { calc := &CREATE2Calculator{ logger: logger, factories: make(map[string]*FactoryConfig), } // Initialize with known factory configurations calc.initializeFactories() return calc } // initializeFactories sets up configurations for known DEX factories func (c *CREATE2Calculator) initializeFactories() { // Uniswap V3 Factory c.factories["uniswap_v3"] = &FactoryConfig{ Name: "uniswap_v3", Address: common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"), InitCodeHash: common.HexToHash("0xe34f199b19b2b4f47f68442619d555527d244f78a3297ea89325f843f87b8b54"), FeeStructure: FeeStructure{ HasFee: true, DefaultFees: []uint32{500, 3000, 10000}, // 0.05%, 0.3%, 1% }, SortTokens: true, } // Uniswap V2 Factory c.factories["uniswap_v2"] = &FactoryConfig{ Name: "uniswap_v2", Address: common.HexToAddress("0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f"), InitCodeHash: common.HexToHash("0x96e8ac4277198ff8b6f785478aa9a39f403cb768dd02cbee326c3e7da348845f"), FeeStructure: FeeStructure{ HasFee: false, DefaultFees: []uint32{3000}, // Fixed 0.3% }, SortTokens: true, } // SushiSwap Factory (same as Uniswap V2 but different address) c.factories["sushiswap"] = &FactoryConfig{ Name: "sushiswap", Address: common.HexToAddress("0xC0AEe478e3658e2610c5F7A4A2E1777cE9e4f2Ac"), InitCodeHash: common.HexToHash("0xe18a34eb0e04b04f7a0ac29a6e80748dca96319b42c54d679cb821dca90c6303"), FeeStructure: FeeStructure{ HasFee: false, DefaultFees: []uint32{3000}, // Fixed 0.3% }, SortTokens: true, } // Camelot V3 (Arbitrum-specific) c.factories["camelot_v3"] = &FactoryConfig{ Name: "camelot_v3", Address: common.HexToAddress("0x1a3c9B1d2F0529D97f2afC5136Cc23e58f1FD35B"), InitCodeHash: common.HexToHash("0xa856464ae65f7619087bc369daaf7e387dae1e5af69cfa7935850ebf754b04c1"), FeeStructure: FeeStructure{ HasFee: true, DefaultFees: []uint32{500, 3000, 10000}, // Similar to Uniswap V3 }, SortTokens: true, } // Curve Factory (simplified - Curve uses different math) c.factories["curve"] = &FactoryConfig{ Name: "curve", Address: common.HexToAddress("0xF18056Bbd320E96A48e3Fbf8bC061322531aac99"), InitCodeHash: common.HexToHash("0x00"), // Curve doesn't use standard CREATE2 FeeStructure: FeeStructure{ HasFee: true, DefaultFees: []uint32{400}, // 0.04% typical }, SortTokens: false, // Curve maintains token order } } // CalculatePoolAddress calculates the pool address using CREATE2 func (c *CREATE2Calculator) CalculatePoolAddress(factoryName string, token0, token1 common.Address, fee uint32) (common.Address, error) { factory, exists := c.factories[factoryName] if !exists { return common.Address{}, fmt.Errorf("unknown factory: %s", factoryName) } // Sort tokens if required by the factory if factory.SortTokens { if token0.Big().Cmp(token1.Big()) > 0 { token0, token1 = token1, token0 } } // Calculate salt based on factory type salt, err := c.calculateSalt(factory, token0, token1, fee) if err != nil { return common.Address{}, fmt.Errorf("failed to calculate salt: %w", err) } // Special handling for factories that don't use standard CREATE2 if factoryName == "curve" { return c.calculateCurvePoolAddress(token0, token1, fee) } // Standard CREATE2 calculation: // address = keccak256(0xff + factory_address + salt + init_code_hash)[12:] // Prepare the data for hashing data := make([]byte, 0, 85) // 1 + 20 + 32 + 32 = 85 bytes data = append(data, 0xff) // 1 byte data = append(data, factory.Address.Bytes()...) // 20 bytes data = append(data, salt...) // 32 bytes data = append(data, factory.InitCodeHash.Bytes()...) // 32 bytes // Calculate keccak256 hash hash := crypto.Keccak256(data) // Take the last 20 bytes as the address var poolAddr common.Address copy(poolAddr[:], hash[12:]) c.logger.Debug(fmt.Sprintf("Calculated %s pool address: %s for tokens %s/%s fee %d", factoryName, poolAddr.Hex(), token0.Hex(), token1.Hex(), fee)) return poolAddr, nil } // calculateSalt generates the salt for CREATE2 calculation func (c *CREATE2Calculator) calculateSalt(factory *FactoryConfig, token0, token1 common.Address, fee uint32) ([]byte, error) { switch factory.Name { case "uniswap_v3", "camelot_v3": // Uniswap V3 salt: keccak256(abi.encode(token0, token1, fee)) return c.calculateUniswapV3Salt(token0, token1, fee) case "uniswap_v2", "sushiswap": // Uniswap V2 salt: keccak256(abi.encodePacked(token0, token1)) return c.calculateUniswapV2Salt(token0, token1) default: // Generic salt: keccak256(abi.encode(token0, token1, fee)) return c.calculateGenericSalt(token0, token1, fee) } } // calculateUniswapV3Salt calculates salt for Uniswap V3 style factories func (c *CREATE2Calculator) calculateUniswapV3Salt(token0, token1 common.Address, fee uint32) ([]byte, error) { // ABI encode: token0 (32 bytes) + token1 (32 bytes) + fee (32 bytes) data := make([]byte, 0, 96) // Pad addresses to 32 bytes token0Padded := make([]byte, 32) token1Padded := make([]byte, 32) feePadded := make([]byte, 32) copy(token0Padded[12:], token0.Bytes()) copy(token1Padded[12:], token1.Bytes()) // Convert fee to big endian 32 bytes feeBig := big.NewInt(int64(fee)) feeBytes := feeBig.Bytes() copy(feePadded[32-len(feeBytes):], feeBytes) data = append(data, token0Padded...) data = append(data, token1Padded...) data = append(data, feePadded...) hash := crypto.Keccak256(data) return hash, nil } // calculateUniswapV2Salt calculates salt for Uniswap V2 style factories func (c *CREATE2Calculator) calculateUniswapV2Salt(token0, token1 common.Address) ([]byte, error) { // ABI encodePacked: token0 (20 bytes) + token1 (20 bytes) data := make([]byte, 0, 40) data = append(data, token0.Bytes()...) data = append(data, token1.Bytes()...) hash := crypto.Keccak256(data) return hash, nil } // calculateGenericSalt calculates salt for generic factories func (c *CREATE2Calculator) calculateGenericSalt(token0, token1 common.Address, fee uint32) ([]byte, error) { // Similar to Uniswap V3 but may have different encoding return c.calculateUniswapV3Salt(token0, token1, fee) } // calculateCurvePoolAddress handles Curve's non-standard pool creation func (c *CREATE2Calculator) calculateCurvePoolAddress(token0, token1 common.Address, fee uint32) (common.Address, error) { // Curve uses a different mechanism - often registry-based // For now, return a placeholder calculation // In practice, you'd need to: // 1. Query the Curve registry // 2. Use Curve's specific pool creation logic // 3. Handle different Curve pool types (stable, crypto, etc.) c.logger.Warn("Curve pool address calculation not fully implemented - using placeholder") // Placeholder calculation using simple hash data := make([]byte, 0, 48) data = append(data, token0.Bytes()...) data = append(data, token1.Bytes()...) data = append(data, big.NewInt(int64(fee)).Bytes()...) hash := crypto.Keccak256(data) var addr common.Address copy(addr[:], hash[12:]) return addr, nil } // FindPoolsForTokenPair finds all possible pools for a token pair across all factories func (c *CREATE2Calculator) FindPoolsForTokenPair(token0, token1 common.Address) ([]*PoolIdentifier, error) { pools := make([]*PoolIdentifier, 0) for factoryName, factory := range c.factories { // Sort tokens if required sortedToken0, sortedToken1 := token0, token1 if factory.SortTokens && token0.Big().Cmp(token1.Big()) > 0 { sortedToken0, sortedToken1 = token1, token0 } // Try each default fee tier for this factory for _, fee := range factory.FeeStructure.DefaultFees { poolAddr, err := c.CalculatePoolAddress(factoryName, sortedToken0, sortedToken1, fee) if err != nil { c.logger.Debug(fmt.Sprintf("Failed to calculate pool address for %s: %v", factoryName, err)) continue } pool := &PoolIdentifier{ Factory: factoryName, Token0: sortedToken0, Token1: sortedToken1, Fee: fee, PoolAddr: poolAddr, } pools = append(pools, pool) } } c.logger.Debug(fmt.Sprintf("Found %d potential pools for tokens %s/%s", len(pools), token0.Hex(), token1.Hex())) return pools, nil } // ValidatePoolAddress verifies if a calculated address matches an expected address func (c *CREATE2Calculator) ValidatePoolAddress(factoryName string, token0, token1 common.Address, fee uint32, expectedAddr common.Address) bool { calculatedAddr, err := c.CalculatePoolAddress(factoryName, token0, token1, fee) if err != nil { c.logger.Debug(fmt.Sprintf("Validation failed - calculation error: %v", err)) return false } match := calculatedAddr == expectedAddr c.logger.Debug(fmt.Sprintf("Pool address validation: calculated=%s, expected=%s, match=%v", calculatedAddr.Hex(), expectedAddr.Hex(), match)) return match } // GetFactoryConfig returns the configuration for a specific factory func (c *CREATE2Calculator) GetFactoryConfig(factoryName string) (*FactoryConfig, error) { factory, exists := c.factories[factoryName] if !exists { return nil, fmt.Errorf("unknown factory: %s", factoryName) } // Return a copy to prevent modification configCopy := *factory return &configCopy, nil } // AddCustomFactory adds a custom factory configuration func (c *CREATE2Calculator) AddCustomFactory(config *FactoryConfig) error { if config.Name == "" { return fmt.Errorf("factory name cannot be empty") } if config.Address == (common.Address{}) { return fmt.Errorf("factory address cannot be zero") } c.factories[config.Name] = config c.logger.Info(fmt.Sprintf("Added custom factory: %s at %s", config.Name, config.Address.Hex())) return nil } // ListFactories returns the names of all configured factories func (c *CREATE2Calculator) ListFactories() []string { names := make([]string, 0, len(c.factories)) for name := range c.factories { names = append(names, name) } sort.Strings(names) return names } // CalculateInitCodeHash calculates the init code hash for a given bytecode // This is useful when adding new factories func CalculateInitCodeHash(initCode []byte) common.Hash { return crypto.Keccak256Hash(initCode) } // VerifyFactorySupport checks if a factory supports CREATE2 pool creation func (c *CREATE2Calculator) VerifyFactorySupport(factoryName string) error { factory, exists := c.factories[factoryName] if !exists { return fmt.Errorf("factory %s not configured", factoryName) } // Basic validation if factory.Address == (common.Address{}) { return fmt.Errorf("factory %s has zero address", factoryName) } if factory.InitCodeHash == (common.Hash{}) && factoryName != "curve" { return fmt.Errorf("factory %s has zero init code hash", factoryName) } if len(factory.FeeStructure.DefaultFees) == 0 { return fmt.Errorf("factory %s has no default fees configured", factoryName) } return nil }