Files
mev-beta/pkg/transport/unix_transport.go
Krypto Kajun 850223a953 fix(multicall): resolve critical multicall parsing corruption issues
- Added comprehensive bounds checking to prevent buffer overruns in multicall parsing
- Implemented graduated validation system (Strict/Moderate/Permissive) to reduce false positives
- Added LRU caching system for address validation with 10-minute TTL
- Enhanced ABI decoder with missing Universal Router and Arbitrum-specific DEX signatures
- Fixed duplicate function declarations and import conflicts across multiple files
- Added error recovery mechanisms with multiple fallback strategies
- Updated tests to handle new validation behavior for suspicious addresses
- Fixed parser test expectations for improved validation system
- Applied gofmt formatting fixes to ensure code style compliance
- Fixed mutex copying issues in monitoring package by introducing MetricsSnapshot
- Resolved critical security vulnerabilities in heuristic address extraction
- Progress: Updated TODO audit from 10% to 35% complete

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 00:12:55 -05:00

360 lines
7.9 KiB
Go

package transport
import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"sync"
"time"
)
// UnixSocketTransport implements Unix socket transport for local IPC
type UnixSocketTransport struct {
socketPath string
listener net.Listener
connections map[string]net.Conn
metrics TransportMetrics
connected bool
isServer bool
receiveChan chan *Message
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
// NewUnixSocketTransport creates a new Unix socket transport
func NewUnixSocketTransport(socketPath string, isServer bool) *UnixSocketTransport {
ctx, cancel := context.WithCancel(context.Background())
return &UnixSocketTransport{
socketPath: socketPath,
connections: make(map[string]net.Conn),
metrics: TransportMetrics{},
isServer: isServer,
receiveChan: make(chan *Message, 1000),
ctx: ctx,
cancel: cancel,
}
}
// Connect establishes the Unix socket connection
func (ut *UnixSocketTransport) Connect(ctx context.Context) error {
ut.mu.Lock()
defer ut.mu.Unlock()
if ut.connected {
return nil
}
if ut.isServer {
return ut.startServer()
} else {
return ut.connectToServer()
}
}
// Disconnect closes the Unix socket connection
func (ut *UnixSocketTransport) Disconnect(ctx context.Context) error {
ut.mu.Lock()
defer ut.mu.Unlock()
if !ut.connected {
return nil
}
ut.cancel()
if ut.isServer && ut.listener != nil {
ut.listener.Close()
// Remove socket file
os.Remove(ut.socketPath)
}
// Close all connections
for id, conn := range ut.connections {
conn.Close()
delete(ut.connections, id)
}
close(ut.receiveChan)
ut.connected = false
ut.metrics.Connections = 0
return nil
}
// Send transmits a message through the Unix socket
func (ut *UnixSocketTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
ut.mu.RLock()
if !ut.connected {
ut.mu.RUnlock()
ut.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Serialize message
data, err := json.Marshal(msg)
if err != nil {
ut.mu.RUnlock()
ut.metrics.Errors++
return fmt.Errorf("failed to marshal message: %w", err)
}
// Add length prefix for framing
frame := fmt.Sprintf("%d\n%s", len(data), data)
frameBytes := []byte(frame)
// Send to all connections
var sendErr error
connectionCount := len(ut.connections)
ut.mu.RUnlock()
if connectionCount == 0 {
ut.metrics.Errors++
return fmt.Errorf("no active connections")
}
ut.mu.RLock()
for connID, conn := range ut.connections {
if err := ut.sendToConnection(conn, frameBytes); err != nil {
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
// Remove failed connection
go ut.removeConnection(connID)
}
}
ut.mu.RUnlock()
if sendErr == nil {
ut.updateSendMetrics(msg, time.Since(start))
} else {
ut.metrics.Errors++
}
return sendErr
}
// Receive returns a channel for receiving messages
func (ut *UnixSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) {
ut.mu.RLock()
defer ut.mu.RUnlock()
if !ut.connected {
return nil, fmt.Errorf("transport not connected")
}
return ut.receiveChan, nil
}
// Health returns the health status of the transport
func (ut *UnixSocketTransport) Health() ComponentHealth {
ut.mu.RLock()
defer ut.mu.RUnlock()
status := "unhealthy"
if ut.connected {
if len(ut.connections) > 0 {
status = "healthy"
} else {
status = "degraded" // Connected but no active connections
}
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: time.Millisecond, // Fast for local sockets
ErrorCount: ut.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (ut *UnixSocketTransport) GetMetrics() TransportMetrics {
ut.mu.RLock()
defer ut.mu.RUnlock()
return TransportMetrics{
BytesSent: ut.metrics.BytesSent,
BytesReceived: ut.metrics.BytesReceived,
MessagesSent: ut.metrics.MessagesSent,
MessagesReceived: ut.metrics.MessagesReceived,
Connections: len(ut.connections),
Errors: ut.metrics.Errors,
Latency: ut.metrics.Latency,
}
}
// Private helper methods
func (ut *UnixSocketTransport) startServer() error {
// Remove existing socket file
os.Remove(ut.socketPath)
listener, err := net.Listen("unix", ut.socketPath)
if err != nil {
return fmt.Errorf("failed to listen on socket %s: %w", ut.socketPath, err)
}
ut.listener = listener
ut.connected = true
// Start accepting connections
go ut.acceptConnections()
return nil
}
func (ut *UnixSocketTransport) connectToServer() error {
conn, err := net.Dial("unix", ut.socketPath)
if err != nil {
return fmt.Errorf("failed to connect to socket %s: %w", ut.socketPath, err)
}
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
ut.connections[connID] = conn
ut.connected = true
ut.metrics.Connections = 1
// Start receiving from server
go ut.handleConnection(connID, conn)
return nil
}
func (ut *UnixSocketTransport) acceptConnections() {
for {
select {
case <-ut.ctx.Done():
return
default:
conn, err := ut.listener.Accept()
if err != nil {
if ut.ctx.Err() != nil {
return // Context cancelled
}
ut.metrics.Errors++
continue
}
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
ut.mu.Lock()
ut.connections[connID] = conn
ut.metrics.Connections = len(ut.connections)
ut.mu.Unlock()
go ut.handleConnection(connID, conn)
}
}
}
func (ut *UnixSocketTransport) handleConnection(connID string, conn net.Conn) {
defer ut.removeConnection(connID)
buffer := make([]byte, 4096)
var messageBuffer []byte
for {
select {
case <-ut.ctx.Done():
return
default:
conn.SetReadDeadline(time.Now().Add(time.Second)) // Non-blocking read
n, err := conn.Read(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // Continue on timeout
}
return // Connection closed or error
}
messageBuffer = append(messageBuffer, buffer[:n]...)
// Process complete messages
for {
msg, remaining, err := ExtractMessage(messageBuffer)
if err != nil {
return // Invalid message format
}
if msg == nil {
break // No complete message yet
}
// Deliver message
select {
case ut.receiveChan <- msg:
ut.updateReceiveMetrics(msg)
case <-ut.ctx.Done():
return
default:
// Channel full, drop message
ut.metrics.Errors++
}
messageBuffer = remaining
}
}
}
}
func (ut *UnixSocketTransport) sendToConnection(conn net.Conn, data []byte) error {
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
_, err := conn.Write(data)
return err
}
func (ut *UnixSocketTransport) removeConnection(connID string) {
ut.mu.Lock()
defer ut.mu.Unlock()
if conn, exists := ut.connections[connID]; exists {
conn.Close()
delete(ut.connections, connID)
ut.metrics.Connections = len(ut.connections)
}
}
func (ut *UnixSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) {
ut.mu.Lock()
defer ut.mu.Unlock()
ut.metrics.MessagesSent++
ut.metrics.Latency = latency
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
ut.metrics.BytesSent += messageSize
}
func (ut *UnixSocketTransport) updateReceiveMetrics(msg *Message) {
ut.mu.Lock()
defer ut.mu.Unlock()
ut.metrics.MessagesReceived++
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
ut.metrics.BytesReceived += messageSize
}
// GetSocketPath returns the socket path (for testing/debugging)
func (ut *UnixSocketTransport) GetSocketPath() string {
return ut.socketPath
}
// GetConnectionCount returns the number of active connections
func (ut *UnixSocketTransport) GetConnectionCount() int {
ut.mu.RLock()
defer ut.mu.RUnlock()
return len(ut.connections)
}