package transport import ( "context" "encoding/json" "fmt" "net/http" "sync" "time" "github.com/gorilla/websocket" ) // WebSocketTransport implements WebSocket transport for real-time monitoring type WebSocketTransport struct { address string port int path string upgrader websocket.Upgrader connections map[string]*websocket.Conn metrics TransportMetrics connected bool isServer bool receiveChan chan *Message server *http.Server mu sync.RWMutex ctx context.Context cancel context.CancelFunc pingInterval time.Duration pongTimeout time.Duration } // NewWebSocketTransport creates a new WebSocket transport func NewWebSocketTransport(address string, port int, path string, isServer bool) *WebSocketTransport { ctx, cancel := context.WithCancel(context.Background()) return &WebSocketTransport{ address: address, port: port, path: path, connections: make(map[string]*websocket.Conn), metrics: TransportMetrics{}, isServer: isServer, receiveChan: make(chan *Message, 1000), ctx: ctx, cancel: cancel, upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // Allow all origins for now }, ReadBufferSize: 1024, WriteBufferSize: 1024, }, pingInterval: 30 * time.Second, pongTimeout: 10 * time.Second, } } // SetPingPongSettings configures WebSocket ping/pong settings func (wt *WebSocketTransport) SetPingPongSettings(pingInterval, pongTimeout time.Duration) { wt.pingInterval = pingInterval wt.pongTimeout = pongTimeout } // Connect establishes the WebSocket connection func (wt *WebSocketTransport) Connect(ctx context.Context) error { wt.mu.Lock() defer wt.mu.Unlock() if wt.connected { return nil } if wt.isServer { return wt.startServer() } else { return wt.connectToServer(ctx) } } // Disconnect closes the WebSocket connection func (wt *WebSocketTransport) Disconnect(ctx context.Context) error { wt.mu.Lock() defer wt.mu.Unlock() if !wt.connected { return nil } wt.cancel() if wt.isServer && wt.server != nil { shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() wt.server.Shutdown(shutdownCtx) } // Close all connections for id, conn := range wt.connections { conn.Close() delete(wt.connections, id) } close(wt.receiveChan) wt.connected = false wt.metrics.Connections = 0 return nil } // Send transmits a message through WebSocket func (wt *WebSocketTransport) Send(ctx context.Context, msg *Message) error { start := time.Now() wt.mu.RLock() if !wt.connected { wt.mu.RUnlock() wt.metrics.Errors++ return fmt.Errorf("transport not connected") } // Serialize message data, err := json.Marshal(msg) if err != nil { wt.mu.RUnlock() wt.metrics.Errors++ return fmt.Errorf("failed to marshal message: %w", err) } // Send to all connections var sendErr error connectionCount := len(wt.connections) wt.mu.RUnlock() if connectionCount == 0 { wt.metrics.Errors++ return fmt.Errorf("no active connections") } wt.mu.RLock() for connID, conn := range wt.connections { if err := wt.sendToConnection(conn, data); err != nil { sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err) // Remove failed connection go wt.removeConnection(connID) } } wt.mu.RUnlock() if sendErr == nil { wt.updateSendMetrics(msg, time.Since(start)) } else { wt.metrics.Errors++ } return sendErr } // Receive returns a channel for receiving messages func (wt *WebSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) { wt.mu.RLock() defer wt.mu.RUnlock() if !wt.connected { return nil, fmt.Errorf("transport not connected") } return wt.receiveChan, nil } // Health returns the health status of the transport func (wt *WebSocketTransport) Health() ComponentHealth { wt.mu.RLock() defer wt.mu.RUnlock() status := "unhealthy" if wt.connected { if len(wt.connections) > 0 { status = "healthy" } else { status = "degraded" // Connected but no active connections } } return ComponentHealth{ Status: status, LastCheck: time.Now(), ResponseTime: time.Millisecond * 5, // Very fast for WebSocket ErrorCount: wt.metrics.Errors, } } // GetMetrics returns transport-specific metrics func (wt *WebSocketTransport) GetMetrics() TransportMetrics { wt.mu.RLock() defer wt.mu.RUnlock() return TransportMetrics{ BytesSent: wt.metrics.BytesSent, BytesReceived: wt.metrics.BytesReceived, MessagesSent: wt.metrics.MessagesSent, MessagesReceived: wt.metrics.MessagesReceived, Connections: len(wt.connections), Errors: wt.metrics.Errors, Latency: wt.metrics.Latency, } } // Private helper methods func (wt *WebSocketTransport) startServer() error { mux := http.NewServeMux() mux.HandleFunc(wt.path, wt.handleWebSocket) // Add health check endpoint mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { health := wt.Health() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(health) }) // Add metrics endpoint mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { metrics := wt.GetMetrics() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(metrics) }) addr := fmt.Sprintf("%s:%d", wt.address, wt.port) wt.server = &http.Server{ Addr: addr, Handler: mux, ReadTimeout: 60 * time.Second, WriteTimeout: 60 * time.Second, IdleTimeout: 120 * time.Second, } wt.connected = true // Start server in goroutine go func() { if err := wt.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { wt.metrics.Errors++ } }() return nil } func (wt *WebSocketTransport) connectToServer(ctx context.Context) error { url := fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path) conn, _, err := websocket.DefaultDialer.DialContext(ctx, url, nil) if err != nil { return fmt.Errorf("failed to connect to WebSocket server: %w", err) } connID := fmt.Sprintf("client_%d", time.Now().UnixNano()) wt.connections[connID] = conn wt.connected = true wt.metrics.Connections = 1 // Start handling connection go wt.handleConnection(connID, conn) return nil } func (wt *WebSocketTransport) handleWebSocket(w http.ResponseWriter, r *http.Request) { conn, err := wt.upgrader.Upgrade(w, r, nil) if err != nil { wt.metrics.Errors++ return } connID := fmt.Sprintf("server_%d", time.Now().UnixNano()) wt.mu.Lock() wt.connections[connID] = conn wt.metrics.Connections = len(wt.connections) wt.mu.Unlock() go wt.handleConnection(connID, conn) } func (wt *WebSocketTransport) handleConnection(connID string, conn *websocket.Conn) { defer wt.removeConnection(connID) // Set up ping/pong handling conn.SetReadDeadline(time.Now().Add(wt.pongTimeout)) conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(wt.pongTimeout)) return nil }) // Start ping routine go wt.pingRoutine(connID, conn) for { select { case <-wt.ctx.Done(): return default: var msg Message err := conn.ReadJSON(&msg) if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { wt.metrics.Errors++ } return } // Deliver message select { case wt.receiveChan <- &msg: wt.updateReceiveMetrics(&msg) case <-wt.ctx.Done(): return default: // Channel full, drop message wt.metrics.Errors++ } } } } func (wt *WebSocketTransport) pingRoutine(connID string, conn *websocket.Conn) { ticker := time.NewTicker(wt.pingInterval) defer ticker.Stop() for { select { case <-ticker.C: if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { return // Connection is likely closed } case <-wt.ctx.Done(): return } } } func (wt *WebSocketTransport) sendToConnection(conn *websocket.Conn, data []byte) error { conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) return conn.WriteMessage(websocket.TextMessage, data) } func (wt *WebSocketTransport) removeConnection(connID string) { wt.mu.Lock() defer wt.mu.Unlock() if conn, exists := wt.connections[connID]; exists { conn.Close() delete(wt.connections, connID) wt.metrics.Connections = len(wt.connections) } } func (wt *WebSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) { wt.mu.Lock() defer wt.mu.Unlock() wt.metrics.MessagesSent++ wt.metrics.Latency = latency // Estimate message size messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) if msg.Data != nil { messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) } wt.metrics.BytesSent += messageSize } func (wt *WebSocketTransport) updateReceiveMetrics(msg *Message) { wt.mu.Lock() defer wt.mu.Unlock() wt.metrics.MessagesReceived++ // Estimate message size messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) if msg.Data != nil { messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) } wt.metrics.BytesReceived += messageSize } // Broadcast sends a message to all connected clients (server mode only) func (wt *WebSocketTransport) Broadcast(ctx context.Context, msg *Message) error { if !wt.isServer { return fmt.Errorf("broadcast only available in server mode") } return wt.Send(ctx, msg) } // GetURL returns the WebSocket URL (for testing/debugging) func (wt *WebSocketTransport) GetURL() string { return fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path) } // GetConnectionCount returns the number of active connections func (wt *WebSocketTransport) GetConnectionCount() int { wt.mu.RLock() defer wt.mu.RUnlock() return len(wt.connections) } // SetAllowedOrigins configures CORS for WebSocket connections func (wt *WebSocketTransport) SetAllowedOrigins(origins []string) { if len(origins) == 0 { wt.upgrader.CheckOrigin = func(r *http.Request) bool { return true // Allow all origins } return } originMap := make(map[string]bool) for _, origin := range origins { originMap[origin] = true } wt.upgrader.CheckOrigin = func(r *http.Request) bool { origin := r.Header.Get("Origin") return originMap[origin] } }