Files
mev-beta/orig/pkg/transport/websocket_transport.go
Administrator 803de231ba feat: create v2-prep branch with comprehensive planning
Restructured project for V2 refactor:

**Structure Changes:**
- Moved all V1 code to orig/ folder (preserved with git mv)
- Created docs/planning/ directory
- Added orig/README_V1.md explaining V1 preservation

**Planning Documents:**
- 00_V2_MASTER_PLAN.md: Complete architecture overview
  - Executive summary of critical V1 issues
  - High-level component architecture diagrams
  - 5-phase implementation roadmap
  - Success metrics and risk mitigation

- 07_TASK_BREAKDOWN.md: Atomic task breakdown
  - 99+ hours of detailed tasks
  - Every task < 2 hours (atomic)
  - Clear dependencies and success criteria
  - Organized by implementation phase

**V2 Key Improvements:**
- Per-exchange parsers (factory pattern)
- Multi-layer strict validation
- Multi-index pool cache
- Background validation pipeline
- Comprehensive observability

**Critical Issues Addressed:**
- Zero address tokens (strict validation + cache enrichment)
- Parsing accuracy (protocol-specific parsers)
- No audit trail (background validation channel)
- Inefficient lookups (multi-index cache)
- Stats disconnection (event-driven metrics)

Next Steps:
1. Review planning documents
2. Begin Phase 1: Foundation (P1-001 through P1-010)
3. Implement parsers in Phase 2
4. Build cache system in Phase 3
5. Add validation pipeline in Phase 4
6. Migrate and test in Phase 5

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-10 10:14:26 +01:00

429 lines
10 KiB
Go

