Files
mev-beta/pkg/transport/tcp_transport.go
Krypto Kajun 850223a953 fix(multicall): resolve critical multicall parsing corruption issues
- Added comprehensive bounds checking to prevent buffer overruns in multicall parsing
- Implemented graduated validation system (Strict/Moderate/Permissive) to reduce false positives
- Added LRU caching system for address validation with 10-minute TTL
- Enhanced ABI decoder with missing Universal Router and Arbitrum-specific DEX signatures
- Fixed duplicate function declarations and import conflicts across multiple files
- Added error recovery mechanisms with multiple fallback strategies
- Updated tests to handle new validation behavior for suspicious addresses
- Fixed parser test expectations for improved validation system
- Applied gofmt formatting fixes to ensure code style compliance
- Fixed mutex copying issues in monitoring package by introducing MetricsSnapshot
- Resolved critical security vulnerabilities in heuristic address extraction
- Progress: Updated TODO audit from 10% to 35% complete

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 00:12:55 -05:00

452 lines
10 KiB
Go

package transport
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net"
"sync"
"time"
)
// TCPTransport implements TCP transport for remote communication
type TCPTransport struct {
address string
port int
listener net.Listener
connections map[string]net.Conn
metrics TransportMetrics
connected bool
isServer bool
receiveChan chan *Message
tlsConfig *tls.Config
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
retryConfig RetryConfig
}
// NewTCPTransport creates a new TCP transport
func NewTCPTransport(address string, port int, isServer bool) *TCPTransport {
ctx, cancel := context.WithCancel(context.Background())
return &TCPTransport{
address: address,
port: port,
connections: make(map[string]net.Conn),
metrics: TransportMetrics{},
isServer: isServer,
receiveChan: make(chan *Message, 1000),
ctx: ctx,
cancel: cancel,
retryConfig: RetryConfig{
MaxRetries: 3,
InitialDelay: time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
Jitter: true,
},
}
}
// SetTLSConfig configures TLS for secure communication
func (tt *TCPTransport) SetTLSConfig(config *tls.Config) {
tt.tlsConfig = config
}
// SetRetryConfig configures retry behavior
func (tt *TCPTransport) SetRetryConfig(config RetryConfig) {
tt.retryConfig = config
}
// Connect establishes the TCP connection
func (tt *TCPTransport) Connect(ctx context.Context) error {
tt.mu.Lock()
defer tt.mu.Unlock()
if tt.connected {
return nil
}
if tt.isServer {
return tt.startServer()
} else {
return tt.connectToServer(ctx)
}
}
// Disconnect closes the TCP connection
func (tt *TCPTransport) Disconnect(ctx context.Context) error {
tt.mu.Lock()
defer tt.mu.Unlock()
if !tt.connected {
return nil
}
tt.cancel()
if tt.isServer && tt.listener != nil {
tt.listener.Close()
}
// Close all connections
for id, conn := range tt.connections {
conn.Close()
delete(tt.connections, id)
}
close(tt.receiveChan)
tt.connected = false
tt.metrics.Connections = 0
return nil
}
// Send transmits a message through TCP
func (tt *TCPTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
tt.mu.RLock()
if !tt.connected {
tt.mu.RUnlock()
tt.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Serialize message
data, err := json.Marshal(msg)
if err != nil {
tt.mu.RUnlock()
tt.metrics.Errors++
return fmt.Errorf("failed to marshal message: %w", err)
}
// Add length prefix for framing
frame := fmt.Sprintf("%d\n%s", len(data), data)
frameBytes := []byte(frame)
// Send to all connections with retry
var sendErr error
connectionCount := len(tt.connections)
tt.mu.RUnlock()
if connectionCount == 0 {
tt.metrics.Errors++
return fmt.Errorf("no active connections")
}
tt.mu.RLock()
for connID, conn := range tt.connections {
if err := tt.sendWithRetry(ctx, conn, frameBytes); err != nil {
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
// Remove failed connection
go tt.removeConnection(connID)
}
}
tt.mu.RUnlock()
if sendErr == nil {
tt.updateSendMetrics(msg, time.Since(start))
} else {
tt.metrics.Errors++
}
return sendErr
}
// Receive returns a channel for receiving messages
func (tt *TCPTransport) Receive(ctx context.Context) (<-chan *Message, error) {
tt.mu.RLock()
defer tt.mu.RUnlock()
if !tt.connected {
return nil, fmt.Errorf("transport not connected")
}
return tt.receiveChan, nil
}
// Health returns the health status of the transport
func (tt *TCPTransport) Health() ComponentHealth {
tt.mu.RLock()
defer tt.mu.RUnlock()
status := "unhealthy"
var responseTime time.Duration
if tt.connected {
if len(tt.connections) > 0 {
status = "healthy"
responseTime = time.Millisecond * 10 // Estimate for TCP
} else {
status = "degraded" // Connected but no active connections
}
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: responseTime,
ErrorCount: tt.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (tt *TCPTransport) GetMetrics() TransportMetrics {
tt.mu.RLock()
defer tt.mu.RUnlock()
return TransportMetrics{
BytesSent: tt.metrics.BytesSent,
BytesReceived: tt.metrics.BytesReceived,
MessagesSent: tt.metrics.MessagesSent,
MessagesReceived: tt.metrics.MessagesReceived,
Connections: len(tt.connections),
Errors: tt.metrics.Errors,
Latency: tt.metrics.Latency,
}
}
// Private helper methods
func (tt *TCPTransport) startServer() error {
addr := fmt.Sprintf("%s:%d", tt.address, tt.port)
var listener net.Listener
var err error
if tt.tlsConfig != nil {
listener, err = tls.Listen("tcp", addr, tt.tlsConfig)
} else {
listener, err = net.Listen("tcp", addr)
}
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", addr, err)
}
tt.listener = listener
tt.connected = true
// Start accepting connections
go tt.acceptConnections()
return nil
}
func (tt *TCPTransport) connectToServer(ctx context.Context) error {
addr := net.JoinHostPort(tt.address, fmt.Sprintf("%d", tt.port))
var conn net.Conn
var err error
// Retry connection with exponential backoff
delay := tt.retryConfig.InitialDelay
for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ {
if tt.tlsConfig != nil {
conn, err = tls.Dial("tcp", addr, tt.tlsConfig)
} else {
conn, err = net.Dial("tcp", addr)
}
if err == nil {
break
}
if attempt == tt.retryConfig.MaxRetries {
return fmt.Errorf("failed to connect to %s after %d attempts: %w", addr, attempt+1, err)
}
// Wait with exponential backoff
select {
case <-time.After(delay):
delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor)
if delay > tt.retryConfig.MaxDelay {
delay = tt.retryConfig.MaxDelay
}
// Add jitter if enabled
if tt.retryConfig.Jitter {
jitter := time.Duration(float64(delay) * 0.1)
jitterFactor := float64(2*time.Now().UnixNano()%1000)/1000.0 - 1
delay += time.Duration(float64(jitter) * jitterFactor)
}
case <-ctx.Done():
return ctx.Err()
}
}
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
tt.connections[connID] = conn
tt.connected = true
tt.metrics.Connections = 1
// Start receiving from server
go tt.handleConnection(connID, conn)
return nil
}
func (tt *TCPTransport) acceptConnections() {
for {
select {
case <-tt.ctx.Done():
return
default:
conn, err := tt.listener.Accept()
if err != nil {
if tt.ctx.Err() != nil {
return // Context cancelled
}
tt.metrics.Errors++
continue
}
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
tt.mu.Lock()
tt.connections[connID] = conn
tt.metrics.Connections = len(tt.connections)
tt.mu.Unlock()
go tt.handleConnection(connID, conn)
}
}
}
func (tt *TCPTransport) handleConnection(connID string, conn net.Conn) {
defer tt.removeConnection(connID)
buffer := make([]byte, 4096)
var messageBuffer []byte
for {
select {
case <-tt.ctx.Done():
return
default:
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, err := conn.Read(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // Continue on timeout
}
return // Connection closed or error
}
messageBuffer = append(messageBuffer, buffer[:n]...)
// Process complete messages
for {
msg, remaining, err := ExtractMessage(messageBuffer)
if err != nil {
return // Invalid message format
}
if msg == nil {
break // No complete message yet
}
// Deliver message
select {
case tt.receiveChan <- msg:
tt.updateReceiveMetrics(msg)
case <-tt.ctx.Done():
return
default:
// Channel full, drop message
tt.metrics.Errors++
}
messageBuffer = remaining
}
}
}
}
func (tt *TCPTransport) sendWithRetry(ctx context.Context, conn net.Conn, data []byte) error {
delay := tt.retryConfig.InitialDelay
for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ {
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
_, err := conn.Write(data)
if err == nil {
return nil
}
if attempt == tt.retryConfig.MaxRetries {
return err
}
// Wait with exponential backoff
select {
case <-time.After(delay):
delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor)
if delay > tt.retryConfig.MaxDelay {
delay = tt.retryConfig.MaxDelay
}
case <-ctx.Done():
return ctx.Err()
}
}
return fmt.Errorf("max retries exceeded")
}
func (tt *TCPTransport) removeConnection(connID string) {
tt.mu.Lock()
defer tt.mu.Unlock()
if conn, exists := tt.connections[connID]; exists {
conn.Close()
delete(tt.connections, connID)
tt.metrics.Connections = len(tt.connections)
}
}
func (tt *TCPTransport) updateSendMetrics(msg *Message, latency time.Duration) {
tt.mu.Lock()
defer tt.mu.Unlock()
tt.metrics.MessagesSent++
tt.metrics.Latency = latency
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
tt.metrics.BytesSent += messageSize
}
func (tt *TCPTransport) updateReceiveMetrics(msg *Message) {
tt.mu.Lock()
defer tt.mu.Unlock()
tt.metrics.MessagesReceived++
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
tt.metrics.BytesReceived += messageSize
}
// GetAddress returns the transport address (for testing/debugging)
func (tt *TCPTransport) GetAddress() string {
return fmt.Sprintf("%s:%d", tt.address, tt.port)
}
// GetConnectionCount returns the number of active connections
func (tt *TCPTransport) GetConnectionCount() int {
tt.mu.RLock()
defer tt.mu.RUnlock()
return len(tt.connections)
}
// IsSecure returns whether TLS is enabled
func (tt *TCPTransport) IsSecure() bool {
return tt.tlsConfig != nil
}