Files
mev-beta/pkg/uniswap/pool_detector.go

274 lines
7.3 KiB
Go

package uniswap
import (
"context"
"errors"
"fmt"
"math/big"
"strings"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient"
)
// PoolVersion represents the version of a DEX pool
type PoolVersion int
const (
PoolVersionUnknown PoolVersion = iota
PoolVersionV2 // Uniswap V2 style (uses getReserves)
PoolVersionV3 // Uniswap V3 style (uses slot0)
PoolVersionBalancer
PoolVersionCurve
)
// String returns the string representation of the pool version
func (pv PoolVersion) String() string {
switch pv {
case PoolVersionV2:
return "UniswapV2"
case PoolVersionV3:
return "UniswapV3"
case PoolVersionBalancer:
return "Balancer"
case PoolVersionCurve:
return "Curve"
default:
return "Unknown"
}
}
// PoolDetector detects the version of a DEX pool
type PoolDetector struct {
client *ethclient.Client
// Cache of detected pool versions
versionCache map[common.Address]PoolVersion
}
// NewPoolDetector creates a new pool detector
func NewPoolDetector(client *ethclient.Client) *PoolDetector {
return &PoolDetector{
client: client,
versionCache: make(map[common.Address]PoolVersion),
}
}
// DetectPoolVersion detects the version of a pool by checking which functions it supports
func (pd *PoolDetector) DetectPoolVersion(ctx context.Context, poolAddress common.Address) (PoolVersion, error) {
// Check cache first
if version, exists := pd.versionCache[poolAddress]; exists {
return version, nil
}
// Try V3 first (slot0 function)
if pd.hasSlot0(ctx, poolAddress) {
pd.versionCache[poolAddress] = PoolVersionV3
return PoolVersionV3, nil
}
// Try V2 (getReserves function)
if pd.hasGetReserves(ctx, poolAddress) {
pd.versionCache[poolAddress] = PoolVersionV2
return PoolVersionV2, nil
}
// Try Balancer (getPoolId function)
if pd.hasGetPoolId(ctx, poolAddress) {
pd.versionCache[poolAddress] = PoolVersionBalancer
return PoolVersionBalancer, nil
}
// Unknown pool type
pd.versionCache[poolAddress] = PoolVersionUnknown
return PoolVersionUnknown, errors.New("unable to detect pool version")
}
// hasSlot0 checks if a pool has the slot0() function (Uniswap V3)
func (pd *PoolDetector) hasSlot0(ctx context.Context, poolAddress common.Address) bool {
// Create minimal ABI for slot0 function
slot0ABI := `[{
"inputs": [],
"name": "slot0",
"outputs": [
{"internalType": "uint160", "name": "sqrtPriceX96", "type": "uint160"},
{"internalType": "int24", "name": "tick", "type": "int24"},
{"internalType": "uint16", "name": "observationIndex", "type": "uint16"},
{"internalType": "uint16", "name": "observationCardinality", "type": "uint16"},
{"internalType": "uint16", "name": "observationCardinalityNext", "type": "uint16"},
{"internalType": "uint8", "name": "feeProtocol", "type": "uint8"},
{"internalType": "bool", "name": "unlocked", "type": "bool"}
],
"stateMutability": "view",
"type": "function"
}]`
parsedABI, err := abi.JSON(strings.NewReader(slot0ABI))
if err != nil {
return false
}
data, err := parsedABI.Pack("slot0")
if err != nil {
return false
}
msg := ethereum.CallMsg{
To: &poolAddress,
Data: data,
}
result, err := pd.client.CallContract(ctx, msg, nil)
if err != nil {
return false
}
// Check if result has the expected length for slot0 return values
// slot0 returns 7 values, should be at least 7*32 = 224 bytes
return len(result) >= 224
}
// hasGetReserves checks if a pool has the getReserves() function (Uniswap V2)
func (pd *PoolDetector) hasGetReserves(ctx context.Context, poolAddress common.Address) bool {
// Create minimal ABI for getReserves function
getReservesABI := `[{
"inputs": [],
"name": "getReserves",
"outputs": [
{"internalType": "uint112", "name": "_reserve0", "type": "uint112"},
{"internalType": "uint112", "name": "_reserve1", "type": "uint112"},
{"internalType": "uint32", "name": "_blockTimestampLast", "type": "uint32"}
],
"stateMutability": "view",
"type": "function"
}]`
parsedABI, err := abi.JSON(strings.NewReader(getReservesABI))
if err != nil {
return false
}
data, err := parsedABI.Pack("getReserves")
if err != nil {
return false
}
msg := ethereum.CallMsg{
To: &poolAddress,
Data: data,
}
result, err := pd.client.CallContract(ctx, msg, nil)
if err != nil {
return false
}
// Check if result has the expected length for getReserves return values
// getReserves returns 3 values (uint112, uint112, uint32) = 96 bytes
return len(result) >= 96
}
// hasGetPoolId checks if a pool has the getPoolId() function (Balancer)
func (pd *PoolDetector) hasGetPoolId(ctx context.Context, poolAddress common.Address) bool {
// Create minimal ABI for getPoolId function
getPoolIdABI := `[{
"inputs": [],
"name": "getPoolId",
"outputs": [{"internalType": "bytes32", "name": "", "type": "bytes32"}],
"stateMutability": "view",
"type": "function"
}]`
parsedABI, err := abi.JSON(strings.NewReader(getPoolIdABI))
if err != nil {
return false
}
data, err := parsedABI.Pack("getPoolId")
if err != nil {
return false
}
msg := ethereum.CallMsg{
To: &poolAddress,
Data: data,
}
result, err := pd.client.CallContract(ctx, msg, nil)
if err != nil {
return false
}
// Check if result is a bytes32 (32 bytes)
return len(result) == 32
}
// GetReservesV2 fetches reserves from a Uniswap V2 style pool
func (pd *PoolDetector) GetReservesV2(ctx context.Context, poolAddress common.Address) (*big.Int, *big.Int, error) {
getReservesABI := `[{
"inputs": [],
"name": "getReserves",
"outputs": [
{"internalType": "uint112", "name": "_reserve0", "type": "uint112"},
{"internalType": "uint112", "name": "_reserve1", "type": "uint112"},
{"internalType": "uint32", "name": "_blockTimestampLast", "type": "uint32"}
],
"stateMutability": "view",
"type": "function"
}]`
parsedABI, err := abi.JSON(strings.NewReader(getReservesABI))
if err != nil {
return nil, nil, fmt.Errorf("failed to parse getReserves ABI: %w", err)
}
data, err := parsedABI.Pack("getReserves")
if err != nil {
return nil, nil, fmt.Errorf("failed to pack getReserves call: %w", err)
}
msg := ethereum.CallMsg{
To: &poolAddress,
Data: data,
}
result, err := pd.client.CallContract(ctx, msg, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to call getReserves: %w", err)
}
unpacked, err := parsedABI.Unpack("getReserves", result)
if err != nil {
return nil, nil, fmt.Errorf("failed to unpack getReserves result: %w", err)
}
if len(unpacked) < 2 {
return nil, nil, fmt.Errorf("unexpected number of return values from getReserves: got %d, expected 3", len(unpacked))
}
reserve0, ok := unpacked[0].(*big.Int)
if !ok {
return nil, nil, fmt.Errorf("failed to convert reserve0 to *big.Int")
}
reserve1, ok := unpacked[1].(*big.Int)
if !ok {
return nil, nil, fmt.Errorf("failed to convert reserve1 to *big.Int")
}
return reserve0, reserve1, nil
}
// ClearCache clears the version cache
func (pd *PoolDetector) ClearCache() {
pd.versionCache = make(map[common.Address]PoolVersion)
}
// GetCachedVersion returns the cached version for a pool, if available
func (pd *PoolDetector) GetCachedVersion(poolAddress common.Address) (PoolVersion, bool) {
version, exists := pd.versionCache[poolAddress]
return version, exists
}