diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/pkg/parsers/factory.go b/pkg/parsers/factory.go new file mode 100644 index 0000000..4249e05 --- /dev/null +++ b/pkg/parsers/factory.go @@ -0,0 +1,104 @@ +package parsers + +import ( + "context" + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/core/types" + + mevtypes "github.com/your-org/mev-bot/pkg/types" +) + +// factory implements the Factory interface +type factory struct { + parsers map[mevtypes.ProtocolType]Parser + mu sync.RWMutex +} + +// NewFactory creates a new parser factory +func NewFactory() Factory { + return &factory{ + parsers: make(map[mevtypes.ProtocolType]Parser), + } +} + +// RegisterParser registers a parser for a protocol +func (f *factory) RegisterParser(protocol mevtypes.ProtocolType, parser Parser) error { + if protocol == mevtypes.ProtocolUnknown { + return fmt.Errorf("cannot register parser for unknown protocol") + } + + if parser == nil { + return fmt.Errorf("parser cannot be nil") + } + + f.mu.Lock() + defer f.mu.Unlock() + + if _, exists := f.parsers[protocol]; exists { + return fmt.Errorf("parser for protocol %s already registered", protocol) + } + + f.parsers[protocol] = parser + return nil +} + +// GetParser returns a parser for the given protocol +func (f *factory) GetParser(protocol mevtypes.ProtocolType) (Parser, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + parser, exists := f.parsers[protocol] + if !exists { + return nil, fmt.Errorf("no parser registered for protocol %s", protocol) + } + + return parser, nil +} + +// ParseLog routes a log to the appropriate parser +func (f *factory) ParseLog(ctx context.Context, log types.Log, tx *types.Transaction) (*mevtypes.SwapEvent, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + // Try each registered parser + for _, parser := range f.parsers { + if parser.SupportsLog(log) { + return parser.ParseLog(ctx, log, tx) + } + } + + return nil, mevtypes.ErrUnsupportedProtocol +} + +// ParseTransaction parses all swap events from a transaction +func (f *factory) ParseTransaction(ctx context.Context, tx *types.Transaction, receipt *types.Receipt) ([]*mevtypes.SwapEvent, error) { + if receipt == nil { + return nil, fmt.Errorf("receipt cannot be nil") + } + + f.mu.RLock() + defer f.mu.RUnlock() + + var allEvents []*mevtypes.SwapEvent + + // Try each log with all parsers + for _, log := range receipt.Logs { + for _, parser := range f.parsers { + if parser.SupportsLog(*log) { + event, err := parser.ParseLog(ctx, *log, tx) + if err != nil { + // Log error but continue with other parsers + continue + } + if event != nil { + allEvents = append(allEvents, event) + } + break // Found parser for this log, move to next log + } + } + } + + return allEvents, nil +} diff --git a/pkg/parsers/factory_test.go b/pkg/parsers/factory_test.go new file mode 100644 index 0000000..17b4c72 --- /dev/null +++ b/pkg/parsers/factory_test.go @@ -0,0 +1,407 @@ +package parsers + +import ( + "context" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + + mevtypes "github.com/your-org/mev-bot/pkg/types" +) + +// mockParser is a mock implementation of Parser for testing +type mockParser struct { + protocol mevtypes.ProtocolType + supportsLog func(types.Log) bool + parseLog func(context.Context, types.Log, *types.Transaction) (*mevtypes.SwapEvent, error) + parseReceipt func(context.Context, *types.Receipt, *types.Transaction) ([]*mevtypes.SwapEvent, error) +} + +func (m *mockParser) ParseLog(ctx context.Context, log types.Log, tx *types.Transaction) (*mevtypes.SwapEvent, error) { + if m.parseLog != nil { + return m.parseLog(ctx, log, tx) + } + return nil, nil +} + +func (m *mockParser) ParseReceipt(ctx context.Context, receipt *types.Receipt, tx *types.Transaction) ([]*mevtypes.SwapEvent, error) { + if m.parseReceipt != nil { + return m.parseReceipt(ctx, receipt, tx) + } + return nil, nil +} + +func (m *mockParser) SupportsLog(log types.Log) bool { + if m.supportsLog != nil { + return m.supportsLog(log) + } + return false +} + +func (m *mockParser) Protocol() mevtypes.ProtocolType { + return m.protocol +} + +func TestNewFactory(t *testing.T) { + factory := NewFactory() + if factory == nil { + t.Fatal("NewFactory returned nil") + } +} + +func TestFactory_RegisterParser(t *testing.T) { + tests := []struct { + name string + protocol mevtypes.ProtocolType + parser Parser + wantErr bool + errString string + }{ + { + name: "valid registration", + protocol: mevtypes.ProtocolUniswapV2, + parser: &mockParser{protocol: mevtypes.ProtocolUniswapV2}, + wantErr: false, + }, + { + name: "unknown protocol", + protocol: mevtypes.ProtocolUnknown, + parser: &mockParser{protocol: mevtypes.ProtocolUnknown}, + wantErr: true, + errString: "cannot register parser for unknown protocol", + }, + { + name: "nil parser", + protocol: mevtypes.ProtocolUniswapV2, + parser: nil, + wantErr: true, + errString: "parser cannot be nil", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := NewFactory() + err := factory.RegisterParser(tt.protocol, tt.parser) + + if tt.wantErr { + if err == nil { + t.Errorf("RegisterParser() expected error, got nil") + return + } + if tt.errString != "" && err.Error() != tt.errString { + t.Errorf("RegisterParser() error = %v, want %v", err.Error(), tt.errString) + } + } else { + if err != nil { + t.Errorf("RegisterParser() unexpected error: %v", err) + } + } + }) + } +} + +func TestFactory_RegisterParser_Duplicate(t *testing.T) { + factory := NewFactory() + parser := &mockParser{protocol: mevtypes.ProtocolUniswapV2} + + // First registration should succeed + err := factory.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + if err != nil { + t.Fatalf("First RegisterParser() failed: %v", err) + } + + // Second registration should fail + err = factory.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + if err == nil { + t.Error("RegisterParser() expected error for duplicate registration, got nil") + } +} + +func TestFactory_GetParser(t *testing.T) { + factory := NewFactory() + parser := &mockParser{protocol: mevtypes.ProtocolUniswapV2} + + // Register parser + err := factory.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + if err != nil { + t.Fatalf("RegisterParser() failed: %v", err) + } + + tests := []struct { + name string + protocol mevtypes.ProtocolType + wantErr bool + }{ + { + name: "get registered parser", + protocol: mevtypes.ProtocolUniswapV2, + wantErr: false, + }, + { + name: "get unregistered parser", + protocol: mevtypes.ProtocolUniswapV3, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := factory.GetParser(tt.protocol) + + if tt.wantErr { + if err == nil { + t.Error("GetParser() expected error, got nil") + } + if got != nil { + t.Error("GetParser() expected nil parser on error") + } + } else { + if err != nil { + t.Errorf("GetParser() unexpected error: %v", err) + } + if got == nil { + t.Error("GetParser() returned nil parser") + } + } + }) + } +} + +func TestFactory_ParseLog(t *testing.T) { + ctx := context.Background() + + // Create test log + testLog := types.Log{ + Address: common.HexToAddress("0x1234"), + Topics: []common.Hash{common.HexToHash("0xabcd")}, + Data: []byte{}, + } + + testTx := types.NewTransaction( + 0, + common.HexToAddress("0x1234"), + big.NewInt(0), + 21000, + big.NewInt(1000000000), + nil, + ) + + tests := []struct { + name string + setupFactory func() Factory + log types.Log + tx *types.Transaction + wantErr bool + wantEvent bool + }{ + { + name: "parser supports log", + setupFactory: func() Factory { + f := NewFactory() + parser := &mockParser{ + protocol: mevtypes.ProtocolUniswapV2, + supportsLog: func(log types.Log) bool { + return true + }, + parseLog: func(ctx context.Context, log types.Log, tx *types.Transaction) (*mevtypes.SwapEvent, error) { + return &mevtypes.SwapEvent{ + Protocol: mevtypes.ProtocolUniswapV2, + }, nil + }, + } + f.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + return f + }, + log: testLog, + tx: testTx, + wantErr: false, + wantEvent: true, + }, + { + name: "no parser supports log", + setupFactory: func() Factory { + f := NewFactory() + parser := &mockParser{ + protocol: mevtypes.ProtocolUniswapV2, + supportsLog: func(log types.Log) bool { + return false + }, + } + f.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + return f + }, + log: testLog, + tx: testTx, + wantErr: true, + wantEvent: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := tt.setupFactory() + event, err := factory.ParseLog(ctx, tt.log, tt.tx) + + if tt.wantErr { + if err == nil { + t.Error("ParseLog() expected error, got nil") + } + } else { + if err != nil { + t.Errorf("ParseLog() unexpected error: %v", err) + } + } + + if tt.wantEvent { + if event == nil { + t.Error("ParseLog() expected event, got nil") + } + } else { + if event != nil && !tt.wantErr { + t.Error("ParseLog() expected nil event") + } + } + }) + } +} + +func TestFactory_ParseTransaction(t *testing.T) { + ctx := context.Background() + + testTx := types.NewTransaction( + 0, + common.HexToAddress("0x1234"), + big.NewInt(0), + 21000, + big.NewInt(1000000000), + nil, + ) + + testLog := &types.Log{ + Address: common.HexToAddress("0x1234"), + Topics: []common.Hash{common.HexToHash("0xabcd")}, + Data: []byte{}, + } + + testReceipt := &types.Receipt{ + Logs: []*types.Log{testLog}, + } + + tests := []struct { + name string + setupFactory func() Factory + tx *types.Transaction + receipt *types.Receipt + wantErr bool + wantEvents int + }{ + { + name: "parse transaction with events", + setupFactory: func() Factory { + f := NewFactory() + parser := &mockParser{ + protocol: mevtypes.ProtocolUniswapV2, + supportsLog: func(log types.Log) bool { + return true + }, + parseLog: func(ctx context.Context, log types.Log, tx *types.Transaction) (*mevtypes.SwapEvent, error) { + return &mevtypes.SwapEvent{ + Protocol: mevtypes.ProtocolUniswapV2, + }, nil + }, + } + f.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + return f + }, + tx: testTx, + receipt: testReceipt, + wantErr: false, + wantEvents: 1, + }, + { + name: "parse transaction with no matching parsers", + setupFactory: func() Factory { + f := NewFactory() + parser := &mockParser{ + protocol: mevtypes.ProtocolUniswapV2, + supportsLog: func(log types.Log) bool { + return false + }, + } + f.RegisterParser(mevtypes.ProtocolUniswapV2, parser) + return f + }, + tx: testTx, + receipt: testReceipt, + wantErr: false, + wantEvents: 0, + }, + { + name: "nil receipt", + setupFactory: func() Factory { + return NewFactory() + }, + tx: testTx, + receipt: nil, + wantErr: true, + wantEvents: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := tt.setupFactory() + events, err := factory.ParseTransaction(ctx, tt.tx, tt.receipt) + + if tt.wantErr { + if err == nil { + t.Error("ParseTransaction() expected error, got nil") + } + } else { + if err != nil { + t.Errorf("ParseTransaction() unexpected error: %v", err) + } + } + + if len(events) != tt.wantEvents { + t.Errorf("ParseTransaction() got %d events, want %d", len(events), tt.wantEvents) + } + }) + } +} + +func TestFactory_ConcurrentAccess(t *testing.T) { + factory := NewFactory() + + // Test concurrent registration + done := make(chan bool) + + for i := 0; i < 10; i++ { + go func(n int) { + protocol := mevtypes.ProtocolType(fmt.Sprintf("protocol-%d", n)) + parser := &mockParser{protocol: protocol} + factory.RegisterParser(protocol, parser) + done <- true + }(i) + } + + for i := 0; i < 10; i++ { + <-done + } + + // Test concurrent reads + for i := 0; i < 10; i++ { + go func(n int) { + protocol := mevtypes.ProtocolType(fmt.Sprintf("protocol-%d", n)) + factory.GetParser(protocol) + done <- true + }(i) + } + + for i := 0; i < 10; i++ { + <-done + } +} diff --git a/pkg/types/pool_test.go b/pkg/types/pool_test.go new file mode 100644 index 0000000..f16449d --- /dev/null +++ b/pkg/types/pool_test.go @@ -0,0 +1,286 @@ +package types + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +func TestPoolInfo_Validate(t *testing.T) { + validPool := &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + PoolType: "constant-product", + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + Reserve0: big.NewInt(1000000), + Reserve1: big.NewInt(500000), + IsActive: true, + } + + tests := []struct { + name string + pool *PoolInfo + wantErr error + }{ + { + name: "valid pool", + pool: validPool, + wantErr: nil, + }, + { + name: "invalid pool address", + pool: &PoolInfo{ + Address: common.Address{}, + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + }, + wantErr: ErrInvalidPoolAddress, + }, + { + name: "invalid token0 address", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.Address{}, + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + }, + wantErr: ErrInvalidToken0Address, + }, + { + name: "invalid token1 address", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.Address{}, + Token0Decimals: 18, + Token1Decimals: 6, + }, + wantErr: ErrInvalidToken1Address, + }, + { + name: "invalid token0 decimals - zero", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 0, + Token1Decimals: 6, + }, + wantErr: ErrInvalidToken0Decimals, + }, + { + name: "invalid token0 decimals - too high", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 19, + Token1Decimals: 6, + }, + wantErr: ErrInvalidToken0Decimals, + }, + { + name: "invalid token1 decimals - zero", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 0, + }, + wantErr: ErrInvalidToken1Decimals, + }, + { + name: "invalid token1 decimals - too high", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 19, + }, + wantErr: ErrInvalidToken1Decimals, + }, + { + name: "unknown protocol", + pool: &PoolInfo{ + Address: common.HexToAddress("0x1111"), + Protocol: ProtocolUnknown, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + }, + wantErr: ErrUnknownProtocol, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.pool.Validate() + if err != tt.wantErr { + t.Errorf("Validate() error = %v, want %v", err, tt.wantErr) + } + }) + } +} + +func TestPoolInfo_GetTokenPair(t *testing.T) { + tests := []struct { + name string + pool *PoolInfo + wantToken0 common.Address + wantToken1 common.Address + }{ + { + name: "token0 < token1", + pool: &PoolInfo{ + Token0: common.HexToAddress("0x1111"), + Token1: common.HexToAddress("0x2222"), + }, + wantToken0: common.HexToAddress("0x1111"), + wantToken1: common.HexToAddress("0x2222"), + }, + { + name: "token1 < token0", + pool: &PoolInfo{ + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x1111"), + }, + wantToken0: common.HexToAddress("0x1111"), + wantToken1: common.HexToAddress("0x2222"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token0, token1 := tt.pool.GetTokenPair() + if token0 != tt.wantToken0 { + t.Errorf("GetTokenPair() token0 = %v, want %v", token0, tt.wantToken0) + } + if token1 != tt.wantToken1 { + t.Errorf("GetTokenPair() token1 = %v, want %v", token1, tt.wantToken1) + } + }) + } +} + +func TestPoolInfo_CalculatePrice(t *testing.T) { + tests := []struct { + name string + pool *PoolInfo + wantPrice string // String representation for comparison + }{ + { + name: "equal decimals", + pool: &PoolInfo{ + Token0Decimals: 18, + Token1Decimals: 18, + Reserve0: big.NewInt(1000000000000000000), // 1e18 + Reserve1: big.NewInt(2000000000000000000), // 2e18 + }, + wantPrice: "2", + }, + { + name: "different decimals - USDC/WETH", + pool: &PoolInfo{ + Token0Decimals: 6, // USDC + Token1Decimals: 18, // WETH + Reserve0: big.NewInt(1000000), // 1 USDC + Reserve1: big.NewInt(1000000000000000000), // 1 WETH + }, + wantPrice: "1000000000000", // 1 WETH = 1,000,000,000,000 scaled USDC + }, + { + name: "zero reserve0", + pool: &PoolInfo{ + Token0Decimals: 18, + Token1Decimals: 18, + Reserve0: big.NewInt(0), + Reserve1: big.NewInt(1000), + }, + wantPrice: "0", + }, + { + name: "nil reserves", + pool: &PoolInfo{ + Token0Decimals: 18, + Token1Decimals: 18, + Reserve0: nil, + Reserve1: nil, + }, + wantPrice: "0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + price := tt.pool.CalculatePrice() + if price.String() != tt.wantPrice { + t.Errorf("CalculatePrice() = %v, want %v", price.String(), tt.wantPrice) + } + }) + } +} + +func Test_scaleToDecimals(t *testing.T) { + tests := []struct { + name string + amount *big.Int + fromDecimals uint8 + toDecimals uint8 + want *big.Int + }{ + { + name: "same decimals", + amount: big.NewInt(1000), + fromDecimals: 18, + toDecimals: 18, + want: big.NewInt(1000), + }, + { + name: "scale up - 6 to 18 decimals", + amount: big.NewInt(1000000), // 1 USDC (6 decimals) + fromDecimals: 6, + toDecimals: 18, + want: new(big.Int).Mul(big.NewInt(1000000), new(big.Int).Exp(big.NewInt(10), big.NewInt(12), nil)), + }, + { + name: "scale down - 18 to 6 decimals", + amount: new(big.Int).Mul(big.NewInt(1), new(big.Int).Exp(big.NewInt(10), big.NewInt(18), nil)), // 1 ETH + fromDecimals: 18, + toDecimals: 6, + want: big.NewInt(1000000), + }, + { + name: "scale up - 8 to 18 decimals (WBTC to ETH)", + amount: big.NewInt(100000000), // 1 WBTC (8 decimals) + fromDecimals: 8, + toDecimals: 18, + want: new(big.Int).Mul(big.NewInt(100000000), new(big.Int).Exp(big.NewInt(10), big.NewInt(10), nil)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := scaleToDecimals(tt.amount, tt.fromDecimals, tt.toDecimals) + if got.Cmp(tt.want) != 0 { + t.Errorf("scaleToDecimals() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/types/swap_test.go b/pkg/types/swap_test.go new file mode 100644 index 0000000..2daaf40 --- /dev/null +++ b/pkg/types/swap_test.go @@ -0,0 +1,259 @@ +package types + +import ( + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" +) + +func TestSwapEvent_Validate(t *testing.T) { + validEvent := &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + BlockNumber: 1000, + LogIndex: 0, + Timestamp: time.Now(), + PoolAddress: common.HexToAddress("0x1111"), + Protocol: ProtocolUniswapV2, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Token0Decimals: 18, + Token1Decimals: 6, + Amount0In: big.NewInt(1000), + Amount1In: big.NewInt(0), + Amount0Out: big.NewInt(0), + Amount1Out: big.NewInt(500), + Sender: common.HexToAddress("0x4444"), + Recipient: common.HexToAddress("0x5555"), + } + + tests := []struct { + name string + event *SwapEvent + wantErr error + }{ + { + name: "valid event", + event: validEvent, + wantErr: nil, + }, + { + name: "invalid tx hash", + event: &SwapEvent{ + TxHash: common.Hash{}, + PoolAddress: common.HexToAddress("0x1111"), + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Protocol: ProtocolUniswapV2, + Amount0In: big.NewInt(1000), + }, + wantErr: ErrInvalidTxHash, + }, + { + name: "invalid pool address", + event: &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + PoolAddress: common.Address{}, + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Protocol: ProtocolUniswapV2, + Amount0In: big.NewInt(1000), + }, + wantErr: ErrInvalidPoolAddress, + }, + { + name: "invalid token0 address", + event: &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + PoolAddress: common.HexToAddress("0x1111"), + Token0: common.Address{}, + Token1: common.HexToAddress("0x3333"), + Protocol: ProtocolUniswapV2, + Amount0In: big.NewInt(1000), + }, + wantErr: ErrInvalidToken0Address, + }, + { + name: "invalid token1 address", + event: &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + PoolAddress: common.HexToAddress("0x1111"), + Token0: common.HexToAddress("0x2222"), + Token1: common.Address{}, + Protocol: ProtocolUniswapV2, + Amount0In: big.NewInt(1000), + }, + wantErr: ErrInvalidToken1Address, + }, + { + name: "unknown protocol", + event: &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + PoolAddress: common.HexToAddress("0x1111"), + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Protocol: ProtocolUnknown, + Amount0In: big.NewInt(1000), + }, + wantErr: ErrUnknownProtocol, + }, + { + name: "zero amounts", + event: &SwapEvent{ + TxHash: common.HexToHash("0x1234"), + PoolAddress: common.HexToAddress("0x1111"), + Token0: common.HexToAddress("0x2222"), + Token1: common.HexToAddress("0x3333"), + Protocol: ProtocolUniswapV2, + Amount0In: big.NewInt(0), + Amount1In: big.NewInt(0), + Amount0Out: big.NewInt(0), + Amount1Out: big.NewInt(0), + }, + wantErr: ErrZeroAmounts, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.event.Validate() + if err != tt.wantErr { + t.Errorf("Validate() error = %v, want %v", err, tt.wantErr) + } + }) + } +} + +func TestSwapEvent_GetInputToken(t *testing.T) { + tests := []struct { + name string + event *SwapEvent + wantToken common.Address + wantAmount *big.Int + }{ + { + name: "token0 input", + event: &SwapEvent{ + Token0: common.HexToAddress("0x1111"), + Token1: common.HexToAddress("0x2222"), + Amount0In: big.NewInt(1000), + Amount1In: big.NewInt(0), + Amount0Out: big.NewInt(0), + Amount1Out: big.NewInt(500), + }, + wantToken: common.HexToAddress("0x1111"), + wantAmount: big.NewInt(1000), + }, + { + name: "token1 input", + event: &SwapEvent{ + Token0: common.HexToAddress("0x1111"), + Token1: common.HexToAddress("0x2222"), + Amount0In: big.NewInt(0), + Amount1In: big.NewInt(500), + Amount0Out: big.NewInt(1000), + Amount1Out: big.NewInt(0), + }, + wantToken: common.HexToAddress("0x2222"), + wantAmount: big.NewInt(500), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, amount := tt.event.GetInputToken() + if token != tt.wantToken { + t.Errorf("GetInputToken() token = %v, want %v", token, tt.wantToken) + } + if amount.Cmp(tt.wantAmount) != 0 { + t.Errorf("GetInputToken() amount = %v, want %v", amount, tt.wantAmount) + } + }) + } +} + +func TestSwapEvent_GetOutputToken(t *testing.T) { + tests := []struct { + name string + event *SwapEvent + wantToken common.Address + wantAmount *big.Int + }{ + { + name: "token0 output", + event: &SwapEvent{ + Token0: common.HexToAddress("0x1111"), + Token1: common.HexToAddress("0x2222"), + Amount0In: big.NewInt(0), + Amount1In: big.NewInt(500), + Amount0Out: big.NewInt(1000), + Amount1Out: big.NewInt(0), + }, + wantToken: common.HexToAddress("0x1111"), + wantAmount: big.NewInt(1000), + }, + { + name: "token1 output", + event: &SwapEvent{ + Token0: common.HexToAddress("0x1111"), + Token1: common.HexToAddress("0x2222"), + Amount0In: big.NewInt(1000), + Amount1In: big.NewInt(0), + Amount0Out: big.NewInt(0), + Amount1Out: big.NewInt(500), + }, + wantToken: common.HexToAddress("0x2222"), + wantAmount: big.NewInt(500), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, amount := tt.event.GetOutputToken() + if token != tt.wantToken { + t.Errorf("GetOutputToken() token = %v, want %v", token, tt.wantToken) + } + if amount.Cmp(tt.wantAmount) != 0 { + t.Errorf("GetOutputToken() amount = %v, want %v", amount, tt.wantAmount) + } + }) + } +} + +func Test_isZero(t *testing.T) { + tests := []struct { + name string + n *big.Int + want bool + }{ + { + name: "nil is zero", + n: nil, + want: true, + }, + { + name: "zero value is zero", + n: big.NewInt(0), + want: true, + }, + { + name: "positive value is not zero", + n: big.NewInt(100), + want: false, + }, + { + name: "negative value is not zero", + n: big.NewInt(-100), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isZero(tt.n); got != tt.want { + t.Errorf("isZero() = %v, want %v", got, tt.want) + } + }) + } +}