package transport import ( "context" "encoding/json" "fmt" "net" "os" "sync" "time" ) // UnixSocketTransport implements Unix socket transport for local IPC type UnixSocketTransport struct { socketPath string listener net.Listener connections map[string]net.Conn metrics TransportMetrics connected bool isServer bool receiveChan chan *Message mu sync.RWMutex ctx context.Context cancel context.CancelFunc } // NewUnixSocketTransport creates a new Unix socket transport func NewUnixSocketTransport(socketPath string, isServer bool) *UnixSocketTransport { ctx, cancel := context.WithCancel(context.Background()) return &UnixSocketTransport{ socketPath: socketPath, connections: make(map[string]net.Conn), metrics: TransportMetrics{}, isServer: isServer, receiveChan: make(chan *Message, 1000), ctx: ctx, cancel: cancel, } } // Connect establishes the Unix socket connection func (ut *UnixSocketTransport) Connect(ctx context.Context) error { ut.mu.Lock() defer ut.mu.Unlock() if ut.connected { return nil } if ut.isServer { return ut.startServer() } else { return ut.connectToServer() } } // Disconnect closes the Unix socket connection func (ut *UnixSocketTransport) Disconnect(ctx context.Context) error { ut.mu.Lock() defer ut.mu.Unlock() if !ut.connected { return nil } ut.cancel() if ut.isServer && ut.listener != nil { ut.listener.Close() // Remove socket file os.Remove(ut.socketPath) } // Close all connections for id, conn := range ut.connections { conn.Close() delete(ut.connections, id) } close(ut.receiveChan) ut.connected = false ut.metrics.Connections = 0 return nil } // Send transmits a message through the Unix socket func (ut *UnixSocketTransport) Send(ctx context.Context, msg *Message) error { start := time.Now() ut.mu.RLock() if !ut.connected { ut.mu.RUnlock() ut.metrics.Errors++ return fmt.Errorf("transport not connected") } // Serialize message data, err := json.Marshal(msg) if err != nil { ut.mu.RUnlock() ut.metrics.Errors++ return fmt.Errorf("failed to marshal message: %w", err) } // Add length prefix for framing frame := fmt.Sprintf("%d\n%s", len(data), data) frameBytes := []byte(frame) // Send to all connections var sendErr error connectionCount := len(ut.connections) ut.mu.RUnlock() if connectionCount == 0 { ut.metrics.Errors++ return fmt.Errorf("no active connections") } ut.mu.RLock() for connID, conn := range ut.connections { if err := ut.sendToConnection(conn, frameBytes); err != nil { sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err) // Remove failed connection go ut.removeConnection(connID) } } ut.mu.RUnlock() if sendErr == nil { ut.updateSendMetrics(msg, time.Since(start)) } else { ut.metrics.Errors++ } return sendErr } // Receive returns a channel for receiving messages func (ut *UnixSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) { ut.mu.RLock() defer ut.mu.RUnlock() if !ut.connected { return nil, fmt.Errorf("transport not connected") } return ut.receiveChan, nil } // Health returns the health status of the transport func (ut *UnixSocketTransport) Health() ComponentHealth { ut.mu.RLock() defer ut.mu.RUnlock() status := "unhealthy" if ut.connected { if len(ut.connections) > 0 { status = "healthy" } else { status = "degraded" // Connected but no active connections } } return ComponentHealth{ Status: status, LastCheck: time.Now(), ResponseTime: time.Millisecond, // Fast for local sockets ErrorCount: ut.metrics.Errors, } } // GetMetrics returns transport-specific metrics func (ut *UnixSocketTransport) GetMetrics() TransportMetrics { ut.mu.RLock() defer ut.mu.RUnlock() return TransportMetrics{ BytesSent: ut.metrics.BytesSent, BytesReceived: ut.metrics.BytesReceived, MessagesSent: ut.metrics.MessagesSent, MessagesReceived: ut.metrics.MessagesReceived, Connections: len(ut.connections), Errors: ut.metrics.Errors, Latency: ut.metrics.Latency, } } // Private helper methods func (ut *UnixSocketTransport) startServer() error { // Remove existing socket file os.Remove(ut.socketPath) listener, err := net.Listen("unix", ut.socketPath) if err != nil { return fmt.Errorf("failed to listen on socket %s: %w", ut.socketPath, err) } ut.listener = listener ut.connected = true // Start accepting connections go ut.acceptConnections() return nil } func (ut *UnixSocketTransport) connectToServer() error { conn, err := net.Dial("unix", ut.socketPath) if err != nil { return fmt.Errorf("failed to connect to socket %s: %w", ut.socketPath, err) } connID := fmt.Sprintf("client_%d", time.Now().UnixNano()) ut.connections[connID] = conn ut.connected = true ut.metrics.Connections = 1 // Start receiving from server go ut.handleConnection(connID, conn) return nil } func (ut *UnixSocketTransport) acceptConnections() { for { select { case <-ut.ctx.Done(): return default: conn, err := ut.listener.Accept() if err != nil { if ut.ctx.Err() != nil { return // Context cancelled } ut.metrics.Errors++ continue } connID := fmt.Sprintf("server_%d", time.Now().UnixNano()) ut.mu.Lock() ut.connections[connID] = conn ut.metrics.Connections = len(ut.connections) ut.mu.Unlock() go ut.handleConnection(connID, conn) } } } func (ut *UnixSocketTransport) handleConnection(connID string, conn net.Conn) { defer ut.removeConnection(connID) buffer := make([]byte, 4096) var messageBuffer []byte for { select { case <-ut.ctx.Done(): return default: conn.SetReadDeadline(time.Now().Add(time.Second)) // Non-blocking read n, err := conn.Read(buffer) if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { continue // Continue on timeout } return // Connection closed or error } messageBuffer = append(messageBuffer, buffer[:n]...) // Process complete messages for { msg, remaining, err := ExtractMessage(messageBuffer) if err != nil { return // Invalid message format } if msg == nil { break // No complete message yet } // Deliver message select { case ut.receiveChan <- msg: ut.updateReceiveMetrics(msg) case <-ut.ctx.Done(): return default: // Channel full, drop message ut.metrics.Errors++ } messageBuffer = remaining } } } } func (ut *UnixSocketTransport) sendToConnection(conn net.Conn, data []byte) error { conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) _, err := conn.Write(data) return err } func (ut *UnixSocketTransport) removeConnection(connID string) { ut.mu.Lock() defer ut.mu.Unlock() if conn, exists := ut.connections[connID]; exists { conn.Close() delete(ut.connections, connID) ut.metrics.Connections = len(ut.connections) } } func (ut *UnixSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) { ut.mu.Lock() defer ut.mu.Unlock() ut.metrics.MessagesSent++ ut.metrics.Latency = latency // Estimate message size messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) if msg.Data != nil { messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) } ut.metrics.BytesSent += messageSize } func (ut *UnixSocketTransport) updateReceiveMetrics(msg *Message) { ut.mu.Lock() defer ut.mu.Unlock() ut.metrics.MessagesReceived++ // Estimate message size messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) if msg.Data != nil { messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) } ut.metrics.BytesReceived += messageSize } // GetSocketPath returns the socket path (for testing/debugging) func (ut *UnixSocketTransport) GetSocketPath() string { return ut.socketPath } // GetConnectionCount returns the number of active connections func (ut *UnixSocketTransport) GetConnectionCount() int { ut.mu.RLock() defer ut.mu.RUnlock() return len(ut.connections) }