package transport
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
// WebSocketTransport implements WebSocket transport for real-time monitoring
type WebSocketTransport struct {
address string
port int
path string
upgrader websocket.Upgrader
connections map[string]*websocket.Conn
metrics TransportMetrics
connected bool
isServer bool
receiveChan chan *Message
server *http.Server
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
pingInterval time.Duration
pongTimeout time.Duration
}
// NewWebSocketTransport creates a new WebSocket transport
func NewWebSocketTransport(address string, port int, path string, isServer bool) *WebSocketTransport {
ctx, cancel := context.WithCancel(context.Background())
return &WebSocketTransport{
address: address,
port: port,
path: path,
connections: make(map[string]*websocket.Conn),
metrics: TransportMetrics{},
isServer: isServer,
receiveChan: make(chan *Message, 1000),
ctx: ctx,
cancel: cancel,
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for now
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
pingInterval: 30 * time.Second,
pongTimeout: 10 * time.Second,
}
}
// SetPingPongSettings configures WebSocket ping/pong settings
func (wt *WebSocketTransport) SetPingPongSettings(pingInterval, pongTimeout time.Duration) {
wt.pingInterval = pingInterval
wt.pongTimeout = pongTimeout
}
// Connect establishes the WebSocket connection
func (wt *WebSocketTransport) Connect(ctx context.Context) error {
wt.mu.Lock()
defer wt.mu.Unlock()
if wt.connected {
return nil
}
if wt.isServer {
return wt.startServer()
} else {
return wt.connectToServer(ctx)
}
}
// Disconnect closes the WebSocket connection
func (wt *WebSocketTransport) Disconnect(ctx context.Context) error {
wt.mu.Lock()
defer wt.mu.Unlock()
if !wt.connected {
return nil
}
wt.cancel()
if wt.isServer && wt.server != nil {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wt.server.Shutdown(shutdownCtx)
}
// Close all connections
for id, conn := range wt.connections {
conn.Close()
delete(wt.connections, id)
}
close(wt.receiveChan)
wt.connected = false
wt.metrics.Connections = 0
return nil
}
// Send transmits a message through WebSocket
func (wt *WebSocketTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
wt.mu.RLock()
if !wt.connected {
wt.mu.RUnlock()
wt.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Serialize message
data, err := json.Marshal(msg)
if err != nil {
wt.mu.RUnlock()
wt.metrics.Errors++
return fmt.Errorf("failed to marshal message: %w", err)
}
// Send to all connections
var sendErr error
connectionCount := len(wt.connections)
wt.mu.RUnlock()
if connectionCount == 0 {
wt.metrics.Errors++
return fmt.Errorf("no active connections")
}
wt.mu.RLock()
for connID, conn := range wt.connections {
if err := wt.sendToConnection(conn, data); err != nil {
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
// Remove failed connection
go wt.removeConnection(connID)
}
}
wt.mu.RUnlock()
if sendErr == nil {
wt.updateSendMetrics(msg, time.Since(start))
} else {
wt.metrics.Errors++
}
return sendErr
}
// Receive returns a channel for receiving messages
func (wt *WebSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) {
wt.mu.RLock()
defer wt.mu.RUnlock()
if !wt.connected {
return nil, fmt.Errorf("transport not connected")
}
return wt.receiveChan, nil
}
// Health returns the health status of the transport
func (wt *WebSocketTransport) Health() ComponentHealth {
wt.mu.RLock()
defer wt.mu.RUnlock()
status := "unhealthy"
if wt.connected {
if len(wt.connections) > 0 {
status = "healthy"
} else {
status = "degraded" // Connected but no active connections
}
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: time.Millisecond * 5, // Very fast for WebSocket
ErrorCount: wt.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (wt *WebSocketTransport) GetMetrics() TransportMetrics {
wt.mu.RLock()
defer wt.mu.RUnlock()
return TransportMetrics{
BytesSent: wt.metrics.BytesSent,
BytesReceived: wt.metrics.BytesReceived,
MessagesSent: wt.metrics.MessagesSent,
MessagesReceived: wt.metrics.MessagesReceived,
Connections: len(wt.connections),
Errors: wt.metrics.Errors,
Latency: wt.metrics.Latency,
}
}
// Private helper methods
func (wt *WebSocketTransport) startServer() error {
mux := http.NewServeMux()
mux.HandleFunc(wt.path, wt.handleWebSocket)
// Add health check endpoint
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
health := wt.Health()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(health)
})
// Add metrics endpoint
mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
metrics := wt.GetMetrics()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(metrics)
})
addr := fmt.Sprintf("%s:%d", wt.address, wt.port)
wt.server = &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second, // Prevent Slowloris attacks
ReadTimeout: 60 * time.Second,
WriteTimeout: 60 * time.Second,
IdleTimeout: 120 * time.Second,
}
wt.connected = true
// Start server in goroutine
go func() {
if err := wt.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
wt.metrics.Errors++
}
}()
return nil
}
func (wt *WebSocketTransport) connectToServer(ctx context.Context) error {
url := fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path)
conn, _, err := websocket.DefaultDialer.DialContext(ctx, url, nil)
if err != nil {
return fmt.Errorf("failed to connect to WebSocket server: %w", err)
}
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
wt.connections[connID] = conn
wt.connected = true
wt.metrics.Connections = 1
// Start handling connection
go wt.handleConnection(connID, conn)
return nil
}
func (wt *WebSocketTransport) handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := wt.upgrader.Upgrade(w, r, nil)
if err != nil {
wt.metrics.Errors++
return
}
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
wt.mu.Lock()
wt.connections[connID] = conn
wt.metrics.Connections = len(wt.connections)
wt.mu.Unlock()
go wt.handleConnection(connID, conn)
}
func (wt *WebSocketTransport) handleConnection(connID string, conn *websocket.Conn) {
defer wt.removeConnection(connID)
// Set up ping/pong handling
conn.SetReadDeadline(time.Now().Add(wt.pongTimeout))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(wt.pongTimeout))
return nil
})
// Start ping routine
go wt.pingRoutine(connID, conn)
for {
select {
case <-wt.ctx.Done():
return
default:
var msg Message
err := conn.ReadJSON(&msg)
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
wt.metrics.Errors++
}
return
}
// Deliver message
select {
case wt.receiveChan <- &msg:
wt.updateReceiveMetrics(&msg)
case <-wt.ctx.Done():
return
default:
// Channel full, drop message
wt.metrics.Errors++
}
}
}
}
func (wt *WebSocketTransport) pingRoutine(connID string, conn *websocket.Conn) {
ticker := time.NewTicker(wt.pingInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
return // Connection is likely closed
}
case <-wt.ctx.Done():
return
}
}
}
func (wt *WebSocketTransport) sendToConnection(conn *websocket.Conn, data []byte) error {
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return conn.WriteMessage(websocket.TextMessage, data)
}
func (wt *WebSocketTransport) removeConnection(connID string) {
wt.mu.Lock()
defer wt.mu.Unlock()
if conn, exists := wt.connections[connID]; exists {
conn.Close()
delete(wt.connections, connID)
wt.metrics.Connections = len(wt.connections)
}
}
func (wt *WebSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) {
wt.mu.Lock()
defer wt.mu.Unlock()
wt.metrics.MessagesSent++
wt.metrics.Latency = latency
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
wt.metrics.BytesSent += messageSize
}
func (wt *WebSocketTransport) updateReceiveMetrics(msg *Message) {
wt.mu.Lock()
defer wt.mu.Unlock()
wt.metrics.MessagesReceived++
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
wt.metrics.BytesReceived += messageSize
}
// Broadcast sends a message to all connected clients (server mode only)
func (wt *WebSocketTransport) Broadcast(ctx context.Context, msg *Message) error {
if !wt.isServer {
return fmt.Errorf("broadcast only available in server mode")
}
return wt.Send(ctx, msg)
}
// GetURL returns the WebSocket URL (for testing/debugging)
func (wt *WebSocketTransport) GetURL() string {
return fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path)
}
// GetConnectionCount returns the number of active connections
func (wt *WebSocketTransport) GetConnectionCount() int {
wt.mu.RLock()
defer wt.mu.RUnlock()
return len(wt.connections)
}
// SetAllowedOrigins configures CORS for WebSocket connections
func (wt *WebSocketTransport) SetAllowedOrigins(origins []string) {
if len(origins) == 0 {
wt.upgrader.CheckOrigin = func(r *http.Request) bool {
return true // Allow all origins
}
return
}
originMap := make(map[string]bool)
for _, origin := range origins {
originMap[origin] = true
}
wt.upgrader.CheckOrigin = func(r *http.Request) bool {
origin := r.Header.Get("Origin")
return originMap[origin]
}
}