package transport import ( "context" "crypto/tls" "encoding/json" "fmt" "net" "sync" "time" ) // TCPTransport implements TCP transport for remote communication type TCPTransport struct { address string port int listener net.Listener connections map[string]net.Conn metrics TransportMetrics connected bool isServer bool receiveChan chan *Message tlsConfig *tls.Config mu sync.RWMutex ctx context.Context cancel context.CancelFunc retryConfig RetryConfig } // NewTCPTransport creates a new TCP transport func NewTCPTransport(address string, port int, isServer bool) *TCPTransport { ctx, cancel := context.WithCancel(context.Background()) return &TCPTransport{ address: address, port: port, connections: make(map[string]net.Conn), metrics: TransportMetrics{}, isServer: isServer, receiveChan: make(chan *Message, 1000), ctx: ctx, cancel: cancel, retryConfig: RetryConfig{ MaxRetries: 3, InitialDelay: time.Second, MaxDelay: 30 * time.Second, BackoffFactor: 2.0, Jitter: true, }, } } // SetTLSConfig configures TLS for secure communication func (tt *TCPTransport) SetTLSConfig(config *tls.Config) { tt.tlsConfig = config } // SetRetryConfig configures retry behavior func (tt *TCPTransport) SetRetryConfig(config RetryConfig) { tt.retryConfig = config } // Connect establishes the TCP connection func (tt *TCPTransport) Connect(ctx context.Context) error { tt.mu.Lock() defer tt.mu.Unlock() if tt.connected { return nil } if tt.isServer { return tt.startServer() } else { return tt.connectToServer(ctx) } } // Disconnect closes the TCP connection func (tt *TCPTransport) Disconnect(ctx context.Context) error { tt.mu.Lock() defer tt.mu.Unlock() if !tt.connected { return nil } tt.cancel() if tt.isServer && tt.listener != nil { tt.listener.Close() } // Close all connections for id, conn := range tt.connections { conn.Close() delete(tt.connections, id) } close(tt.receiveChan) tt.connected = false tt.metrics.Connections = 0 return nil } // Send transmits a message through TCP func (tt *TCPTransport) Send(ctx context.Context, msg *Message) error { start := time.Now() tt.mu.RLock() if !tt.connected { tt.mu.RUnlock() tt.metrics.Errors++ return fmt.Errorf("transport not connected") } // Serialize message data, err := json.Marshal(msg) if err != nil { tt.mu.RUnlock() tt.metrics.Errors++ return fmt.Errorf("failed to marshal message: %w", err) } // Add length prefix for framing frame := fmt.Sprintf("%d\n%s", len(data), data) frameBytes := []byte(frame) // Send to all connections with retry var sendErr error connectionCount := len(tt.connections) tt.mu.RUnlock() if connectionCount == 0 { tt.metrics.Errors++ return fmt.Errorf("no active connections") } tt.mu.RLock() for connID, conn := range tt.connections { if err := tt.sendWithRetry(ctx, conn, frameBytes); err != nil { sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err) // Remove failed connection go tt.removeConnection(connID) } } tt.mu.RUnlock() if sendErr == nil { tt.updateSendMetrics(msg, time.Since(start)) } else { tt.metrics.Errors++ } return sendErr } // Receive returns a channel for receiving messages func (tt *TCPTransport) Receive(ctx context.Context) (<-chan *Message, error) { tt.mu.RLock() defer tt.mu.RUnlock() if !tt.connected { return nil, fmt.Errorf("transport not connected") } return tt.receiveChan, nil } // Health returns the health status of the transport func (tt *TCPTransport) Health() ComponentHealth { tt.mu.RLock() defer tt.mu.RUnlock() status := "unhealthy" var responseTime time.Duration if tt.connected { if len(tt.connections) > 0 { status = "healthy" responseTime = time.Millisecond * 10 // Estimate for TCP } else { status = "degraded" // Connected but no active connections } } return ComponentHealth{ Status: status, LastCheck: time.Now(), ResponseTime: responseTime, ErrorCount: tt.metrics.Errors, } } // GetMetrics returns transport-specific metrics func (tt *TCPTransport) GetMetrics() TransportMetrics { tt.mu.RLock() defer tt.mu.RUnlock() return TransportMetrics{ BytesSent: tt.metrics.BytesSent, BytesReceived: tt.metrics.BytesReceived, MessagesSent: tt.metrics.MessagesSent, MessagesReceived: tt.metrics.MessagesReceived, Connections: len(tt.connections), Errors: tt.metrics.Errors, Latency: tt.metrics.Latency, } } // Private helper methods func (tt *TCPTransport) startServer() error { addr := fmt.Sprintf("%s:%d", tt.address, tt.port) var listener net.Listener var err error if tt.tlsConfig != nil { listener, err = tls.Listen("tcp", addr, tt.tlsConfig) } else { listener, err = net.Listen("tcp", addr) } if err != nil { return fmt.Errorf("failed to listen on %s: %w", addr, err) } tt.listener = listener tt.connected = true // Start accepting connections go tt.acceptConnections() return nil } func (tt *TCPTransport) connectToServer(ctx context.Context) error { addr := net.JoinHostPort(tt.address, fmt.Sprintf("%d", tt.port)) var conn net.Conn var err error // Retry connection with exponential backoff delay := tt.retryConfig.InitialDelay for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ { if tt.tlsConfig != nil { conn, err = tls.Dial("tcp", addr, tt.tlsConfig) } else { conn, err = net.Dial("tcp", addr) } if err == nil { break } if attempt == tt.retryConfig.MaxRetries { return fmt.Errorf("failed to connect to %s after %d attempts: %w", addr, attempt+1, err) } // Wait with exponential backoff select { case <-time.After(delay): delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor) if delay > tt.retryConfig.MaxDelay { delay = tt.retryConfig.MaxDelay } // Add jitter if enabled if tt.retryConfig.Jitter { jitter := time.Duration(float64(delay) * 0.1) jitterFactor := float64(2*time.Now().UnixNano()%1000)/1000.0 - 1 delay += time.Duration(float64(jitter) * jitterFactor) } case <-ctx.Done(): return ctx.Err() } } connID := fmt.Sprintf("client_%d", time.Now().UnixNano()) tt.connections[connID] = conn tt.connected = true tt.metrics.Connections = 1 // Start receiving from server go tt.handleConnection(connID, conn) return nil } func (tt *TCPTransport) acceptConnections() { for { select { case <-tt.ctx.Done(): return default: conn, err := tt.listener.Accept() if err != nil { if tt.ctx.Err() != nil { return // Context cancelled } tt.metrics.Errors++ continue } connID := fmt.Sprintf("server_%d", time.Now().UnixNano()) tt.mu.Lock() tt.connections[connID] = conn tt.metrics.Connections = len(tt.connections) tt.mu.Unlock() go tt.handleConnection(connID, conn) } } } func (tt *TCPTransport) handleConnection(connID string, conn net.Conn) { defer tt.removeConnection(connID) buffer := make([]byte, 4096) var messageBuffer []byte for { select { case <-tt.ctx.Done(): return default: conn.SetReadDeadline(time.Now().Add(30 * time.Second)) n, err := conn.Read(buffer) if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { continue // Continue on timeout } return // Connection closed or error } messageBuffer = append(messageBuffer, buffer[:n]...) // Process complete messages for { msg, remaining, err := ExtractMessage(messageBuffer) if err != nil { return // Invalid message format } if msg == nil { break // No complete message yet } // Deliver message select { case tt.receiveChan <- msg: tt.updateReceiveMetrics(msg) case <-tt.ctx.Done(): return default: // Channel full, drop message tt.metrics.Errors++ } messageBuffer = remaining } } } } func (tt *TCPTransport) sendWithRetry(ctx context.Context, conn net.Conn, data []byte) error { delay := tt.retryConfig.InitialDelay for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ { conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) _, err := conn.Write(data) if err == nil { return nil } if attempt == tt.retryConfig.MaxRetries { return err } // Wait with exponential backoff select { case <-time.After(delay): delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor) if delay > tt.retryConfig.MaxDelay { delay = tt.retryConfig.MaxDelay } case <-ctx.Done(): return ctx.Err() } } return fmt.Errorf("max retries exceeded") } func (tt *TCPTransport) removeConnection(connID string) { tt.mu.Lock() defer tt.mu.Unlock() if conn, exists := tt.connections[connID]; exists { conn.Close() delete(tt.connections, connID) tt.metrics.Connections = len(tt.connections) } } func (tt *TCPTransport) updateSendMetrics(msg *Message, latency time.Duration) { tt.mu.Lock() defer tt.mu.Unlock() tt.metrics.MessagesSent++ tt.metrics.Latency = latency // Estimate message size messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) if msg.Data != nil { messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) } tt.metrics.BytesSent += messageSize } func (tt *TCPTransport) updateReceiveMetrics(msg *Message) { tt.mu.Lock() defer tt.mu.Unlock() tt.metrics.MessagesReceived++ // Estimate message size messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) if msg.Data != nil { messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) } tt.metrics.BytesReceived += messageSize } // GetAddress returns the transport address (for testing/debugging) func (tt *TCPTransport) GetAddress() string { return fmt.Sprintf("%s:%d", tt.address, tt.port) } // GetConnectionCount returns the number of active connections func (tt *TCPTransport) GetConnectionCount() int { tt.mu.RLock() defer tt.mu.RUnlock() return len(tt.connections) } // IsSecure returns whether TLS is enabled func (tt *TCPTransport) IsSecure() bool { return tt.tlsConfig != nil }