package transport import ( "context" "fmt" "sync" "time" ) // MemoryTransport implements in-memory message transport for local communication type MemoryTransport struct { channels map[string]chan *Message metrics TransportMetrics connected bool mu sync.RWMutex } // NewMemoryTransport creates a new in-memory transport func NewMemoryTransport() *MemoryTransport { return &MemoryTransport{ channels: make(map[string]chan *Message), metrics: TransportMetrics{}, } } // Connect establishes the transport connection func (mt *MemoryTransport) Connect(ctx context.Context) error { mt.mu.Lock() defer mt.mu.Unlock() if mt.connected { return nil } mt.connected = true mt.metrics.Connections = 1 return nil } // Disconnect closes the transport connection func (mt *MemoryTransport) Disconnect(ctx context.Context) error { mt.mu.Lock() defer mt.mu.Unlock() if !mt.connected { return nil } // Close all channels for _, ch := range mt.channels { close(ch) } mt.channels = make(map[string]chan *Message) mt.connected = false mt.metrics.Connections = 0 return nil } // Send transmits a message through the memory transport func (mt *MemoryTransport) Send(ctx context.Context, msg *Message) error { start := time.Now() mt.mu.RLock() if !mt.connected { mt.mu.RUnlock() mt.metrics.Errors++ return fmt.Errorf("transport not connected") } // Get or create channel for topic ch, exists := mt.channels[msg.Topic] if !exists { mt.mu.RUnlock() mt.mu.Lock() // Double-check after acquiring write lock if ch, exists = mt.channels[msg.Topic]; !exists { ch = make(chan *Message, 1000) // Buffered channel mt.channels[msg.Topic] = ch } mt.mu.Unlock() } else { mt.mu.RUnlock() } // Send message select { case ch <- msg: mt.updateSendMetrics(msg, time.Since(start)) return nil case <-ctx.Done(): mt.metrics.Errors++ return ctx.Err() default: mt.metrics.Errors++ return fmt.Errorf("channel full for topic: %s", msg.Topic) } } // Receive returns a channel for receiving messages func (mt *MemoryTransport) Receive(ctx context.Context) (<-chan *Message, error) { mt.mu.RLock() defer mt.mu.RUnlock() if !mt.connected { return nil, fmt.Errorf("transport not connected") } // Create a merged channel that receives from all topic channels merged := make(chan *Message, 1000) go func() { defer close(merged) // Use a wait group to handle multiple topic channels var wg sync.WaitGroup mt.mu.RLock() for topic, ch := range mt.channels { wg.Add(1) go func(topicCh <-chan *Message, topicName string) { defer wg.Done() for { select { case msg, ok := <-topicCh: if !ok { return } select { case merged <- msg: mt.updateReceiveMetrics(msg) case <-ctx.Done(): return } case <-ctx.Done(): return } } }(ch, topic) } mt.mu.RUnlock() wg.Wait() }() return merged, nil } // Health returns the health status of the transport func (mt *MemoryTransport) Health() ComponentHealth { mt.mu.RLock() defer mt.mu.RUnlock() status := "unhealthy" if mt.connected { status = "healthy" } return ComponentHealth{ Status: status, LastCheck: time.Now(), ResponseTime: time.Microsecond, // Very fast for memory transport ErrorCount: mt.metrics.Errors, } } // GetMetrics returns transport-specific metrics func (mt *MemoryTransport) GetMetrics() TransportMetrics { mt.mu.RLock() defer mt.mu.RUnlock() // Create a copy to avoid race conditions return TransportMetrics{ BytesSent: mt.metrics.BytesSent, BytesReceived: mt.metrics.BytesReceived, MessagesSent: mt.metrics.MessagesSent, MessagesReceived: mt.metrics.MessagesReceived, Connections: mt.metrics.Connections, Errors: mt.metrics.Errors, Latency: mt.metrics.Latency, } } // Private helper methods func (mt *MemoryTransport) updateSendMetrics(msg *Message, latency time.Duration) { mt.mu.Lock() defer mt.mu.Unlock() mt.metrics.MessagesSent++ mt.metrics.Latency = latency // Estimate message size (simplified) messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) if msg.Data != nil { messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) } mt.metrics.BytesSent += messageSize } func (mt *MemoryTransport) updateReceiveMetrics(msg *Message) { mt.mu.Lock() defer mt.mu.Unlock() mt.metrics.MessagesReceived++ // Estimate message size (simplified) messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) if msg.Data != nil { messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) } mt.metrics.BytesReceived += messageSize } // GetChannelForTopic returns the channel for a specific topic (for testing/debugging) func (mt *MemoryTransport) GetChannelForTopic(topic string) (<-chan *Message, bool) { mt.mu.RLock() defer mt.mu.RUnlock() ch, exists := mt.channels[topic] return ch, exists } // GetTopicCount returns the number of active topic channels func (mt *MemoryTransport) GetTopicCount() int { mt.mu.RLock() defer mt.mu.RUnlock() return len(mt.channels) }