feat(transport): implement comprehensive universal message bus

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Krypto Kajun
2025-09-19 16:39:14 -05:00
parent fac8a64092
commit c0ec08468c
13 changed files with 5515 additions and 63 deletions

591
pkg/transport/dlq.go Normal file
View File

@@ -0,0 +1,591 @@
package transport
import (
"context"
"fmt"
"sort"
"sync"
"time"
)
// DeadLetterQueue handles failed messages with retry and reprocessing capabilities
type DeadLetterQueue struct {
messages map[string][]*DLQMessage
config DLQConfig
metrics DLQMetrics
reprocessor MessageReprocessor
mu sync.RWMutex
cleanupTicker *time.Ticker
ctx context.Context
cancel context.CancelFunc
}
// DLQMessage represents a message in the dead letter queue
type DLQMessage struct {
ID string
OriginalMessage *Message
Topic string
FirstFailed time.Time
LastAttempt time.Time
AttemptCount int
MaxRetries int
FailureReason string
RetryDelay time.Duration
NextRetry time.Time
Metadata map[string]interface{}
Permanent bool
}
// DLQConfig configures dead letter queue behavior
type DLQConfig struct {
MaxMessages int
MaxRetries int
RetentionTime time.Duration
AutoReprocess bool
ReprocessInterval time.Duration
BackoffStrategy BackoffStrategy
InitialRetryDelay time.Duration
MaxRetryDelay time.Duration
BackoffMultiplier float64
PermanentFailures []string // Error patterns that mark messages as permanently failed
ReprocessBatchSize int
}
// BackoffStrategy defines retry delay calculation methods
type BackoffStrategy string
const (
BackoffFixed BackoffStrategy = "fixed"
BackoffLinear BackoffStrategy = "linear"
BackoffExponential BackoffStrategy = "exponential"
BackoffCustom BackoffStrategy = "custom"
)
// DLQMetrics tracks dead letter queue statistics
type DLQMetrics struct {
MessagesAdded int64
MessagesReprocessed int64
MessagesExpired int64
MessagesPermanent int64
ReprocessSuccesses int64
ReprocessFailures int64
QueueSize int64
OldestMessage time.Time
}
// MessageReprocessor handles message reprocessing logic
type MessageReprocessor interface {
Reprocess(ctx context.Context, msg *DLQMessage) error
CanReprocess(msg *DLQMessage) bool
ShouldRetry(msg *DLQMessage, err error) bool
}
// DefaultMessageReprocessor implements basic reprocessing logic
type DefaultMessageReprocessor struct {
publisher MessagePublisher
}
// MessagePublisher interface for republishing messages
type MessagePublisher interface {
Publish(ctx context.Context, msg *Message) error
}
// NewDeadLetterQueue creates a new dead letter queue
func NewDeadLetterQueue(config DLQConfig) *DeadLetterQueue {
ctx, cancel := context.WithCancel(context.Background())
dlq := &DeadLetterQueue{
messages: make(map[string][]*DLQMessage),
config: config,
metrics: DLQMetrics{},
ctx: ctx,
cancel: cancel,
}
// Set default configuration values
if dlq.config.MaxMessages == 0 {
dlq.config.MaxMessages = 10000
}
if dlq.config.MaxRetries == 0 {
dlq.config.MaxRetries = 3
}
if dlq.config.RetentionTime == 0 {
dlq.config.RetentionTime = 24 * time.Hour
}
if dlq.config.ReprocessInterval == 0 {
dlq.config.ReprocessInterval = 5 * time.Minute
}
if dlq.config.InitialRetryDelay == 0 {
dlq.config.InitialRetryDelay = time.Minute
}
if dlq.config.MaxRetryDelay == 0 {
dlq.config.MaxRetryDelay = time.Hour
}
if dlq.config.BackoffMultiplier == 0 {
dlq.config.BackoffMultiplier = 2.0
}
if dlq.config.BackoffStrategy == "" {
dlq.config.BackoffStrategy = BackoffExponential
}
if dlq.config.ReprocessBatchSize == 0 {
dlq.config.ReprocessBatchSize = 10
}
// Start cleanup routine
dlq.startCleanupRoutine()
// Start reprocessing routine if enabled
if dlq.config.AutoReprocess {
dlq.startReprocessRoutine()
}
return dlq
}
// AddMessage adds a failed message to the dead letter queue
func (dlq *DeadLetterQueue) AddMessage(topic string, msg *Message) error {
return dlq.AddMessageWithReason(topic, msg, "unknown failure")
}
// AddMessageWithReason adds a failed message with a specific failure reason
func (dlq *DeadLetterQueue) AddMessageWithReason(topic string, msg *Message, reason string) error {
dlq.mu.Lock()
defer dlq.mu.Unlock()
// Check if we've exceeded max messages
totalMessages := dlq.getTotalMessageCount()
if totalMessages >= dlq.config.MaxMessages {
// Remove oldest message to make room
dlq.removeOldestMessage()
}
// Check if this is a permanent failure
permanent := dlq.isPermanentFailure(reason)
dlqMsg := &DLQMessage{
ID: fmt.Sprintf("dlq_%s_%d", topic, time.Now().UnixNano()),
OriginalMessage: msg,
Topic: topic,
FirstFailed: time.Now(),
LastAttempt: time.Now(),
AttemptCount: 1,
MaxRetries: dlq.config.MaxRetries,
FailureReason: reason,
Metadata: make(map[string]interface{}),
Permanent: permanent,
}
if !permanent {
dlqMsg.RetryDelay = dlq.calculateRetryDelay(dlqMsg)
dlqMsg.NextRetry = time.Now().Add(dlqMsg.RetryDelay)
}
// Add to queue
if _, exists := dlq.messages[topic]; !exists {
dlq.messages[topic] = make([]*DLQMessage, 0)
}
dlq.messages[topic] = append(dlq.messages[topic], dlqMsg)
// Update metrics
dlq.metrics.MessagesAdded++
dlq.metrics.QueueSize++
if permanent {
dlq.metrics.MessagesPermanent++
}
dlq.updateOldestMessage()
return nil
}
// GetMessages returns all messages for a topic
func (dlq *DeadLetterQueue) GetMessages(topic string) ([]*DLQMessage, error) {
dlq.mu.RLock()
defer dlq.mu.RUnlock()
messages, exists := dlq.messages[topic]
if !exists {
return []*DLQMessage{}, nil
}
// Return a copy to avoid race conditions
result := make([]*DLQMessage, len(messages))
copy(result, messages)
return result, nil
}
// GetAllMessages returns all messages across all topics
func (dlq *DeadLetterQueue) GetAllMessages() map[string][]*DLQMessage {
dlq.mu.RLock()
defer dlq.mu.RUnlock()
result := make(map[string][]*DLQMessage)
for topic, messages := range dlq.messages {
result[topic] = make([]*DLQMessage, len(messages))
copy(result[topic], messages)
}
return result
}
// ReprocessMessage attempts to reprocess a specific message
func (dlq *DeadLetterQueue) ReprocessMessage(messageID string) error {
dlq.mu.Lock()
defer dlq.mu.Unlock()
// Find message
var dlqMsg *DLQMessage
var topic string
var index int
for t, messages := range dlq.messages {
for i, msg := range messages {
if msg.ID == messageID {
dlqMsg = msg
topic = t
index = i
break
}
}
if dlqMsg != nil {
break
}
}
if dlqMsg == nil {
return fmt.Errorf("message not found: %s", messageID)
}
if dlqMsg.Permanent {
return fmt.Errorf("message marked as permanent failure: %s", messageID)
}
// Attempt reprocessing
err := dlq.attemptReprocess(dlqMsg)
if err == nil {
// Success - remove from queue
dlq.removeMessageByIndex(topic, index)
dlq.metrics.ReprocessSuccesses++
dlq.metrics.QueueSize--
return nil
}
// Failed - update retry information
dlqMsg.AttemptCount++
dlqMsg.LastAttempt = time.Now()
dlqMsg.FailureReason = err.Error()
if dlqMsg.AttemptCount >= dlqMsg.MaxRetries {
dlqMsg.Permanent = true
dlq.metrics.MessagesPermanent++
} else {
dlqMsg.RetryDelay = dlq.calculateRetryDelay(dlqMsg)
dlqMsg.NextRetry = time.Now().Add(dlqMsg.RetryDelay)
}
dlq.metrics.ReprocessFailures++
return fmt.Errorf("reprocessing failed: %w", err)
}
// PurgeMessages removes all messages for a topic
func (dlq *DeadLetterQueue) PurgeMessages(topic string) error {
dlq.mu.Lock()
defer dlq.mu.Unlock()
if messages, exists := dlq.messages[topic]; exists {
count := len(messages)
delete(dlq.messages, topic)
dlq.metrics.QueueSize -= int64(count)
dlq.updateOldestMessage()
}
return nil
}
// PurgeAllMessages removes all messages from the queue
func (dlq *DeadLetterQueue) PurgeAllMessages() error {
dlq.mu.Lock()
defer dlq.mu.Unlock()
dlq.messages = make(map[string][]*DLQMessage)
dlq.metrics.QueueSize = 0
dlq.metrics.OldestMessage = time.Time{}
return nil
}
// GetMessageCount returns the total number of messages in the queue
func (dlq *DeadLetterQueue) GetMessageCount() int {
dlq.mu.RLock()
defer dlq.mu.RUnlock()
return dlq.getTotalMessageCount()
}
// GetMetrics returns current DLQ metrics
func (dlq *DeadLetterQueue) GetMetrics() DLQMetrics {
dlq.mu.RLock()
defer dlq.mu.RUnlock()
return dlq.metrics
}
// SetReprocessor sets the message reprocessor
func (dlq *DeadLetterQueue) SetReprocessor(reprocessor MessageReprocessor) {
dlq.mu.Lock()
defer dlq.mu.Unlock()
dlq.reprocessor = reprocessor
}
// Cleanup removes expired messages
func (dlq *DeadLetterQueue) Cleanup(maxAge time.Duration) error {
dlq.mu.Lock()
defer dlq.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
expiredCount := 0
for topic, messages := range dlq.messages {
filtered := make([]*DLQMessage, 0)
for _, msg := range messages {
if msg.FirstFailed.After(cutoff) {
filtered = append(filtered, msg)
} else {
expiredCount++
}
}
dlq.messages[topic] = filtered
// Remove empty topics
if len(filtered) == 0 {
delete(dlq.messages, topic)
}
}
dlq.metrics.MessagesExpired += int64(expiredCount)
dlq.metrics.QueueSize -= int64(expiredCount)
dlq.updateOldestMessage()
return nil
}
// Stop gracefully shuts down the dead letter queue
func (dlq *DeadLetterQueue) Stop() error {
dlq.cancel()
if dlq.cleanupTicker != nil {
dlq.cleanupTicker.Stop()
}
return nil
}
// Private helper methods
func (dlq *DeadLetterQueue) getTotalMessageCount() int {
count := 0
for _, messages := range dlq.messages {
count += len(messages)
}
return count
}
func (dlq *DeadLetterQueue) removeOldestMessage() {
var oldestTime time.Time
var oldestTopic string
var oldestIndex int
for topic, messages := range dlq.messages {
for i, msg := range messages {
if oldestTime.IsZero() || msg.FirstFailed.Before(oldestTime) {
oldestTime = msg.FirstFailed
oldestTopic = topic
oldestIndex = i
}
}
}
if !oldestTime.IsZero() {
dlq.removeMessageByIndex(oldestTopic, oldestIndex)
dlq.metrics.QueueSize--
}
}
func (dlq *DeadLetterQueue) removeMessageByIndex(topic string, index int) {
messages := dlq.messages[topic]
dlq.messages[topic] = append(messages[:index], messages[index+1:]...)
if len(dlq.messages[topic]) == 0 {
delete(dlq.messages, topic)
}
}
func (dlq *DeadLetterQueue) isPermanentFailure(reason string) bool {
for _, pattern := range dlq.config.PermanentFailures {
if pattern == reason {
return true
}
// Simple pattern matching (can be enhanced with regex)
if len(pattern) > 0 && pattern[len(pattern)-1] == '*' {
prefix := pattern[:len(pattern)-1]
if len(reason) >= len(prefix) && reason[:len(prefix)] == prefix {
return true
}
}
}
return false
}
func (dlq *DeadLetterQueue) calculateRetryDelay(msg *DLQMessage) time.Duration {
switch dlq.config.BackoffStrategy {
case BackoffFixed:
return dlq.config.InitialRetryDelay
case BackoffLinear:
delay := time.Duration(msg.AttemptCount) * dlq.config.InitialRetryDelay
if delay > dlq.config.MaxRetryDelay {
return dlq.config.MaxRetryDelay
}
return delay
case BackoffExponential:
delay := time.Duration(float64(dlq.config.InitialRetryDelay) *
pow(dlq.config.BackoffMultiplier, float64(msg.AttemptCount-1)))
if delay > dlq.config.MaxRetryDelay {
return dlq.config.MaxRetryDelay
}
return delay
default:
return dlq.config.InitialRetryDelay
}
}
func (dlq *DeadLetterQueue) attemptReprocess(msg *DLQMessage) error {
if dlq.reprocessor == nil {
return fmt.Errorf("no reprocessor configured")
}
if !dlq.reprocessor.CanReprocess(msg) {
return fmt.Errorf("message cannot be reprocessed")
}
return dlq.reprocessor.Reprocess(dlq.ctx, msg)
}
func (dlq *DeadLetterQueue) updateOldestMessage() {
var oldest time.Time
for _, messages := range dlq.messages {
for _, msg := range messages {
if oldest.IsZero() || msg.FirstFailed.Before(oldest) {
oldest = msg.FirstFailed
}
}
}
dlq.metrics.OldestMessage = oldest
}
func (dlq *DeadLetterQueue) startCleanupRoutine() {
dlq.cleanupTicker = time.NewTicker(dlq.config.ReprocessInterval)
go func() {
for {
select {
case <-dlq.cleanupTicker.C:
dlq.Cleanup(dlq.config.RetentionTime)
case <-dlq.ctx.Done():
return
}
}
}()
}
func (dlq *DeadLetterQueue) startReprocessRoutine() {
ticker := time.NewTicker(dlq.config.ReprocessInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
dlq.processRetryableMessages()
case <-dlq.ctx.Done():
return
}
}
}()
}
func (dlq *DeadLetterQueue) processRetryableMessages() {
dlq.mu.Lock()
retryable := dlq.getRetryableMessages()
dlq.mu.Unlock()
// Sort by next retry time
sort.Slice(retryable, func(i, j int) bool {
return retryable[i].NextRetry.Before(retryable[j].NextRetry)
})
// Process batch
batchSize := dlq.config.ReprocessBatchSize
if len(retryable) < batchSize {
batchSize = len(retryable)
}
for i := 0; i < batchSize; i++ {
msg := retryable[i]
if time.Now().After(msg.NextRetry) {
dlq.ReprocessMessage(msg.ID)
}
}
}
func (dlq *DeadLetterQueue) getRetryableMessages() []*DLQMessage {
var retryable []*DLQMessage
for _, messages := range dlq.messages {
for _, msg := range messages {
if !msg.Permanent && msg.AttemptCount < msg.MaxRetries {
retryable = append(retryable, msg)
}
}
}
return retryable
}
// Implementation of DefaultMessageReprocessor
func NewDefaultMessageReprocessor(publisher MessagePublisher) *DefaultMessageReprocessor {
return &DefaultMessageReprocessor{
publisher: publisher,
}
}
func (r *DefaultMessageReprocessor) Reprocess(ctx context.Context, msg *DLQMessage) error {
if r.publisher == nil {
return fmt.Errorf("no publisher configured")
}
return r.publisher.Publish(ctx, msg.OriginalMessage)
}
func (r *DefaultMessageReprocessor) CanReprocess(msg *DLQMessage) bool {
return !msg.Permanent && msg.AttemptCount < msg.MaxRetries
}
func (r *DefaultMessageReprocessor) ShouldRetry(msg *DLQMessage, err error) bool {
// Simple retry logic - can be enhanced based on error types
return msg.AttemptCount < msg.MaxRetries
}
// Helper function for power calculation
func pow(base, exp float64) float64 {
if exp == 0 {
return 1
}
result := base
for i := 1; i < int(exp); i++ {
result *= base
}
return result
}

277
pkg/transport/interfaces.go Normal file
View File

@@ -0,0 +1,277 @@
package transport
import (
"context"
"time"
)
// MessageType represents the type of message being sent
type MessageType string
const (
// Core message types
MessageTypeEvent MessageType = "event"
MessageTypeCommand MessageType = "command"
MessageTypeResponse MessageType = "response"
MessageTypeHeartbeat MessageType = "heartbeat"
MessageTypeStatus MessageType = "status"
MessageTypeError MessageType = "error"
// Business-specific message types
MessageTypeArbitrage MessageType = "arbitrage"
MessageTypeMarketData MessageType = "market_data"
MessageTypeExecution MessageType = "execution"
MessageTypeRiskCheck MessageType = "risk_check"
)
// Priority levels for message routing
type Priority uint8
const (
PriorityLow Priority = iota
PriorityNormal
PriorityHigh
PriorityCritical
PriorityEmergency
)
// Message represents a universal message in the system
type Message struct {
ID string `json:"id"`
Type MessageType `json:"type"`
Topic string `json:"topic"`
Source string `json:"source"`
Destination string `json:"destination"`
Priority Priority `json:"priority"`
Timestamp time.Time `json:"timestamp"`
TTL time.Duration `json:"ttl"`
Headers map[string]string `json:"headers"`
Payload []byte `json:"payload"`
Metadata map[string]interface{} `json:"metadata"`
}
// MessageHandler processes incoming messages
type MessageHandler func(ctx context.Context, msg *Message) error
// Transport defines the interface for different transport mechanisms
type Transport interface {
// Start initializes the transport
Start(ctx context.Context) error
// Stop gracefully shuts down the transport
Stop(ctx context.Context) error
// Send publishes a message
Send(ctx context.Context, msg *Message) error
// Subscribe registers a handler for messages on a topic
Subscribe(ctx context.Context, topic string, handler MessageHandler) error
// Unsubscribe removes a handler for a topic
Unsubscribe(ctx context.Context, topic string) error
// GetStats returns transport statistics
GetStats() TransportStats
// GetType returns the transport type
GetType() TransportType
// IsHealthy checks if the transport is functioning properly
IsHealthy() bool
}
// TransportType identifies different transport implementations
type TransportType string
const (
TransportTypeSharedMemory TransportType = "shared_memory"
TransportTypeUnixSocket TransportType = "unix_socket"
TransportTypeTCP TransportType = "tcp"
TransportTypeWebSocket TransportType = "websocket"
TransportTypeGRPC TransportType = "grpc"
)
// TransportStats provides metrics about transport performance
type TransportStats struct {
MessagesSent uint64 `json:"messages_sent"`
MessagesReceived uint64 `json:"messages_received"`
MessagesDropped uint64 `json:"messages_dropped"`
BytesSent uint64 `json:"bytes_sent"`
BytesReceived uint64 `json:"bytes_received"`
Latency time.Duration `json:"latency"`
ErrorCount uint64 `json:"error_count"`
ConnectedPeers int `json:"connected_peers"`
Uptime time.Duration `json:"uptime"`
}
// MessageBus coordinates message routing across multiple transports
type MessageBus interface {
// Start initializes the message bus
Start(ctx context.Context) error
// Stop gracefully shuts down the message bus
Stop(ctx context.Context) error
// RegisterTransport adds a transport to the bus
RegisterTransport(transport Transport) error
// UnregisterTransport removes a transport from the bus
UnregisterTransport(transportType TransportType) error
// Publish sends a message through the optimal transport
Publish(ctx context.Context, msg *Message) error
// Subscribe registers a handler for messages on a topic
Subscribe(ctx context.Context, topic string, handler MessageHandler) error
// Unsubscribe removes a handler for a topic
Unsubscribe(ctx context.Context, topic string) error
// GetTransport returns a specific transport
GetTransport(transportType TransportType) (Transport, error)
// GetStats returns aggregated statistics
GetStats() MessageBusStats
}
// MessageBusStats provides comprehensive metrics
type MessageBusStats struct {
TotalMessages uint64 `json:"total_messages"`
MessagesByType map[MessageType]uint64 `json:"messages_by_type"`
TransportStats map[TransportType]TransportStats `json:"transport_stats"`
ActiveTopics []string `json:"active_topics"`
Subscribers int `json:"subscribers"`
AverageLatency time.Duration `json:"average_latency"`
ThroughputMPS float64 `json:"throughput_mps"` // Messages per second
}
// Router determines the best transport for a message
type Router interface {
// Route selects the optimal transport for a message
Route(msg *Message) (TransportType, error)
// AddRule adds a routing rule
AddRule(rule RoutingRule) error
// RemoveRule removes a routing rule
RemoveRule(ruleID string) error
// GetRules returns all routing rules
GetRules() []RoutingRule
}
// RoutingRule defines how messages should be routed
type RoutingRule struct {
ID string `json:"id"`
Priority int `json:"priority"`
Condition Condition `json:"condition"`
Transport TransportType `json:"transport"`
Fallback TransportType `json:"fallback,omitempty"`
Description string `json:"description"`
}
// Condition defines when a routing rule applies
type Condition struct {
MessageType *MessageType `json:"message_type,omitempty"`
Topic *string `json:"topic,omitempty"`
Priority *Priority `json:"priority,omitempty"`
Source *string `json:"source,omitempty"`
Destination *string `json:"destination,omitempty"`
PayloadSize *int `json:"payload_size,omitempty"`
LatencyReq *time.Duration `json:"latency_requirement,omitempty"`
}
// DeadLetterQueue handles failed messages
type DeadLetterQueue interface {
// Add puts a failed message in the queue
Add(ctx context.Context, msg *Message, reason error) error
// Retry attempts to resend failed messages
Retry(ctx context.Context, maxRetries int) error
// Get retrieves failed messages
Get(ctx context.Context, limit int) ([]*FailedMessage, error)
// Remove deletes a failed message
Remove(ctx context.Context, messageID string) error
// GetStats returns dead letter queue statistics
GetStats() DLQStats
}
// FailedMessage represents a message that couldn't be delivered
type FailedMessage struct {
Message *Message `json:"message"`
Reason string `json:"reason"`
Attempts int `json:"attempts"`
FirstFailed time.Time `json:"first_failed"`
LastAttempt time.Time `json:"last_attempt"`
}
// DLQStats provides dead letter queue metrics
type DLQStats struct {
TotalMessages uint64 `json:"total_messages"`
RetryableMessages uint64 `json:"retryable_messages"`
PermanentFailures uint64 `json:"permanent_failures"`
OldestMessage time.Time `json:"oldest_message"`
AverageRetries float64 `json:"average_retries"`
}
// Serializer handles message encoding/decoding
type Serializer interface {
// Serialize converts a message to bytes
Serialize(msg *Message) ([]byte, error)
// Deserialize converts bytes to a message
Deserialize(data []byte) (*Message, error)
// GetFormat returns the serialization format
GetFormat() SerializationFormat
}
// SerializationFormat defines encoding types
type SerializationFormat string
const (
FormatJSON SerializationFormat = "json"
FormatProtobuf SerializationFormat = "protobuf"
FormatMsgPack SerializationFormat = "msgpack"
FormatAvro SerializationFormat = "avro"
)
// Persistence handles message storage
type Persistence interface {
// Store saves a message for persistence
Store(ctx context.Context, msg *Message) error
// Retrieve gets a stored message
Retrieve(ctx context.Context, messageID string) (*Message, error)
// Delete removes a stored message
Delete(ctx context.Context, messageID string) error
// List returns stored messages matching criteria
List(ctx context.Context, criteria PersistenceCriteria) ([]*Message, error)
// GetStats returns persistence statistics
GetStats() PersistenceStats
}
// PersistenceCriteria defines search parameters
type PersistenceCriteria struct {
Topic *string `json:"topic,omitempty"`
MessageType *MessageType `json:"message_type,omitempty"`
Source *string `json:"source,omitempty"`
FromTime *time.Time `json:"from_time,omitempty"`
ToTime *time.Time `json:"to_time,omitempty"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
// PersistenceStats provides storage metrics
type PersistenceStats struct {
StoredMessages uint64 `json:"stored_messages"`
StorageSize uint64 `json:"storage_size_bytes"`
OldestMessage time.Time `json:"oldest_message"`
NewestMessage time.Time `json:"newest_message"`
}

View File

@@ -0,0 +1,230 @@
package transport
import (
"context"
"fmt"
"sync"
"time"
)
// MemoryTransport implements in-memory message transport for local communication
type MemoryTransport struct {
channels map[string]chan *Message
metrics TransportMetrics
connected bool
mu sync.RWMutex
}
// NewMemoryTransport creates a new in-memory transport
func NewMemoryTransport() *MemoryTransport {
return &MemoryTransport{
channels: make(map[string]chan *Message),
metrics: TransportMetrics{},
}
}
// Connect establishes the transport connection
func (mt *MemoryTransport) Connect(ctx context.Context) error {
mt.mu.Lock()
defer mt.mu.Unlock()
if mt.connected {
return nil
}
mt.connected = true
mt.metrics.Connections = 1
return nil
}
// Disconnect closes the transport connection
func (mt *MemoryTransport) Disconnect(ctx context.Context) error {
mt.mu.Lock()
defer mt.mu.Unlock()
if !mt.connected {
return nil
}
// Close all channels
for _, ch := range mt.channels {
close(ch)
}
mt.channels = make(map[string]chan *Message)
mt.connected = false
mt.metrics.Connections = 0
return nil
}
// Send transmits a message through the memory transport
func (mt *MemoryTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
mt.mu.RLock()
if !mt.connected {
mt.mu.RUnlock()
mt.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Get or create channel for topic
ch, exists := mt.channels[msg.Topic]
if !exists {
mt.mu.RUnlock()
mt.mu.Lock()
// Double-check after acquiring write lock
if ch, exists = mt.channels[msg.Topic]; !exists {
ch = make(chan *Message, 1000) // Buffered channel
mt.channels[msg.Topic] = ch
}
mt.mu.Unlock()
} else {
mt.mu.RUnlock()
}
// Send message
select {
case ch <- msg:
mt.updateSendMetrics(msg, time.Since(start))
return nil
case <-ctx.Done():
mt.metrics.Errors++
return ctx.Err()
default:
mt.metrics.Errors++
return fmt.Errorf("channel full for topic: %s", msg.Topic)
}
}
// Receive returns a channel for receiving messages
func (mt *MemoryTransport) Receive(ctx context.Context) (<-chan *Message, error) {
mt.mu.RLock()
defer mt.mu.RUnlock()
if !mt.connected {
return nil, fmt.Errorf("transport not connected")
}
// Create a merged channel that receives from all topic channels
merged := make(chan *Message, 1000)
go func() {
defer close(merged)
// Use a wait group to handle multiple topic channels
var wg sync.WaitGroup
mt.mu.RLock()
for topic, ch := range mt.channels {
wg.Add(1)
go func(topicCh <-chan *Message, topicName string) {
defer wg.Done()
for {
select {
case msg, ok := <-topicCh:
if !ok {
return
}
select {
case merged <- msg:
mt.updateReceiveMetrics(msg)
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}(ch, topic)
}
mt.mu.RUnlock()
wg.Wait()
}()
return merged, nil
}
// Health returns the health status of the transport
func (mt *MemoryTransport) Health() ComponentHealth {
mt.mu.RLock()
defer mt.mu.RUnlock()
status := "unhealthy"
if mt.connected {
status = "healthy"
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: time.Microsecond, // Very fast for memory transport
ErrorCount: mt.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (mt *MemoryTransport) GetMetrics() TransportMetrics {
mt.mu.RLock()
defer mt.mu.RUnlock()
// Create a copy to avoid race conditions
return TransportMetrics{
BytesSent: mt.metrics.BytesSent,
BytesReceived: mt.metrics.BytesReceived,
MessagesSent: mt.metrics.MessagesSent,
MessagesReceived: mt.metrics.MessagesReceived,
Connections: mt.metrics.Connections,
Errors: mt.metrics.Errors,
Latency: mt.metrics.Latency,
}
}
// Private helper methods
func (mt *MemoryTransport) updateSendMetrics(msg *Message, latency time.Duration) {
mt.mu.Lock()
defer mt.mu.Unlock()
mt.metrics.MessagesSent++
mt.metrics.Latency = latency
// Estimate message size (simplified)
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
mt.metrics.BytesSent += messageSize
}
func (mt *MemoryTransport) updateReceiveMetrics(msg *Message) {
mt.mu.Lock()
defer mt.mu.Unlock()
mt.metrics.MessagesReceived++
// Estimate message size (simplified)
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
mt.metrics.BytesReceived += messageSize
}
// GetChannelForTopic returns the channel for a specific topic (for testing/debugging)
func (mt *MemoryTransport) GetChannelForTopic(topic string) (<-chan *Message, bool) {
mt.mu.RLock()
defer mt.mu.RUnlock()
ch, exists := mt.channels[topic]
return ch, exists
}
// GetTopicCount returns the number of active topic channels
func (mt *MemoryTransport) GetTopicCount() int {
mt.mu.RLock()
defer mt.mu.RUnlock()
return len(mt.channels)
}

View File

@@ -0,0 +1,453 @@
package transport
import (
"context"
"fmt"
"sync"
"time"
)
// MessageType represents the type of message being sent
type MessageType string
const (
MessageTypeEvent MessageType = "event"
MessageTypeCommand MessageType = "command"
MessageTypeQuery MessageType = "query"
MessageTypeResponse MessageType = "response"
MessageTypeNotification MessageType = "notification"
MessageTypeHeartbeat MessageType = "heartbeat"
)
// MessagePriority defines message processing priority
type MessagePriority int
const (
PriorityLow MessagePriority = iota
PriorityNormal
PriorityHigh
PriorityCritical
)
// Message represents a universal message in the system
type Message struct {
ID string `json:"id"`
Type MessageType `json:"type"`
Topic string `json:"topic"`
Source string `json:"source"`
Target string `json:"target,omitempty"`
Priority MessagePriority `json:"priority"`
Timestamp time.Time `json:"timestamp"`
Data interface{} `json:"data"`
Headers map[string]string `json:"headers,omitempty"`
CorrelationID string `json:"correlation_id,omitempty"`
TTL time.Duration `json:"ttl,omitempty"`
Retries int `json:"retries"`
MaxRetries int `json:"max_retries"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// MessageHandler processes incoming messages
type MessageHandler func(ctx context.Context, msg *Message) error
// MessageFilter determines if a message should be processed
type MessageFilter func(msg *Message) bool
// Subscription represents a topic subscription
type Subscription struct {
ID string
Topic string
Filter MessageFilter
Handler MessageHandler
Options SubscriptionOptions
created time.Time
active bool
mu sync.RWMutex
}
// SubscriptionOptions configures subscription behavior
type SubscriptionOptions struct {
QueueSize int
BatchSize int
BatchTimeout time.Duration
DLQEnabled bool
RetryEnabled bool
Persistent bool
Durable bool
}
// MessageBusInterface defines the universal message bus contract
type MessageBusInterface interface {
// Core messaging operations
Publish(ctx context.Context, msg *Message) error
Subscribe(topic string, handler MessageHandler, opts ...SubscriptionOption) (*Subscription, error)
Unsubscribe(subscriptionID string) error
// Advanced messaging patterns
Request(ctx context.Context, msg *Message, timeout time.Duration) (*Message, error)
Reply(ctx context.Context, originalMsg *Message, response *Message) error
// Topic management
CreateTopic(topic string, config TopicConfig) error
DeleteTopic(topic string) error
ListTopics() []string
GetTopicInfo(topic string) (*TopicInfo, error)
// Queue operations
QueueMessage(topic string, msg *Message) error
DequeueMessage(topic string, timeout time.Duration) (*Message, error)
PeekMessage(topic string) (*Message, error)
// Dead letter queue
GetDLQMessages(topic string) ([]*Message, error)
ReprocessDLQMessage(messageID string) error
PurgeDLQ(topic string) error
// Lifecycle management
Start(ctx context.Context) error
Stop(ctx context.Context) error
Health() HealthStatus
// Metrics and monitoring
GetMetrics() MessageBusMetrics
GetSubscriptions() []*Subscription
GetActiveConnections() int
}
// SubscriptionOption configures subscription behavior
type SubscriptionOption func(*SubscriptionOptions)
// TopicConfig defines topic configuration
type TopicConfig struct {
Persistent bool
Replicated bool
RetentionPolicy RetentionPolicy
Partitions int
MaxMessageSize int64
TTL time.Duration
}
// RetentionPolicy defines message retention behavior
type RetentionPolicy struct {
MaxMessages int
MaxAge time.Duration
MaxSize int64
}
// TopicInfo provides topic statistics
type TopicInfo struct {
Name string
Config TopicConfig
MessageCount int64
SubscriberCount int
LastActivity time.Time
SizeBytes int64
}
// HealthStatus represents system health
type HealthStatus struct {
Status string
Uptime time.Duration
LastCheck time.Time
Components map[string]ComponentHealth
Errors []HealthError
}
// ComponentHealth represents component-specific health
type ComponentHealth struct {
Status string
LastCheck time.Time
ResponseTime time.Duration
ErrorCount int64
}
// HealthError represents a health check error
type HealthError struct {
Component string
Message string
Timestamp time.Time
Severity string
}
// MessageBusMetrics provides operational metrics
type MessageBusMetrics struct {
MessagesPublished int64
MessagesConsumed int64
MessagesFailed int64
MessagesInDLQ int64
ActiveSubscriptions int
TopicCount int
AverageLatency time.Duration
ThroughputPerSec float64
ErrorRate float64
MemoryUsage int64
CPUUsage float64
}
// TransportType defines available transport mechanisms
type TransportType string
const (
TransportMemory TransportType = "memory"
TransportUnixSocket TransportType = "unix"
TransportTCP TransportType = "tcp"
TransportWebSocket TransportType = "websocket"
TransportRedis TransportType = "redis"
TransportNATS TransportType = "nats"
)
// TransportConfig configures transport layer
type TransportConfig struct {
Type TransportType
Address string
Options map[string]interface{}
RetryConfig RetryConfig
SecurityConfig SecurityConfig
}
// RetryConfig defines retry behavior
type RetryConfig struct {
MaxRetries int
InitialDelay time.Duration
MaxDelay time.Duration
BackoffFactor float64
Jitter bool
}
// SecurityConfig defines security settings
type SecurityConfig struct {
Enabled bool
TLSConfig *TLSConfig
AuthConfig *AuthConfig
Encryption bool
Compression bool
}
// TLSConfig for secure transport
type TLSConfig struct {
CertFile string
KeyFile string
CAFile string
Verify bool
}
// AuthConfig for authentication
type AuthConfig struct {
Username string
Password string
Token string
Method string
}
// UniversalMessageBus implements MessageBusInterface
type UniversalMessageBus struct {
config MessageBusConfig
transports map[TransportType]Transport
router *MessageRouter
topics map[string]*Topic
subscriptions map[string]*Subscription
dlq *DeadLetterQueue
metrics *MetricsCollector
persistence PersistenceLayer
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
started bool
}
// MessageBusConfig configures the message bus
type MessageBusConfig struct {
DefaultTransport TransportType
EnablePersistence bool
EnableMetrics bool
EnableDLQ bool
MaxMessageSize int64
DefaultTTL time.Duration
HealthCheckInterval time.Duration
CleanupInterval time.Duration
}
// Transport interface for different transport mechanisms
type Transport interface {
Send(ctx context.Context, msg *Message) error
Receive(ctx context.Context) (<-chan *Message, error)
Connect(ctx context.Context) error
Disconnect(ctx context.Context) error
Health() ComponentHealth
GetMetrics() TransportMetrics
}
// TransportMetrics for transport-specific metrics
type TransportMetrics struct {
BytesSent int64
BytesReceived int64
MessagesSent int64
MessagesReceived int64
Connections int
Errors int64
Latency time.Duration
}
// MessageRouter handles message routing logic
type MessageRouter struct {
rules []RoutingRule
fallback TransportType
loadBalancer LoadBalancer
mu sync.RWMutex
}
// RoutingRule defines message routing logic
type RoutingRule struct {
Condition MessageFilter
Transport TransportType
Priority int
Enabled bool
}
// LoadBalancer for transport selection
type LoadBalancer interface {
SelectTransport(transports []TransportType, msg *Message) TransportType
UpdateStats(transport TransportType, latency time.Duration, success bool)
}
// Topic represents a message topic
type Topic struct {
Name string
Config TopicConfig
Messages []StoredMessage
Subscribers []*Subscription
Created time.Time
LastActivity time.Time
mu sync.RWMutex
}
// StoredMessage represents a persisted message
type StoredMessage struct {
Message *Message
Stored time.Time
Processed bool
}
// DeadLetterQueue handles failed messages
type DeadLetterQueue struct {
messages map[string][]*Message
config DLQConfig
mu sync.RWMutex
}
// DLQConfig configures dead letter queue
type DLQConfig struct {
MaxMessages int
MaxRetries int
RetentionTime time.Duration
AutoReprocess bool
}
// MetricsCollector gathers operational metrics
type MetricsCollector struct {
metrics map[string]interface{}
mu sync.RWMutex
}
// PersistenceLayer handles message persistence
type PersistenceLayer interface {
Store(msg *Message) error
Retrieve(id string) (*Message, error)
Delete(id string) error
List(topic string, limit int) ([]*Message, error)
Cleanup(maxAge time.Duration) error
}
// Factory functions for common subscription options
func WithQueueSize(size int) SubscriptionOption {
return func(opts *SubscriptionOptions) {
opts.QueueSize = size
}
}
func WithBatchProcessing(size int, timeout time.Duration) SubscriptionOption {
return func(opts *SubscriptionOptions) {
opts.BatchSize = size
opts.BatchTimeout = timeout
}
}
func WithDLQ(enabled bool) SubscriptionOption {
return func(opts *SubscriptionOptions) {
opts.DLQEnabled = enabled
}
}
func WithRetry(enabled bool) SubscriptionOption {
return func(opts *SubscriptionOptions) {
opts.RetryEnabled = enabled
}
}
func WithPersistence(enabled bool) SubscriptionOption {
return func(opts *SubscriptionOptions) {
opts.Persistent = enabled
}
}
// NewUniversalMessageBus creates a new message bus instance
func NewUniversalMessageBus(config MessageBusConfig) *UniversalMessageBus {
ctx, cancel := context.WithCancel(context.Background())
return &UniversalMessageBus{
config: config,
transports: make(map[TransportType]Transport),
topics: make(map[string]*Topic),
subscriptions: make(map[string]*Subscription),
router: NewMessageRouter(),
dlq: NewDeadLetterQueue(DLQConfig{}),
metrics: NewMetricsCollector(),
ctx: ctx,
cancel: cancel,
}
}
// NewMessageRouter creates a new message router
func NewMessageRouter() *MessageRouter {
return &MessageRouter{
rules: make([]RoutingRule, 0),
fallback: TransportMemory,
}
}
// NewDeadLetterQueue creates a new dead letter queue
func NewDeadLetterQueue(config DLQConfig) *DeadLetterQueue {
return &DeadLetterQueue{
messages: make(map[string][]*Message),
config: config,
}
}
// NewMetricsCollector creates a new metrics collector
func NewMetricsCollector() *MetricsCollector {
return &MetricsCollector{
metrics: make(map[string]interface{}),
}
}
// Helper function to generate message ID
func GenerateMessageID() string {
return fmt.Sprintf("msg_%d_%d", time.Now().UnixNano(), time.Now().Nanosecond())
}
// Helper function to create message with defaults
func NewMessage(msgType MessageType, topic string, source string, data interface{}) *Message {
return &Message{
ID: GenerateMessageID(),
Type: msgType,
Topic: topic,
Source: source,
Priority: PriorityNormal,
Timestamp: time.Now(),
Data: data,
Headers: make(map[string]string),
Metadata: make(map[string]interface{}),
Retries: 0,
MaxRetries: 3,
}
}

View File

@@ -0,0 +1,743 @@
package transport
import (
"context"
"fmt"
"sync"
"time"
)
// Publish sends a message to the specified topic
func (mb *UniversalMessageBus) Publish(ctx context.Context, msg *Message) error {
if !mb.started {
return fmt.Errorf("message bus not started")
}
// Validate message
if err := mb.validateMessage(msg); err != nil {
return fmt.Errorf("invalid message: %w", err)
}
// Set timestamp if not set
if msg.Timestamp.IsZero() {
msg.Timestamp = time.Now()
}
// Set ID if not set
if msg.ID == "" {
msg.ID = GenerateMessageID()
}
// Update metrics
mb.metrics.IncrementCounter("messages_published_total")
mb.metrics.RecordLatency("publish_latency", time.Since(msg.Timestamp))
// Route message to appropriate transport
transport, err := mb.router.RouteMessage(msg, mb.transports)
if err != nil {
mb.metrics.IncrementCounter("routing_errors_total")
return fmt.Errorf("routing failed: %w", err)
}
// Send via transport
if err := transport.Send(ctx, msg); err != nil {
mb.metrics.IncrementCounter("send_errors_total")
// Try dead letter queue if enabled
if mb.config.EnableDLQ {
if dlqErr := mb.dlq.AddMessage(msg.Topic, msg); dlqErr != nil {
return fmt.Errorf("send failed and DLQ failed: %v, original error: %w", dlqErr, err)
}
}
return fmt.Errorf("send failed: %w", err)
}
// Store in topic if persistence enabled
if mb.config.EnablePersistence {
if err := mb.addMessageToTopic(msg); err != nil {
// Log error but don't fail the publish
mb.metrics.IncrementCounter("persistence_errors_total")
}
}
// Deliver to local subscribers
go mb.deliverToSubscribers(ctx, msg)
return nil
}
// Subscribe creates a subscription to a topic
func (mb *UniversalMessageBus) Subscribe(topic string, handler MessageHandler, opts ...SubscriptionOption) (*Subscription, error) {
if !mb.started {
return nil, fmt.Errorf("message bus not started")
}
// Apply subscription options
options := SubscriptionOptions{
QueueSize: 1000,
BatchSize: 1,
BatchTimeout: time.Second,
DLQEnabled: mb.config.EnableDLQ,
RetryEnabled: true,
Persistent: false,
Durable: false,
}
for _, opt := range opts {
opt(&options)
}
// Create subscription
subscription := &Subscription{
ID: fmt.Sprintf("sub_%s_%d", topic, time.Now().UnixNano()),
Topic: topic,
Handler: handler,
Options: options,
created: time.Now(),
active: true,
}
mb.mu.Lock()
mb.subscriptions[subscription.ID] = subscription
mb.mu.Unlock()
// Add to topic subscribers
mb.addSubscriberToTopic(topic, subscription)
mb.metrics.IncrementCounter("subscriptions_created_total")
return subscription, nil
}
// Unsubscribe removes a subscription
func (mb *UniversalMessageBus) Unsubscribe(subscriptionID string) error {
mb.mu.Lock()
defer mb.mu.Unlock()
subscription, exists := mb.subscriptions[subscriptionID]
if !exists {
return fmt.Errorf("subscription not found: %s", subscriptionID)
}
// Mark as inactive
subscription.mu.Lock()
subscription.active = false
subscription.mu.Unlock()
// Remove from subscriptions map
delete(mb.subscriptions, subscriptionID)
// Remove from topic subscribers
mb.removeSubscriberFromTopic(subscription.Topic, subscriptionID)
mb.metrics.IncrementCounter("subscriptions_removed_total")
return nil
}
// Request sends a request and waits for a response
func (mb *UniversalMessageBus) Request(ctx context.Context, msg *Message, timeout time.Duration) (*Message, error) {
if !mb.started {
return nil, fmt.Errorf("message bus not started")
}
// Set correlation ID for request-response
if msg.CorrelationID == "" {
msg.CorrelationID = GenerateMessageID()
}
// Create response channel
responseChannel := make(chan *Message, 1)
defer close(responseChannel)
// Subscribe to response topic
responseTopic := fmt.Sprintf("response.%s", msg.CorrelationID)
subscription, err := mb.Subscribe(responseTopic, func(ctx context.Context, response *Message) error {
select {
case responseChannel <- response:
default:
// Channel full, ignore
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to subscribe to response topic: %w", err)
}
defer mb.Unsubscribe(subscription.ID)
// Send request
if err := mb.Publish(ctx, msg); err != nil {
return nil, fmt.Errorf("failed to publish request: %w", err)
}
// Wait for response with timeout
select {
case response := <-responseChannel:
return response, nil
case <-time.After(timeout):
return nil, fmt.Errorf("request timeout after %v", timeout)
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Reply sends a response to a request
func (mb *UniversalMessageBus) Reply(ctx context.Context, originalMsg *Message, response *Message) error {
if originalMsg.CorrelationID == "" {
return fmt.Errorf("original message has no correlation ID")
}
// Set response properties
response.Type = MessageTypeResponse
response.Topic = fmt.Sprintf("response.%s", originalMsg.CorrelationID)
response.CorrelationID = originalMsg.CorrelationID
response.Target = originalMsg.Source
return mb.Publish(ctx, response)
}
// CreateTopic creates a new topic with configuration
func (mb *UniversalMessageBus) CreateTopic(topicName string, config TopicConfig) error {
mb.mu.Lock()
defer mb.mu.Unlock()
if _, exists := mb.topics[topicName]; exists {
return fmt.Errorf("topic already exists: %s", topicName)
}
topic := &Topic{
Name: topicName,
Config: config,
Messages: make([]StoredMessage, 0),
Subscribers: make([]*Subscription, 0),
Created: time.Now(),
LastActivity: time.Now(),
}
mb.topics[topicName] = topic
mb.metrics.IncrementCounter("topics_created_total")
return nil
}
// DeleteTopic removes a topic
func (mb *UniversalMessageBus) DeleteTopic(topicName string) error {
mb.mu.Lock()
defer mb.mu.Unlock()
topic, exists := mb.topics[topicName]
if !exists {
return fmt.Errorf("topic not found: %s", topicName)
}
// Remove all subscribers
for _, sub := range topic.Subscribers {
mb.Unsubscribe(sub.ID)
}
delete(mb.topics, topicName)
mb.metrics.IncrementCounter("topics_deleted_total")
return nil
}
// ListTopics returns all topic names
func (mb *UniversalMessageBus) ListTopics() []string {
mb.mu.RLock()
defer mb.mu.RUnlock()
topics := make([]string, 0, len(mb.topics))
for name := range mb.topics {
topics = append(topics, name)
}
return topics
}
// GetTopicInfo returns topic information
func (mb *UniversalMessageBus) GetTopicInfo(topicName string) (*TopicInfo, error) {
mb.mu.RLock()
defer mb.mu.RUnlock()
topic, exists := mb.topics[topicName]
if !exists {
return nil, fmt.Errorf("topic not found: %s", topicName)
}
topic.mu.RLock()
defer topic.mu.RUnlock()
// Calculate size
var sizeBytes int64
for _, stored := range topic.Messages {
// Rough estimation of message size
sizeBytes += int64(len(fmt.Sprintf("%+v", stored.Message)))
}
return &TopicInfo{
Name: topic.Name,
Config: topic.Config,
MessageCount: int64(len(topic.Messages)),
SubscriberCount: len(topic.Subscribers),
LastActivity: topic.LastActivity,
SizeBytes: sizeBytes,
}, nil
}
// QueueMessage adds a message to a topic queue
func (mb *UniversalMessageBus) QueueMessage(topic string, msg *Message) error {
return mb.addMessageToTopic(msg)
}
// DequeueMessage removes a message from a topic queue
func (mb *UniversalMessageBus) DequeueMessage(topic string, timeout time.Duration) (*Message, error) {
start := time.Now()
for time.Since(start) < timeout {
mb.mu.RLock()
topicObj, exists := mb.topics[topic]
mb.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("topic not found: %s", topic)
}
topicObj.mu.Lock()
if len(topicObj.Messages) > 0 {
// Get first unprocessed message
for i, stored := range topicObj.Messages {
if !stored.Processed {
topicObj.Messages[i].Processed = true
topicObj.mu.Unlock()
return stored.Message, nil
}
}
}
topicObj.mu.Unlock()
// Wait a bit before trying again
time.Sleep(10 * time.Millisecond)
}
return nil, fmt.Errorf("no message available within timeout")
}
// PeekMessage returns the next message without removing it
func (mb *UniversalMessageBus) PeekMessage(topic string) (*Message, error) {
mb.mu.RLock()
topicObj, exists := mb.topics[topic]
mb.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("topic not found: %s", topic)
}
topicObj.mu.RLock()
defer topicObj.mu.RUnlock()
for _, stored := range topicObj.Messages {
if !stored.Processed {
return stored.Message, nil
}
}
return nil, fmt.Errorf("no messages available")
}
// Start initializes and starts the message bus
func (mb *UniversalMessageBus) Start(ctx context.Context) error {
mb.mu.Lock()
defer mb.mu.Unlock()
if mb.started {
return fmt.Errorf("message bus already started")
}
// Initialize default transport if none configured
if len(mb.transports) == 0 {
memTransport := NewMemoryTransport()
mb.transports[TransportMemory] = memTransport
if err := memTransport.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect default transport: %w", err)
}
}
// Start background routines
go mb.healthCheckLoop()
go mb.cleanupLoop()
go mb.metricsLoop()
mb.started = true
mb.metrics.RecordEvent("message_bus_started")
return nil
}
// Stop gracefully shuts down the message bus
func (mb *UniversalMessageBus) Stop(ctx context.Context) error {
mb.mu.Lock()
defer mb.mu.Unlock()
if !mb.started {
return nil
}
// Cancel context to stop background routines
mb.cancel()
// Disconnect all transports
for _, transport := range mb.transports {
if err := transport.Disconnect(ctx); err != nil {
// Log error but continue shutdown
}
}
mb.started = false
mb.metrics.RecordEvent("message_bus_stopped")
return nil
}
// Health returns the current health status
func (mb *UniversalMessageBus) Health() HealthStatus {
components := make(map[string]ComponentHealth)
// Check transport health
for transportType, transport := range mb.transports {
components[string(transportType)] = transport.Health()
}
// Overall status
status := "healthy"
var errors []HealthError
for name, component := range components {
if component.Status != "healthy" {
status = "degraded"
errors = append(errors, HealthError{
Component: name,
Message: fmt.Sprintf("Component %s is %s", name, component.Status),
Timestamp: time.Now(),
Severity: "warning",
})
}
}
return HealthStatus{
Status: status,
Uptime: time.Since(time.Now()), // Would track actual uptime
LastCheck: time.Now(),
Components: components,
Errors: errors,
}
}
// GetMetrics returns current operational metrics
func (mb *UniversalMessageBus) GetMetrics() MessageBusMetrics {
metrics := mb.metrics.GetAll()
return MessageBusMetrics{
MessagesPublished: mb.getMetricInt64("messages_published_total"),
MessagesConsumed: mb.getMetricInt64("messages_consumed_total"),
MessagesFailed: mb.getMetricInt64("send_errors_total"),
MessagesInDLQ: int64(mb.dlq.GetMessageCount()),
ActiveSubscriptions: len(mb.subscriptions),
TopicCount: len(mb.topics),
AverageLatency: mb.getMetricDuration("average_latency"),
ThroughputPerSec: mb.getMetricFloat64("throughput_per_second"),
ErrorRate: mb.getMetricFloat64("error_rate"),
MemoryUsage: mb.getMetricInt64("memory_usage_bytes"),
CPUUsage: mb.getMetricFloat64("cpu_usage_percent"),
}
}
// GetSubscriptions returns all active subscriptions
func (mb *UniversalMessageBus) GetSubscriptions() []*Subscription {
mb.mu.RLock()
defer mb.mu.RUnlock()
subscriptions := make([]*Subscription, 0, len(mb.subscriptions))
for _, sub := range mb.subscriptions {
subscriptions = append(subscriptions, sub)
}
return subscriptions
}
// GetActiveConnections returns the number of active connections
func (mb *UniversalMessageBus) GetActiveConnections() int {
count := 0
for _, transport := range mb.transports {
metrics := transport.GetMetrics()
count += metrics.Connections
}
return count
}
// Helper methods
func (mb *UniversalMessageBus) validateMessage(msg *Message) error {
if msg == nil {
return fmt.Errorf("message is nil")
}
if msg.Topic == "" {
return fmt.Errorf("message topic is empty")
}
if msg.Source == "" {
return fmt.Errorf("message source is empty")
}
if msg.Data == nil {
return fmt.Errorf("message data is nil")
}
return nil
}
func (mb *UniversalMessageBus) addMessageToTopic(msg *Message) error {
mb.mu.RLock()
topic, exists := mb.topics[msg.Topic]
mb.mu.RUnlock()
if !exists {
// Create topic automatically
config := TopicConfig{
Persistent: true,
RetentionPolicy: RetentionPolicy{MaxMessages: 10000, MaxAge: 24 * time.Hour},
}
if err := mb.CreateTopic(msg.Topic, config); err != nil {
return err
}
topic = mb.topics[msg.Topic]
}
topic.mu.Lock()
defer topic.mu.Unlock()
stored := StoredMessage{
Message: msg,
Stored: time.Now(),
Processed: false,
}
topic.Messages = append(topic.Messages, stored)
topic.LastActivity = time.Now()
// Apply retention policy
mb.applyRetentionPolicy(topic)
return nil
}
func (mb *UniversalMessageBus) addSubscriberToTopic(topicName string, subscription *Subscription) {
mb.mu.RLock()
topic, exists := mb.topics[topicName]
mb.mu.RUnlock()
if !exists {
// Create topic automatically
config := TopicConfig{Persistent: false}
mb.CreateTopic(topicName, config)
topic = mb.topics[topicName]
}
topic.mu.Lock()
topic.Subscribers = append(topic.Subscribers, subscription)
topic.mu.Unlock()
}
func (mb *UniversalMessageBus) removeSubscriberFromTopic(topicName, subscriptionID string) {
mb.mu.RLock()
topic, exists := mb.topics[topicName]
mb.mu.RUnlock()
if !exists {
return
}
topic.mu.Lock()
defer topic.mu.Unlock()
for i, sub := range topic.Subscribers {
if sub.ID == subscriptionID {
topic.Subscribers = append(topic.Subscribers[:i], topic.Subscribers[i+1:]...)
break
}
}
}
func (mb *UniversalMessageBus) deliverToSubscribers(ctx context.Context, msg *Message) {
mb.mu.RLock()
topic, exists := mb.topics[msg.Topic]
mb.mu.RUnlock()
if !exists {
return
}
topic.mu.RLock()
subscribers := make([]*Subscription, len(topic.Subscribers))
copy(subscribers, topic.Subscribers)
topic.mu.RUnlock()
for _, sub := range subscribers {
sub.mu.RLock()
if !sub.active {
sub.mu.RUnlock()
continue
}
// Apply filter if present
if sub.Filter != nil && !sub.Filter(msg) {
sub.mu.RUnlock()
continue
}
handler := sub.Handler
sub.mu.RUnlock()
// Deliver message in goroutine
go func(subscription *Subscription, message *Message) {
defer func() {
if r := recover(); r != nil {
mb.metrics.IncrementCounter("handler_panics_total")
}
}()
if err := handler(ctx, message); err != nil {
mb.metrics.IncrementCounter("handler_errors_total")
if mb.config.EnableDLQ {
mb.dlq.AddMessage(message.Topic, message)
}
} else {
mb.metrics.IncrementCounter("messages_consumed_total")
}
}(sub, msg)
}
}
func (mb *UniversalMessageBus) applyRetentionPolicy(topic *Topic) {
policy := topic.Config.RetentionPolicy
// Remove old messages
if policy.MaxAge > 0 {
cutoff := time.Now().Add(-policy.MaxAge)
filtered := make([]StoredMessage, 0)
for _, stored := range topic.Messages {
if stored.Stored.After(cutoff) {
filtered = append(filtered, stored)
}
}
topic.Messages = filtered
}
// Limit number of messages
if policy.MaxMessages > 0 && len(topic.Messages) > policy.MaxMessages {
topic.Messages = topic.Messages[len(topic.Messages)-policy.MaxMessages:]
}
}
func (mb *UniversalMessageBus) healthCheckLoop() {
ticker := time.NewTicker(mb.config.HealthCheckInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
mb.performHealthCheck()
case <-mb.ctx.Done():
return
}
}
}
func (mb *UniversalMessageBus) cleanupLoop() {
ticker := time.NewTicker(mb.config.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
mb.performCleanup()
case <-mb.ctx.Done():
return
}
}
}
func (mb *UniversalMessageBus) metricsLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
mb.updateMetrics()
case <-mb.ctx.Done():
return
}
}
}
func (mb *UniversalMessageBus) performHealthCheck() {
// Check all transports
for _, transport := range mb.transports {
health := transport.Health()
mb.metrics.RecordGauge(fmt.Sprintf("transport_%s_healthy", health.Component),
map[string]float64{"healthy": 1, "unhealthy": 0, "degraded": 0.5}[health.Status])
}
}
func (mb *UniversalMessageBus) performCleanup() {
// Clean up processed messages in topics
mb.mu.RLock()
topics := make([]*Topic, 0, len(mb.topics))
for _, topic := range mb.topics {
topics = append(topics, topic)
}
mb.mu.RUnlock()
for _, topic := range topics {
topic.mu.Lock()
mb.applyRetentionPolicy(topic)
topic.mu.Unlock()
}
// Clean up DLQ
mb.dlq.Cleanup(time.Hour * 24) // Clean messages older than 24 hours
}
func (mb *UniversalMessageBus) updateMetrics() {
// Update throughput metrics
publishedCount := mb.getMetricInt64("messages_published_total")
if publishedCount > 0 {
// Calculate per-second rate (simplified)
mb.metrics.RecordGauge("throughput_per_second", float64(publishedCount)/60.0)
}
// Update error rate
errorCount := mb.getMetricInt64("send_errors_total")
totalCount := publishedCount
if totalCount > 0 {
errorRate := float64(errorCount) / float64(totalCount)
mb.metrics.RecordGauge("error_rate", errorRate)
}
}
func (mb *UniversalMessageBus) getMetricInt64(key string) int64 {
if val, ok := mb.metrics.Get(key).(int64); ok {
return val
}
return 0
}
func (mb *UniversalMessageBus) getMetricFloat64(key string) float64 {
if val, ok := mb.metrics.Get(key).(float64); ok {
return val
}
return 0
}
func (mb *UniversalMessageBus) getMetricDuration(key string) time.Duration {
if val, ok := mb.metrics.Get(key).(time.Duration); ok {
return val
}
return 0
}

View File

@@ -0,0 +1,622 @@
package transport
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sort"
"sync"
"time"
)
// FilePersistenceLayer implements file-based message persistence
type FilePersistenceLayer struct {
basePath string
maxFileSize int64
maxFiles int
compression bool
encryption EncryptionConfig
mu sync.RWMutex
}
// EncryptionConfig configures message encryption at rest
type EncryptionConfig struct {
Enabled bool
Algorithm string
Key []byte
}
// PersistedMessage represents a message stored on disk
type PersistedMessage struct {
ID string `json:"id"`
Topic string `json:"topic"`
Message *Message `json:"message"`
Stored time.Time `json:"stored"`
Metadata map[string]interface{} `json:"metadata"`
Encrypted bool `json:"encrypted"`
}
// PersistenceMetrics tracks persistence layer statistics
type PersistenceMetrics struct {
MessagesStored int64
MessagesRetrieved int64
MessagesDeleted int64
StorageSize int64
FileCount int
LastCleanup time.Time
Errors int64
}
// NewFilePersistenceLayer creates a new file-based persistence layer
func NewFilePersistenceLayer(basePath string) *FilePersistenceLayer {
return &FilePersistenceLayer{
basePath: basePath,
maxFileSize: 100 * 1024 * 1024, // 100MB default
maxFiles: 1000,
compression: false,
}
}
// SetMaxFileSize configures the maximum file size
func (fpl *FilePersistenceLayer) SetMaxFileSize(size int64) {
fpl.mu.Lock()
defer fpl.mu.Unlock()
fpl.maxFileSize = size
}
// SetMaxFiles configures the maximum number of files
func (fpl *FilePersistenceLayer) SetMaxFiles(count int) {
fpl.mu.Lock()
defer fpl.mu.Unlock()
fpl.maxFiles = count
}
// EnableCompression enables/disables compression
func (fpl *FilePersistenceLayer) EnableCompression(enabled bool) {
fpl.mu.Lock()
defer fpl.mu.Unlock()
fpl.compression = enabled
}
// SetEncryption configures encryption settings
func (fpl *FilePersistenceLayer) SetEncryption(config EncryptionConfig) {
fpl.mu.Lock()
defer fpl.mu.Unlock()
fpl.encryption = config
}
// Store persists a message to disk
func (fpl *FilePersistenceLayer) Store(msg *Message) error {
fpl.mu.Lock()
defer fpl.mu.Unlock()
// Create directory if it doesn't exist
topicDir := filepath.Join(fpl.basePath, msg.Topic)
if err := os.MkdirAll(topicDir, 0755); err != nil {
return fmt.Errorf("failed to create topic directory: %w", err)
}
// Create persisted message
persistedMsg := &PersistedMessage{
ID: msg.ID,
Topic: msg.Topic,
Message: msg,
Stored: time.Now(),
Metadata: make(map[string]interface{}),
}
// Serialize message
data, err := json.Marshal(persistedMsg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
// Apply encryption if enabled
if fpl.encryption.Enabled {
encryptedData, err := fpl.encrypt(data)
if err != nil {
return fmt.Errorf("encryption failed: %w", err)
}
data = encryptedData
persistedMsg.Encrypted = true
}
// Apply compression if enabled
if fpl.compression {
compressedData, err := fpl.compress(data)
if err != nil {
return fmt.Errorf("compression failed: %w", err)
}
data = compressedData
}
// Find appropriate file to write to
filename, err := fpl.getWritableFile(topicDir, len(data))
if err != nil {
return fmt.Errorf("failed to get writable file: %w", err)
}
// Write to file
file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
// Write length prefix and data
lengthPrefix := fmt.Sprintf("%d\n", len(data))
if _, err := file.WriteString(lengthPrefix); err != nil {
return fmt.Errorf("failed to write length prefix: %w", err)
}
if _, err := file.Write(data); err != nil {
return fmt.Errorf("failed to write data: %w", err)
}
return nil
}
// Retrieve loads a message from disk by ID
func (fpl *FilePersistenceLayer) Retrieve(id string) (*Message, error) {
fpl.mu.RLock()
defer fpl.mu.RUnlock()
// Search all topic directories
topicDirs, err := fpl.getTopicDirectories()
if err != nil {
return nil, fmt.Errorf("failed to get topic directories: %w", err)
}
for _, topicDir := range topicDirs {
files, err := fpl.getTopicFiles(topicDir)
if err != nil {
continue
}
for _, file := range files {
msg, err := fpl.findMessageInFile(file, id)
if err != nil {
continue
}
if msg != nil {
return msg, nil
}
}
}
return nil, fmt.Errorf("message not found: %s", id)
}
// Delete removes a message from disk by ID
func (fpl *FilePersistenceLayer) Delete(id string) error {
fpl.mu.Lock()
defer fpl.mu.Unlock()
// This is a simplified implementation
// In a production system, you might want to mark messages as deleted
// and compact files periodically instead of rewriting entire files
return fmt.Errorf("delete operation not yet implemented")
}
// List returns messages for a topic with optional limit
func (fpl *FilePersistenceLayer) List(topic string, limit int) ([]*Message, error) {
fpl.mu.RLock()
defer fpl.mu.RUnlock()
topicDir := filepath.Join(fpl.basePath, topic)
if _, err := os.Stat(topicDir); os.IsNotExist(err) {
return []*Message{}, nil
}
files, err := fpl.getTopicFiles(topicDir)
if err != nil {
return nil, fmt.Errorf("failed to get topic files: %w", err)
}
var messages []*Message
count := 0
// Read files in chronological order (newest first)
sort.Slice(files, func(i, j int) bool {
infoI, _ := os.Stat(files[i])
infoJ, _ := os.Stat(files[j])
return infoI.ModTime().After(infoJ.ModTime())
})
for _, file := range files {
fileMessages, err := fpl.readMessagesFromFile(file)
if err != nil {
continue
}
for _, msg := range fileMessages {
messages = append(messages, msg)
count++
if limit > 0 && count >= limit {
break
}
}
if limit > 0 && count >= limit {
break
}
}
return messages, nil
}
// Cleanup removes messages older than maxAge
func (fpl *FilePersistenceLayer) Cleanup(maxAge time.Duration) error {
fpl.mu.Lock()
defer fpl.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
topicDirs, err := fpl.getTopicDirectories()
if err != nil {
return fmt.Errorf("failed to get topic directories: %w", err)
}
for _, topicDir := range topicDirs {
files, err := fpl.getTopicFiles(topicDir)
if err != nil {
continue
}
for _, file := range files {
// Check file modification time
info, err := os.Stat(file)
if err != nil {
continue
}
if info.ModTime().Before(cutoff) {
os.Remove(file)
}
}
// Remove empty topic directories
if isEmpty, _ := fpl.isDirectoryEmpty(topicDir); isEmpty {
os.Remove(topicDir)
}
}
return nil
}
// GetMetrics returns persistence layer metrics
func (fpl *FilePersistenceLayer) GetMetrics() (PersistenceMetrics, error) {
fpl.mu.RLock()
defer fpl.mu.RUnlock()
metrics := PersistenceMetrics{}
// Calculate storage size and file count
err := filepath.Walk(fpl.basePath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
metrics.FileCount++
metrics.StorageSize += info.Size()
}
return nil
})
return metrics, err
}
// Private helper methods
func (fpl *FilePersistenceLayer) getWritableFile(topicDir string, dataSize int) (string, error) {
files, err := fpl.getTopicFiles(topicDir)
if err != nil {
return "", err
}
// Find a file with enough space
for _, file := range files {
info, err := os.Stat(file)
if err != nil {
continue
}
if info.Size()+int64(dataSize) <= fpl.maxFileSize {
return file, nil
}
}
// Create new file
timestamp := time.Now().Format("20060102_150405")
filename := filepath.Join(topicDir, fmt.Sprintf("messages_%s.dat", timestamp))
return filename, nil
}
func (fpl *FilePersistenceLayer) getTopicDirectories() ([]string, error) {
entries, err := ioutil.ReadDir(fpl.basePath)
if err != nil {
return nil, err
}
var dirs []string
for _, entry := range entries {
if entry.IsDir() {
dirs = append(dirs, filepath.Join(fpl.basePath, entry.Name()))
}
}
return dirs, nil
}
func (fpl *FilePersistenceLayer) getTopicFiles(topicDir string) ([]string, error) {
entries, err := ioutil.ReadDir(topicDir)
if err != nil {
return nil, err
}
var files []string
for _, entry := range entries {
if !entry.IsDir() && filepath.Ext(entry.Name()) == ".dat" {
files = append(files, filepath.Join(topicDir, entry.Name()))
}
}
return files, nil
}
func (fpl *FilePersistenceLayer) findMessageInFile(filename, messageID string) (*Message, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
data, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}
// Parse messages from file data
messages, err := fpl.parseFileData(data)
if err != nil {
return nil, err
}
for _, msg := range messages {
if msg.ID == messageID {
return msg, nil
}
}
return nil, nil
}
func (fpl *FilePersistenceLayer) readMessagesFromFile(filename string) ([]*Message, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
data, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}
return fpl.parseFileData(data)
}
func (fpl *FilePersistenceLayer) parseFileData(data []byte) ([]*Message, error) {
var messages []*Message
offset := 0
for offset < len(data) {
// Read length prefix
lengthEnd := -1
for i := offset; i < len(data); i++ {
if data[i] == '\n' {
lengthEnd = i
break
}
}
if lengthEnd == -1 {
break // No more complete messages
}
lengthStr := string(data[offset:lengthEnd])
var messageLength int
if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil {
break // Invalid length prefix
}
messageStart := lengthEnd + 1
messageEnd := messageStart + messageLength
if messageEnd > len(data) {
break // Incomplete message
}
messageData := data[messageStart:messageEnd]
// Apply decompression if needed
if fpl.compression {
decompressed, err := fpl.decompress(messageData)
if err != nil {
offset = messageEnd
continue
}
messageData = decompressed
}
// Apply decryption if needed
if fpl.encryption.Enabled {
decrypted, err := fpl.decrypt(messageData)
if err != nil {
offset = messageEnd
continue
}
messageData = decrypted
}
// Parse message
var persistedMsg PersistedMessage
if err := json.Unmarshal(messageData, &persistedMsg); err != nil {
offset = messageEnd
continue
}
messages = append(messages, persistedMsg.Message)
offset = messageEnd
}
return messages, nil
}
func (fpl *FilePersistenceLayer) isDirectoryEmpty(dir string) (bool, error) {
entries, err := ioutil.ReadDir(dir)
if err != nil {
return false, err
}
return len(entries) == 0, nil
}
func (fpl *FilePersistenceLayer) encrypt(data []byte) ([]byte, error) {
// Placeholder for encryption implementation
// In a real implementation, you would use proper encryption libraries
return data, nil
}
func (fpl *FilePersistenceLayer) decrypt(data []byte) ([]byte, error) {
// Placeholder for decryption implementation
return data, nil
}
func (fpl *FilePersistenceLayer) compress(data []byte) ([]byte, error) {
// Placeholder for compression implementation
// In a real implementation, you would use libraries like gzip
return data, nil
}
func (fpl *FilePersistenceLayer) decompress(data []byte) ([]byte, error) {
// Placeholder for decompression implementation
return data, nil
}
// InMemoryPersistenceLayer implements in-memory persistence for testing/development
type InMemoryPersistenceLayer struct {
messages map[string]*Message
topics map[string][]string
mu sync.RWMutex
}
// NewInMemoryPersistenceLayer creates a new in-memory persistence layer
func NewInMemoryPersistenceLayer() *InMemoryPersistenceLayer {
return &InMemoryPersistenceLayer{
messages: make(map[string]*Message),
topics: make(map[string][]string),
}
}
// Store stores a message in memory
func (impl *InMemoryPersistenceLayer) Store(msg *Message) error {
impl.mu.Lock()
defer impl.mu.Unlock()
impl.messages[msg.ID] = msg
if _, exists := impl.topics[msg.Topic]; !exists {
impl.topics[msg.Topic] = make([]string, 0)
}
impl.topics[msg.Topic] = append(impl.topics[msg.Topic], msg.ID)
return nil
}
// Retrieve retrieves a message from memory
func (impl *InMemoryPersistenceLayer) Retrieve(id string) (*Message, error) {
impl.mu.RLock()
defer impl.mu.RUnlock()
msg, exists := impl.messages[id]
if !exists {
return nil, fmt.Errorf("message not found: %s", id)
}
return msg, nil
}
// Delete removes a message from memory
func (impl *InMemoryPersistenceLayer) Delete(id string) error {
impl.mu.Lock()
defer impl.mu.Unlock()
msg, exists := impl.messages[id]
if !exists {
return fmt.Errorf("message not found: %s", id)
}
delete(impl.messages, id)
// Remove from topic index
if messageIDs, exists := impl.topics[msg.Topic]; exists {
for i, msgID := range messageIDs {
if msgID == id {
impl.topics[msg.Topic] = append(messageIDs[:i], messageIDs[i+1:]...)
break
}
}
}
return nil
}
// List returns messages for a topic
func (impl *InMemoryPersistenceLayer) List(topic string, limit int) ([]*Message, error) {
impl.mu.RLock()
defer impl.mu.RUnlock()
messageIDs, exists := impl.topics[topic]
if !exists {
return []*Message{}, nil
}
var messages []*Message
count := 0
for _, msgID := range messageIDs {
if limit > 0 && count >= limit {
break
}
if msg, exists := impl.messages[msgID]; exists {
messages = append(messages, msg)
count++
}
}
return messages, nil
}
// Cleanup removes messages older than maxAge
func (impl *InMemoryPersistenceLayer) Cleanup(maxAge time.Duration) error {
impl.mu.Lock()
defer impl.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
var toDelete []string
for id, msg := range impl.messages {
if msg.Timestamp.Before(cutoff) {
toDelete = append(toDelete, id)
}
}
for _, id := range toDelete {
impl.Delete(id)
}
return nil
}

478
pkg/transport/router.go Normal file
View File

@@ -0,0 +1,478 @@
package transport
import (
"fmt"
"math/rand"
"sort"
"sync"
"time"
)
// MessageRouter handles intelligent message routing and transport selection
type MessageRouter struct {
rules []RoutingRule
fallback TransportType
loadBalancer LoadBalancer
mu sync.RWMutex
}
// RoutingRule defines message routing logic
type RoutingRule struct {
ID string
Name string
Condition MessageFilter
Transport TransportType
Priority int
Enabled bool
Created time.Time
LastUsed time.Time
UsageCount int64
}
// RouteMessage selects the appropriate transport for a message
func (mr *MessageRouter) RouteMessage(msg *Message, transports map[TransportType]Transport) (Transport, error) {
mr.mu.RLock()
defer mr.mu.RUnlock()
// Find matching rules (sorted by priority)
matchingRules := mr.findMatchingRules(msg)
// Try each matching rule in priority order
for _, rule := range matchingRules {
if transport, exists := transports[rule.Transport]; exists {
// Check transport health
if health := transport.Health(); health.Status == "healthy" {
mr.updateRuleUsage(rule.ID)
return transport, nil
}
}
}
// Use load balancer for available transports
if mr.loadBalancer != nil {
availableTransports := mr.getHealthyTransports(transports)
if len(availableTransports) > 0 {
selectedType := mr.loadBalancer.SelectTransport(availableTransports, msg)
if transport, exists := transports[selectedType]; exists {
return transport, nil
}
}
}
// Fall back to default transport
if fallbackTransport, exists := transports[mr.fallback]; exists {
if health := fallbackTransport.Health(); health.Status != "unhealthy" {
return fallbackTransport, nil
}
}
return nil, fmt.Errorf("no available transport for message")
}
// AddRule adds a new routing rule
func (mr *MessageRouter) AddRule(rule RoutingRule) {
mr.mu.Lock()
defer mr.mu.Unlock()
if rule.ID == "" {
rule.ID = fmt.Sprintf("rule_%d", time.Now().UnixNano())
}
rule.Created = time.Now()
rule.Enabled = true
mr.rules = append(mr.rules, rule)
mr.sortRulesByPriority()
}
// RemoveRule removes a routing rule by ID
func (mr *MessageRouter) RemoveRule(ruleID string) bool {
mr.mu.Lock()
defer mr.mu.Unlock()
for i, rule := range mr.rules {
if rule.ID == ruleID {
mr.rules = append(mr.rules[:i], mr.rules[i+1:]...)
return true
}
}
return false
}
// UpdateRule updates an existing routing rule
func (mr *MessageRouter) UpdateRule(ruleID string, updates func(*RoutingRule)) bool {
mr.mu.Lock()
defer mr.mu.Unlock()
for i := range mr.rules {
if mr.rules[i].ID == ruleID {
updates(&mr.rules[i])
mr.sortRulesByPriority()
return true
}
}
return false
}
// GetRules returns all routing rules
func (mr *MessageRouter) GetRules() []RoutingRule {
mr.mu.RLock()
defer mr.mu.RUnlock()
rules := make([]RoutingRule, len(mr.rules))
copy(rules, mr.rules)
return rules
}
// EnableRule enables a routing rule
func (mr *MessageRouter) EnableRule(ruleID string) bool {
return mr.UpdateRule(ruleID, func(rule *RoutingRule) {
rule.Enabled = true
})
}
// DisableRule disables a routing rule
func (mr *MessageRouter) DisableRule(ruleID string) bool {
return mr.UpdateRule(ruleID, func(rule *RoutingRule) {
rule.Enabled = false
})
}
// SetFallbackTransport sets the fallback transport type
func (mr *MessageRouter) SetFallbackTransport(transportType TransportType) {
mr.mu.Lock()
defer mr.mu.Unlock()
mr.fallback = transportType
}
// SetLoadBalancer sets the load balancer
func (mr *MessageRouter) SetLoadBalancer(lb LoadBalancer) {
mr.mu.Lock()
defer mr.mu.Unlock()
mr.loadBalancer = lb
}
// Private helper methods
func (mr *MessageRouter) findMatchingRules(msg *Message) []RoutingRule {
var matching []RoutingRule
for _, rule := range mr.rules {
if rule.Enabled && (rule.Condition == nil || rule.Condition(msg)) {
matching = append(matching, rule)
}
}
return matching
}
func (mr *MessageRouter) sortRulesByPriority() {
sort.Slice(mr.rules, func(i, j int) bool {
return mr.rules[i].Priority > mr.rules[j].Priority
})
}
func (mr *MessageRouter) updateRuleUsage(ruleID string) {
for i := range mr.rules {
if mr.rules[i].ID == ruleID {
mr.rules[i].LastUsed = time.Now()
mr.rules[i].UsageCount++
break
}
}
}
func (mr *MessageRouter) getHealthyTransports(transports map[TransportType]Transport) []TransportType {
var healthy []TransportType
for transportType, transport := range transports {
if health := transport.Health(); health.Status == "healthy" {
healthy = append(healthy, transportType)
}
}
return healthy
}
// LoadBalancer implementations
// RoundRobinLoadBalancer implements round-robin load balancing
type RoundRobinLoadBalancer struct {
counter int64
mu sync.Mutex
}
func NewRoundRobinLoadBalancer() *RoundRobinLoadBalancer {
return &RoundRobinLoadBalancer{}
}
func (lb *RoundRobinLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
if len(transports) == 0 {
return ""
}
lb.mu.Lock()
defer lb.mu.Unlock()
selected := transports[lb.counter%int64(len(transports))]
lb.counter++
return selected
}
func (lb *RoundRobinLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
// Round-robin doesn't use stats
}
// WeightedLoadBalancer implements weighted load balancing based on performance
type WeightedLoadBalancer struct {
stats map[TransportType]*TransportStats
mu sync.RWMutex
}
type TransportStats struct {
TotalRequests int64
SuccessRequests int64
TotalLatency time.Duration
LastUpdate time.Time
Weight float64
}
func NewWeightedLoadBalancer() *WeightedLoadBalancer {
return &WeightedLoadBalancer{
stats: make(map[TransportType]*TransportStats),
}
}
func (lb *WeightedLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
if len(transports) == 0 {
return ""
}
lb.mu.RLock()
defer lb.mu.RUnlock()
// Calculate weights and select based on weighted random selection
totalWeight := 0.0
weights := make(map[TransportType]float64)
for _, transport := range transports {
weight := lb.calculateWeight(transport)
weights[transport] = weight
totalWeight += weight
}
if totalWeight == 0 {
// Fall back to random selection
return transports[rand.Intn(len(transports))]
}
// Weighted random selection
target := rand.Float64() * totalWeight
current := 0.0
for _, transport := range transports {
current += weights[transport]
if current >= target {
return transport
}
}
// Fallback (shouldn't happen)
return transports[0]
}
func (lb *WeightedLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
lb.mu.Lock()
defer lb.mu.Unlock()
stats, exists := lb.stats[transport]
if !exists {
stats = &TransportStats{
Weight: 1.0, // Default weight
}
lb.stats[transport] = stats
}
stats.TotalRequests++
stats.TotalLatency += latency
stats.LastUpdate = time.Now()
if success {
stats.SuccessRequests++
}
// Recalculate weight based on performance
stats.Weight = lb.calculateWeight(transport)
}
func (lb *WeightedLoadBalancer) calculateWeight(transport TransportType) float64 {
stats, exists := lb.stats[transport]
if !exists {
return 1.0 // Default weight for unknown transports
}
if stats.TotalRequests == 0 {
return 1.0
}
// Calculate success rate
successRate := float64(stats.SuccessRequests) / float64(stats.TotalRequests)
// Calculate average latency
avgLatency := stats.TotalLatency / time.Duration(stats.TotalRequests)
// Weight formula: success rate / (latency factor)
// Lower latency and higher success rate = higher weight
latencyFactor := float64(avgLatency) / float64(time.Millisecond)
if latencyFactor < 1 {
latencyFactor = 1
}
weight := successRate / latencyFactor
// Ensure minimum weight
if weight < 0.1 {
weight = 0.1
}
return weight
}
// LeastLatencyLoadBalancer selects the transport with the lowest latency
type LeastLatencyLoadBalancer struct {
stats map[TransportType]*LatencyStats
mu sync.RWMutex
}
type LatencyStats struct {
RecentLatencies []time.Duration
MaxSamples int
LastUpdate time.Time
}
func NewLeastLatencyLoadBalancer() *LeastLatencyLoadBalancer {
return &LeastLatencyLoadBalancer{
stats: make(map[TransportType]*LatencyStats),
}
}
func (lb *LeastLatencyLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
if len(transports) == 0 {
return ""
}
lb.mu.RLock()
defer lb.mu.RUnlock()
bestTransport := transports[0]
bestLatency := time.Hour // Large initial value
for _, transport := range transports {
avgLatency := lb.getAverageLatency(transport)
if avgLatency < bestLatency {
bestLatency = avgLatency
bestTransport = transport
}
}
return bestTransport
}
func (lb *LeastLatencyLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
if !success {
return // Only track successful requests
}
lb.mu.Lock()
defer lb.mu.Unlock()
stats, exists := lb.stats[transport]
if !exists {
stats = &LatencyStats{
RecentLatencies: make([]time.Duration, 0),
MaxSamples: 10, // Keep last 10 samples
}
lb.stats[transport] = stats
}
// Add new latency sample
stats.RecentLatencies = append(stats.RecentLatencies, latency)
// Keep only recent samples
if len(stats.RecentLatencies) > stats.MaxSamples {
stats.RecentLatencies = stats.RecentLatencies[1:]
}
stats.LastUpdate = time.Now()
}
func (lb *LeastLatencyLoadBalancer) getAverageLatency(transport TransportType) time.Duration {
stats, exists := lb.stats[transport]
if !exists || len(stats.RecentLatencies) == 0 {
return time.Millisecond * 100 // Default estimate
}
total := time.Duration(0)
for _, latency := range stats.RecentLatencies {
total += latency
}
return total / time.Duration(len(stats.RecentLatencies))
}
// Common routing rule factory functions
// CreateTopicRule creates a rule based on message topic
func CreateTopicRule(name string, topic string, transport TransportType, priority int) RoutingRule {
return RoutingRule{
Name: name,
Condition: func(msg *Message) bool { return msg.Topic == topic },
Transport: transport,
Priority: priority,
}
}
// CreateTopicPatternRule creates a rule based on topic pattern matching
func CreateTopicPatternRule(name string, pattern string, transport TransportType, priority int) RoutingRule {
return RoutingRule{
Name: name,
Condition: func(msg *Message) bool {
// Simple pattern matching (can be enhanced with regex)
return msg.Topic == pattern ||
(len(pattern) > 0 && pattern[len(pattern)-1] == '*' &&
len(msg.Topic) >= len(pattern)-1 &&
msg.Topic[:len(pattern)-1] == pattern[:len(pattern)-1])
},
Transport: transport,
Priority: priority,
}
}
// CreatePriorityRule creates a rule based on message priority
func CreatePriorityRule(name string, msgPriority MessagePriority, transport TransportType, priority int) RoutingRule {
return RoutingRule{
Name: name,
Condition: func(msg *Message) bool { return msg.Priority == msgPriority },
Transport: transport,
Priority: priority,
}
}
// CreateTypeRule creates a rule based on message type
func CreateTypeRule(name string, msgType MessageType, transport TransportType, priority int) RoutingRule {
return RoutingRule{
Name: name,
Condition: func(msg *Message) bool { return msg.Type == msgType },
Transport: transport,
Priority: priority,
}
}
// CreateSourceRule creates a rule based on message source
func CreateSourceRule(name string, source string, transport TransportType, priority int) RoutingRule {
return RoutingRule{
Name: name,
Condition: func(msg *Message) bool { return msg.Source == source },
Transport: transport,
Priority: priority,
}
}

View File

@@ -0,0 +1,566 @@
package transport
import (
"bytes"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"sync"
)
// SerializationFormat defines supported serialization formats
type SerializationFormat string
const (
SerializationJSON SerializationFormat = "json"
SerializationMsgPack SerializationFormat = "msgpack"
SerializationProtobuf SerializationFormat = "protobuf"
SerializationAvro SerializationFormat = "avro"
)
// CompressionType defines supported compression algorithms
type CompressionType string
const (
CompressionNone CompressionType = "none"
CompressionGZip CompressionType = "gzip"
CompressionLZ4 CompressionType = "lz4"
CompressionSnappy CompressionType = "snappy"
)
// SerializationConfig configures serialization behavior
type SerializationConfig struct {
Format SerializationFormat
Compression CompressionType
Encryption bool
Validation bool
}
// SerializedMessage represents a serialized message with metadata
type SerializedMessage struct {
Format SerializationFormat `json:"format"`
Compression CompressionType `json:"compression"`
Encrypted bool `json:"encrypted"`
Checksum string `json:"checksum"`
Data []byte `json:"data"`
Size int `json:"size"`
Timestamp int64 `json:"timestamp"`
}
// Serializer interface defines serialization operations
type Serializer interface {
Serialize(msg *Message) (*SerializedMessage, error)
Deserialize(serialized *SerializedMessage) (*Message, error)
GetFormat() SerializationFormat
GetConfig() SerializationConfig
}
// SerializationLayer manages multiple serializers and format selection
type SerializationLayer struct {
serializers map[SerializationFormat]Serializer
defaultFormat SerializationFormat
compressor Compressor
encryptor Encryptor
validator Validator
mu sync.RWMutex
}
// Compressor interface for data compression
type Compressor interface {
Compress(data []byte, algorithm CompressionType) ([]byte, error)
Decompress(data []byte, algorithm CompressionType) ([]byte, error)
GetSupportedAlgorithms() []CompressionType
}
// Encryptor interface for data encryption
type Encryptor interface {
Encrypt(data []byte) ([]byte, error)
Decrypt(data []byte) ([]byte, error)
IsEnabled() bool
}
// Validator interface for data validation
type Validator interface {
Validate(msg *Message) error
GenerateChecksum(data []byte) string
VerifyChecksum(data []byte, checksum string) bool
}
// NewSerializationLayer creates a new serialization layer
func NewSerializationLayer() *SerializationLayer {
sl := &SerializationLayer{
serializers: make(map[SerializationFormat]Serializer),
defaultFormat: SerializationJSON,
compressor: NewDefaultCompressor(),
validator: NewDefaultValidator(),
}
// Register default serializers
sl.RegisterSerializer(NewJSONSerializer())
return sl
}
// RegisterSerializer registers a new serializer
func (sl *SerializationLayer) RegisterSerializer(serializer Serializer) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.serializers[serializer.GetFormat()] = serializer
}
// SetDefaultFormat sets the default serialization format
func (sl *SerializationLayer) SetDefaultFormat(format SerializationFormat) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.defaultFormat = format
}
// SetCompressor sets the compression handler
func (sl *SerializationLayer) SetCompressor(compressor Compressor) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.compressor = compressor
}
// SetEncryptor sets the encryption handler
func (sl *SerializationLayer) SetEncryptor(encryptor Encryptor) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.encryptor = encryptor
}
// SetValidator sets the validation handler
func (sl *SerializationLayer) SetValidator(validator Validator) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.validator = validator
}
// Serialize serializes a message using the specified or default format
func (sl *SerializationLayer) Serialize(msg *Message, format ...SerializationFormat) (*SerializedMessage, error) {
sl.mu.RLock()
defer sl.mu.RUnlock()
// Determine format to use
selectedFormat := sl.defaultFormat
if len(format) > 0 {
selectedFormat = format[0]
}
// Get serializer
serializer, exists := sl.serializers[selectedFormat]
if !exists {
return nil, fmt.Errorf("unsupported serialization format: %s", selectedFormat)
}
// Validate message if validator is configured
if sl.validator != nil {
if err := sl.validator.Validate(msg); err != nil {
return nil, fmt.Errorf("message validation failed: %w", err)
}
}
// Serialize message
serialized, err := serializer.Serialize(msg)
if err != nil {
return nil, fmt.Errorf("serialization failed: %w", err)
}
// Apply compression if configured
config := serializer.GetConfig()
if config.Compression != CompressionNone && sl.compressor != nil {
compressed, err := sl.compressor.Compress(serialized.Data, config.Compression)
if err != nil {
return nil, fmt.Errorf("compression failed: %w", err)
}
serialized.Data = compressed
serialized.Compression = config.Compression
}
// Apply encryption if configured
if config.Encryption && sl.encryptor != nil && sl.encryptor.IsEnabled() {
encrypted, err := sl.encryptor.Encrypt(serialized.Data)
if err != nil {
return nil, fmt.Errorf("encryption failed: %w", err)
}
serialized.Data = encrypted
serialized.Encrypted = true
}
// Generate checksum
if sl.validator != nil {
serialized.Checksum = sl.validator.GenerateChecksum(serialized.Data)
}
// Update metadata
serialized.Size = len(serialized.Data)
serialized.Timestamp = msg.Timestamp.UnixNano()
return serialized, nil
}
// Deserialize deserializes a message
func (sl *SerializationLayer) Deserialize(serialized *SerializedMessage) (*Message, error) {
sl.mu.RLock()
defer sl.mu.RUnlock()
// Verify checksum if available
if sl.validator != nil && serialized.Checksum != "" {
if !sl.validator.VerifyChecksum(serialized.Data, serialized.Checksum) {
return nil, fmt.Errorf("checksum verification failed")
}
}
data := serialized.Data
// Apply decryption if needed
if serialized.Encrypted && sl.encryptor != nil && sl.encryptor.IsEnabled() {
decrypted, err := sl.encryptor.Decrypt(data)
if err != nil {
return nil, fmt.Errorf("decryption failed: %w", err)
}
data = decrypted
}
// Apply decompression if needed
if serialized.Compression != CompressionNone && sl.compressor != nil {
decompressed, err := sl.compressor.Decompress(data, serialized.Compression)
if err != nil {
return nil, fmt.Errorf("decompression failed: %w", err)
}
data = decompressed
}
// Get serializer
serializer, exists := sl.serializers[serialized.Format]
if !exists {
return nil, fmt.Errorf("unsupported serialization format: %s", serialized.Format)
}
// Create temporary serialized message for deserializer
tempSerialized := &SerializedMessage{
Format: serialized.Format,
Data: data,
}
// Deserialize message
msg, err := serializer.Deserialize(tempSerialized)
if err != nil {
return nil, fmt.Errorf("deserialization failed: %w", err)
}
// Validate deserialized message if validator is configured
if sl.validator != nil {
if err := sl.validator.Validate(msg); err != nil {
return nil, fmt.Errorf("deserialized message validation failed: %w", err)
}
}
return msg, nil
}
// GetSupportedFormats returns all supported serialization formats
func (sl *SerializationLayer) GetSupportedFormats() []SerializationFormat {
sl.mu.RLock()
defer sl.mu.RUnlock()
formats := make([]SerializationFormat, 0, len(sl.serializers))
for format := range sl.serializers {
formats = append(formats, format)
}
return formats
}
// JSONSerializer implements JSON serialization
type JSONSerializer struct {
config SerializationConfig
}
// NewJSONSerializer creates a new JSON serializer
func NewJSONSerializer() *JSONSerializer {
return &JSONSerializer{
config: SerializationConfig{
Format: SerializationJSON,
Compression: CompressionNone,
Encryption: false,
Validation: true,
},
}
}
// SetConfig updates the serializer configuration
func (js *JSONSerializer) SetConfig(config SerializationConfig) {
js.config = config
js.config.Format = SerializationJSON // Ensure format is correct
}
// Serialize serializes a message to JSON
func (js *JSONSerializer) Serialize(msg *Message) (*SerializedMessage, error) {
data, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("JSON marshal failed: %w", err)
}
return &SerializedMessage{
Format: SerializationJSON,
Compression: CompressionNone,
Encrypted: false,
Data: data,
Size: len(data),
}, nil
}
// Deserialize deserializes a message from JSON
func (js *JSONSerializer) Deserialize(serialized *SerializedMessage) (*Message, error) {
var msg Message
if err := json.Unmarshal(serialized.Data, &msg); err != nil {
return nil, fmt.Errorf("JSON unmarshal failed: %w", err)
}
return &msg, nil
}
// GetFormat returns the serialization format
func (js *JSONSerializer) GetFormat() SerializationFormat {
return SerializationJSON
}
// GetConfig returns the serializer configuration
func (js *JSONSerializer) GetConfig() SerializationConfig {
return js.config
}
// DefaultCompressor implements basic compression operations
type DefaultCompressor struct {
supportedAlgorithms []CompressionType
}
// NewDefaultCompressor creates a new default compressor
func NewDefaultCompressor() *DefaultCompressor {
return &DefaultCompressor{
supportedAlgorithms: []CompressionType{
CompressionNone,
CompressionGZip,
},
}
}
// Compress compresses data using the specified algorithm
func (dc *DefaultCompressor) Compress(data []byte, algorithm CompressionType) ([]byte, error) {
switch algorithm {
case CompressionNone:
return data, nil
case CompressionGZip:
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
if _, err := writer.Write(data); err != nil {
return nil, fmt.Errorf("gzip write failed: %w", err)
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("gzip close failed: %w", err)
}
return buf.Bytes(), nil
default:
return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm)
}
}
// Decompress decompresses data using the specified algorithm
func (dc *DefaultCompressor) Decompress(data []byte, algorithm CompressionType) ([]byte, error) {
switch algorithm {
case CompressionNone:
return data, nil
case CompressionGZip:
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("gzip reader creation failed: %w", err)
}
defer reader.Close()
decompressed, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("gzip read failed: %w", err)
}
return decompressed, nil
default:
return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm)
}
}
// GetSupportedAlgorithms returns supported compression algorithms
func (dc *DefaultCompressor) GetSupportedAlgorithms() []CompressionType {
return dc.supportedAlgorithms
}
// DefaultValidator implements basic message validation
type DefaultValidator struct {
strictMode bool
}
// NewDefaultValidator creates a new default validator
func NewDefaultValidator() *DefaultValidator {
return &DefaultValidator{
strictMode: false,
}
}
// SetStrictMode enables/disables strict validation
func (dv *DefaultValidator) SetStrictMode(enabled bool) {
dv.strictMode = enabled
}
// Validate validates a message
func (dv *DefaultValidator) Validate(msg *Message) error {
if msg == nil {
return fmt.Errorf("message is nil")
}
if msg.ID == "" {
return fmt.Errorf("message ID is empty")
}
if msg.Topic == "" {
return fmt.Errorf("message topic is empty")
}
if msg.Source == "" {
return fmt.Errorf("message source is empty")
}
if msg.Type == "" {
return fmt.Errorf("message type is empty")
}
if dv.strictMode {
if msg.Data == nil {
return fmt.Errorf("message data is nil")
}
if msg.Timestamp.IsZero() {
return fmt.Errorf("message timestamp is zero")
}
}
return nil
}
// GenerateChecksum generates a simple checksum for data
func (dv *DefaultValidator) GenerateChecksum(data []byte) string {
// Simple checksum implementation
// In production, use a proper hash function like SHA-256
var sum uint32
for _, b := range data {
sum += uint32(b)
}
return fmt.Sprintf("%08x", sum)
}
// VerifyChecksum verifies a checksum
func (dv *DefaultValidator) VerifyChecksum(data []byte, checksum string) bool {
return dv.GenerateChecksum(data) == checksum
}
// NoOpEncryptor implements a no-operation encryptor for testing
type NoOpEncryptor struct {
enabled bool
}
// NewNoOpEncryptor creates a new no-op encryptor
func NewNoOpEncryptor() *NoOpEncryptor {
return &NoOpEncryptor{enabled: false}
}
// SetEnabled enables/disables the encryptor
func (noe *NoOpEncryptor) SetEnabled(enabled bool) {
noe.enabled = enabled
}
// Encrypt returns data unchanged
func (noe *NoOpEncryptor) Encrypt(data []byte) ([]byte, error) {
return data, nil
}
// Decrypt returns data unchanged
func (noe *NoOpEncryptor) Decrypt(data []byte) ([]byte, error) {
return data, nil
}
// IsEnabled returns whether encryption is enabled
func (noe *NoOpEncryptor) IsEnabled() bool {
return noe.enabled
}
// SerializationMetrics tracks serialization performance
type SerializationMetrics struct {
SerializedMessages int64
DeserializedMessages int64
SerializationErrors int64
CompressionRatio float64
AverageMessageSize int64
TotalDataProcessed int64
}
// MetricsCollector collects serialization metrics
type MetricsCollector struct {
metrics SerializationMetrics
mu sync.RWMutex
}
// NewMetricsCollector creates a new metrics collector
func NewMetricsCollector() *MetricsCollector {
return &MetricsCollector{}
}
// RecordSerialization records a serialization operation
func (mc *MetricsCollector) RecordSerialization(originalSize, serializedSize int) {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.metrics.SerializedMessages++
mc.metrics.TotalDataProcessed += int64(originalSize)
// Update compression ratio
if originalSize > 0 {
ratio := float64(serializedSize) / float64(originalSize)
mc.metrics.CompressionRatio = (mc.metrics.CompressionRatio + ratio) / 2
}
// Update average message size
mc.metrics.AverageMessageSize = mc.metrics.TotalDataProcessed / mc.metrics.SerializedMessages
}
// RecordDeserialization records a deserialization operation
func (mc *MetricsCollector) RecordDeserialization() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.metrics.DeserializedMessages++
}
// RecordError records a serialization error
func (mc *MetricsCollector) RecordError() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.metrics.SerializationErrors++
}
// GetMetrics returns current metrics
func (mc *MetricsCollector) GetMetrics() SerializationMetrics {
mc.mu.RLock()
defer mc.mu.RUnlock()
return mc.metrics
}
// Reset resets all metrics
func (mc *MetricsCollector) Reset() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.metrics = SerializationMetrics{}
}

View File

@@ -0,0 +1,490 @@
package transport
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net"
"sync"
"time"
)
// TCPTransport implements TCP transport for remote communication
type TCPTransport struct {
address string
port int
listener net.Listener
connections map[string]net.Conn
metrics TransportMetrics
connected bool
isServer bool
receiveChan chan *Message
tlsConfig *tls.Config
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
retryConfig RetryConfig
}
// NewTCPTransport creates a new TCP transport
func NewTCPTransport(address string, port int, isServer bool) *TCPTransport {
ctx, cancel := context.WithCancel(context.Background())
return &TCPTransport{
address: address,
port: port,
connections: make(map[string]net.Conn),
metrics: TransportMetrics{},
isServer: isServer,
receiveChan: make(chan *Message, 1000),
ctx: ctx,
cancel: cancel,
retryConfig: RetryConfig{
MaxRetries: 3,
InitialDelay: time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
Jitter: true,
},
}
}
// SetTLSConfig configures TLS for secure communication
func (tt *TCPTransport) SetTLSConfig(config *tls.Config) {
tt.tlsConfig = config
}
// SetRetryConfig configures retry behavior
func (tt *TCPTransport) SetRetryConfig(config RetryConfig) {
tt.retryConfig = config
}
// Connect establishes the TCP connection
func (tt *TCPTransport) Connect(ctx context.Context) error {
tt.mu.Lock()
defer tt.mu.Unlock()
if tt.connected {
return nil
}
if tt.isServer {
return tt.startServer()
} else {
return tt.connectToServer(ctx)
}
}
// Disconnect closes the TCP connection
func (tt *TCPTransport) Disconnect(ctx context.Context) error {
tt.mu.Lock()
defer tt.mu.Unlock()
if !tt.connected {
return nil
}
tt.cancel()
if tt.isServer && tt.listener != nil {
tt.listener.Close()
}
// Close all connections
for id, conn := range tt.connections {
conn.Close()
delete(tt.connections, id)
}
close(tt.receiveChan)
tt.connected = false
tt.metrics.Connections = 0
return nil
}
// Send transmits a message through TCP
func (tt *TCPTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
tt.mu.RLock()
if !tt.connected {
tt.mu.RUnlock()
tt.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Serialize message
data, err := json.Marshal(msg)
if err != nil {
tt.mu.RUnlock()
tt.metrics.Errors++
return fmt.Errorf("failed to marshal message: %w", err)
}
// Add length prefix for framing
frame := fmt.Sprintf("%d\n%s", len(data), data)
frameBytes := []byte(frame)
// Send to all connections with retry
var sendErr error
connectionCount := len(tt.connections)
tt.mu.RUnlock()
if connectionCount == 0 {
tt.metrics.Errors++
return fmt.Errorf("no active connections")
}
tt.mu.RLock()
for connID, conn := range tt.connections {
if err := tt.sendWithRetry(ctx, conn, frameBytes); err != nil {
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
// Remove failed connection
go tt.removeConnection(connID)
}
}
tt.mu.RUnlock()
if sendErr == nil {
tt.updateSendMetrics(msg, time.Since(start))
} else {
tt.metrics.Errors++
}
return sendErr
}
// Receive returns a channel for receiving messages
func (tt *TCPTransport) Receive(ctx context.Context) (<-chan *Message, error) {
tt.mu.RLock()
defer tt.mu.RUnlock()
if !tt.connected {
return nil, fmt.Errorf("transport not connected")
}
return tt.receiveChan, nil
}
// Health returns the health status of the transport
func (tt *TCPTransport) Health() ComponentHealth {
tt.mu.RLock()
defer tt.mu.RUnlock()
status := "unhealthy"
var responseTime time.Duration
if tt.connected {
if len(tt.connections) > 0 {
status = "healthy"
responseTime = time.Millisecond * 10 // Estimate for TCP
} else {
status = "degraded" // Connected but no active connections
}
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: responseTime,
ErrorCount: tt.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (tt *TCPTransport) GetMetrics() TransportMetrics {
tt.mu.RLock()
defer tt.mu.RUnlock()
return TransportMetrics{
BytesSent: tt.metrics.BytesSent,
BytesReceived: tt.metrics.BytesReceived,
MessagesSent: tt.metrics.MessagesSent,
MessagesReceived: tt.metrics.MessagesReceived,
Connections: len(tt.connections),
Errors: tt.metrics.Errors,
Latency: tt.metrics.Latency,
}
}
// Private helper methods
func (tt *TCPTransport) startServer() error {
addr := fmt.Sprintf("%s:%d", tt.address, tt.port)
var listener net.Listener
var err error
if tt.tlsConfig != nil {
listener, err = tls.Listen("tcp", addr, tt.tlsConfig)
} else {
listener, err = net.Listen("tcp", addr)
}
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", addr, err)
}
tt.listener = listener
tt.connected = true
// Start accepting connections
go tt.acceptConnections()
return nil
}
func (tt *TCPTransport) connectToServer(ctx context.Context) error {
addr := fmt.Sprintf("%s:%d", tt.address, tt.port)
var conn net.Conn
var err error
// Retry connection with exponential backoff
delay := tt.retryConfig.InitialDelay
for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ {
if tt.tlsConfig != nil {
conn, err = tls.Dial("tcp", addr, tt.tlsConfig)
} else {
conn, err = net.Dial("tcp", addr)
}
if err == nil {
break
}
if attempt == tt.retryConfig.MaxRetries {
return fmt.Errorf("failed to connect to %s after %d attempts: %w", addr, attempt+1, err)
}
// Wait with exponential backoff
select {
case <-time.After(delay):
delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor)
if delay > tt.retryConfig.MaxDelay {
delay = tt.retryConfig.MaxDelay
}
// Add jitter if enabled
if tt.retryConfig.Jitter {
jitter := time.Duration(float64(delay) * 0.1)
delay += time.Duration(float64(jitter) * (2*time.Now().UnixNano()%1000/1000.0 - 1))
}
case <-ctx.Done():
return ctx.Err()
}
}
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
tt.connections[connID] = conn
tt.connected = true
tt.metrics.Connections = 1
// Start receiving from server
go tt.handleConnection(connID, conn)
return nil
}
func (tt *TCPTransport) acceptConnections() {
for {
select {
case <-tt.ctx.Done():
return
default:
conn, err := tt.listener.Accept()
if err != nil {
if tt.ctx.Err() != nil {
return // Context cancelled
}
tt.metrics.Errors++
continue
}
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
tt.mu.Lock()
tt.connections[connID] = conn
tt.metrics.Connections = len(tt.connections)
tt.mu.Unlock()
go tt.handleConnection(connID, conn)
}
}
}
func (tt *TCPTransport) handleConnection(connID string, conn net.Conn) {
defer tt.removeConnection(connID)
buffer := make([]byte, 4096)
var messageBuffer []byte
for {
select {
case <-tt.ctx.Done():
return
default:
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, err := conn.Read(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // Continue on timeout
}
return // Connection closed or error
}
messageBuffer = append(messageBuffer, buffer[:n]...)
// Process complete messages
for {
msg, remaining, err := tt.extractMessage(messageBuffer)
if err != nil {
return // Invalid message format
}
if msg == nil {
break // No complete message yet
}
// Deliver message
select {
case tt.receiveChan <- msg:
tt.updateReceiveMetrics(msg)
case <-tt.ctx.Done():
return
default:
// Channel full, drop message
tt.metrics.Errors++
}
messageBuffer = remaining
}
}
}
}
func (tt *TCPTransport) sendWithRetry(ctx context.Context, conn net.Conn, data []byte) error {
delay := tt.retryConfig.InitialDelay
for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ {
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
_, err := conn.Write(data)
if err == nil {
return nil
}
if attempt == tt.retryConfig.MaxRetries {
return err
}
// Wait with exponential backoff
select {
case <-time.After(delay):
delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor)
if delay > tt.retryConfig.MaxDelay {
delay = tt.retryConfig.MaxDelay
}
case <-ctx.Done():
return ctx.Err()
}
}
return fmt.Errorf("max retries exceeded")
}
func (tt *TCPTransport) removeConnection(connID string) {
tt.mu.Lock()
defer tt.mu.Unlock()
if conn, exists := tt.connections[connID]; exists {
conn.Close()
delete(tt.connections, connID)
tt.metrics.Connections = len(tt.connections)
}
}
func (tt *TCPTransport) extractMessage(buffer []byte) (*Message, []byte, error) {
// Look for length prefix (format: "length\nmessage_data")
newlineIndex := -1
for i, b := range buffer {
if b == '\n' {
newlineIndex = i
break
}
}
if newlineIndex == -1 {
return nil, buffer, nil // No complete length prefix yet
}
// Parse length
lengthStr := string(buffer[:newlineIndex])
var messageLength int
if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil {
return nil, nil, fmt.Errorf("invalid length prefix: %s", lengthStr)
}
// Check if we have the complete message
messageStart := newlineIndex + 1
messageEnd := messageStart + messageLength
if len(buffer) < messageEnd {
return nil, buffer, nil // Incomplete message
}
// Extract and parse message
messageData := buffer[messageStart:messageEnd]
var msg Message
if err := json.Unmarshal(messageData, &msg); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal message: %w", err)
}
// Return message and remaining buffer
remaining := buffer[messageEnd:]
return &msg, remaining, nil
}
func (tt *TCPTransport) updateSendMetrics(msg *Message, latency time.Duration) {
tt.mu.Lock()
defer tt.mu.Unlock()
tt.metrics.MessagesSent++
tt.metrics.Latency = latency
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
tt.metrics.BytesSent += messageSize
}
func (tt *TCPTransport) updateReceiveMetrics(msg *Message) {
tt.mu.Lock()
defer tt.mu.Unlock()
tt.metrics.MessagesReceived++
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
tt.metrics.BytesReceived += messageSize
}
// GetAddress returns the transport address (for testing/debugging)
func (tt *TCPTransport) GetAddress() string {
return fmt.Sprintf("%s:%d", tt.address, tt.port)
}
// GetConnectionCount returns the number of active connections
func (tt *TCPTransport) GetConnectionCount() int {
tt.mu.RLock()
defer tt.mu.RUnlock()
return len(tt.connections)
}
// IsSecure returns whether TLS is enabled
func (tt *TCPTransport) IsSecure() bool {
return tt.tlsConfig != nil
}

View File

@@ -0,0 +1,399 @@
package transport
import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"sync"
"time"
)
// UnixSocketTransport implements Unix socket transport for local IPC
type UnixSocketTransport struct {
socketPath string
listener net.Listener
connections map[string]net.Conn
metrics TransportMetrics
connected bool
isServer bool
receiveChan chan *Message
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
// NewUnixSocketTransport creates a new Unix socket transport
func NewUnixSocketTransport(socketPath string, isServer bool) *UnixSocketTransport {
ctx, cancel := context.WithCancel(context.Background())
return &UnixSocketTransport{
socketPath: socketPath,
connections: make(map[string]net.Conn),
metrics: TransportMetrics{},
isServer: isServer,
receiveChan: make(chan *Message, 1000),
ctx: ctx,
cancel: cancel,
}
}
// Connect establishes the Unix socket connection
func (ut *UnixSocketTransport) Connect(ctx context.Context) error {
ut.mu.Lock()
defer ut.mu.Unlock()
if ut.connected {
return nil
}
if ut.isServer {
return ut.startServer()
} else {
return ut.connectToServer()
}
}
// Disconnect closes the Unix socket connection
func (ut *UnixSocketTransport) Disconnect(ctx context.Context) error {
ut.mu.Lock()
defer ut.mu.Unlock()
if !ut.connected {
return nil
}
ut.cancel()
if ut.isServer && ut.listener != nil {
ut.listener.Close()
// Remove socket file
os.Remove(ut.socketPath)
}
// Close all connections
for id, conn := range ut.connections {
conn.Close()
delete(ut.connections, id)
}
close(ut.receiveChan)
ut.connected = false
ut.metrics.Connections = 0
return nil
}
// Send transmits a message through the Unix socket
func (ut *UnixSocketTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
ut.mu.RLock()
if !ut.connected {
ut.mu.RUnlock()
ut.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Serialize message
data, err := json.Marshal(msg)
if err != nil {
ut.mu.RUnlock()
ut.metrics.Errors++
return fmt.Errorf("failed to marshal message: %w", err)
}
// Add length prefix for framing
frame := fmt.Sprintf("%d\n%s", len(data), data)
frameBytes := []byte(frame)
// Send to all connections
var sendErr error
connectionCount := len(ut.connections)
ut.mu.RUnlock()
if connectionCount == 0 {
ut.metrics.Errors++
return fmt.Errorf("no active connections")
}
ut.mu.RLock()
for connID, conn := range ut.connections {
if err := ut.sendToConnection(conn, frameBytes); err != nil {
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
// Remove failed connection
go ut.removeConnection(connID)
}
}
ut.mu.RUnlock()
if sendErr == nil {
ut.updateSendMetrics(msg, time.Since(start))
} else {
ut.metrics.Errors++
}
return sendErr
}
// Receive returns a channel for receiving messages
func (ut *UnixSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) {
ut.mu.RLock()
defer ut.mu.RUnlock()
if !ut.connected {
return nil, fmt.Errorf("transport not connected")
}
return ut.receiveChan, nil
}
// Health returns the health status of the transport
func (ut *UnixSocketTransport) Health() ComponentHealth {
ut.mu.RLock()
defer ut.mu.RUnlock()
status := "unhealthy"
if ut.connected {
if len(ut.connections) > 0 {
status = "healthy"
} else {
status = "degraded" // Connected but no active connections
}
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: time.Millisecond, // Fast for local sockets
ErrorCount: ut.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (ut *UnixSocketTransport) GetMetrics() TransportMetrics {
ut.mu.RLock()
defer ut.mu.RUnlock()
return TransportMetrics{
BytesSent: ut.metrics.BytesSent,
BytesReceived: ut.metrics.BytesReceived,
MessagesSent: ut.metrics.MessagesSent,
MessagesReceived: ut.metrics.MessagesReceived,
Connections: len(ut.connections),
Errors: ut.metrics.Errors,
Latency: ut.metrics.Latency,
}
}
// Private helper methods
func (ut *UnixSocketTransport) startServer() error {
// Remove existing socket file
os.Remove(ut.socketPath)
listener, err := net.Listen("unix", ut.socketPath)
if err != nil {
return fmt.Errorf("failed to listen on socket %s: %w", ut.socketPath, err)
}
ut.listener = listener
ut.connected = true
// Start accepting connections
go ut.acceptConnections()
return nil
}
func (ut *UnixSocketTransport) connectToServer() error {
conn, err := net.Dial("unix", ut.socketPath)
if err != nil {
return fmt.Errorf("failed to connect to socket %s: %w", ut.socketPath, err)
}
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
ut.connections[connID] = conn
ut.connected = true
ut.metrics.Connections = 1
// Start receiving from server
go ut.handleConnection(connID, conn)
return nil
}
func (ut *UnixSocketTransport) acceptConnections() {
for {
select {
case <-ut.ctx.Done():
return
default:
conn, err := ut.listener.Accept()
if err != nil {
if ut.ctx.Err() != nil {
return // Context cancelled
}
ut.metrics.Errors++
continue
}
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
ut.mu.Lock()
ut.connections[connID] = conn
ut.metrics.Connections = len(ut.connections)
ut.mu.Unlock()
go ut.handleConnection(connID, conn)
}
}
}
func (ut *UnixSocketTransport) handleConnection(connID string, conn net.Conn) {
defer ut.removeConnection(connID)
buffer := make([]byte, 4096)
var messageBuffer []byte
for {
select {
case <-ut.ctx.Done():
return
default:
conn.SetReadDeadline(time.Now().Add(time.Second)) // Non-blocking read
n, err := conn.Read(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // Continue on timeout
}
return // Connection closed or error
}
messageBuffer = append(messageBuffer, buffer[:n]...)
// Process complete messages
for {
msg, remaining, err := ut.extractMessage(messageBuffer)
if err != nil {
return // Invalid message format
}
if msg == nil {
break // No complete message yet
}
// Deliver message
select {
case ut.receiveChan <- msg:
ut.updateReceiveMetrics(msg)
case <-ut.ctx.Done():
return
default:
// Channel full, drop message
ut.metrics.Errors++
}
messageBuffer = remaining
}
}
}
}
func (ut *UnixSocketTransport) sendToConnection(conn net.Conn, data []byte) error {
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
_, err := conn.Write(data)
return err
}
func (ut *UnixSocketTransport) removeConnection(connID string) {
ut.mu.Lock()
defer ut.mu.Unlock()
if conn, exists := ut.connections[connID]; exists {
conn.Close()
delete(ut.connections, connID)
ut.metrics.Connections = len(ut.connections)
}
}
func (ut *UnixSocketTransport) extractMessage(buffer []byte) (*Message, []byte, error) {
// Look for length prefix (format: "length\nmessage_data")
newlineIndex := -1
for i, b := range buffer {
if b == '\n' {
newlineIndex = i
break
}
}
if newlineIndex == -1 {
return nil, buffer, nil // No complete length prefix yet
}
// Parse length
lengthStr := string(buffer[:newlineIndex])
var messageLength int
if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil {
return nil, nil, fmt.Errorf("invalid length prefix: %s", lengthStr)
}
// Check if we have the complete message
messageStart := newlineIndex + 1
messageEnd := messageStart + messageLength
if len(buffer) < messageEnd {
return nil, buffer, nil // Incomplete message
}
// Extract and parse message
messageData := buffer[messageStart:messageEnd]
var msg Message
if err := json.Unmarshal(messageData, &msg); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal message: %w", err)
}
// Return message and remaining buffer
remaining := buffer[messageEnd:]
return &msg, remaining, nil
}
func (ut *UnixSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) {
ut.mu.Lock()
defer ut.mu.Unlock()
ut.metrics.MessagesSent++
ut.metrics.Latency = latency
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
ut.metrics.BytesSent += messageSize
}
func (ut *UnixSocketTransport) updateReceiveMetrics(msg *Message) {
ut.mu.Lock()
defer ut.mu.Unlock()
ut.metrics.MessagesReceived++
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
ut.metrics.BytesReceived += messageSize
}
// GetSocketPath returns the socket path (for testing/debugging)
func (ut *UnixSocketTransport) GetSocketPath() string {
return ut.socketPath
}
// GetConnectionCount returns the number of active connections
func (ut *UnixSocketTransport) GetConnectionCount() int {
ut.mu.RLock()
defer ut.mu.RUnlock()
return len(ut.connections)
}

View File

@@ -0,0 +1,427 @@
package transport
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
// WebSocketTransport implements WebSocket transport for real-time monitoring
type WebSocketTransport struct {
address string
port int
path string
upgrader websocket.Upgrader
connections map[string]*websocket.Conn
metrics TransportMetrics
connected bool
isServer bool
receiveChan chan *Message
server *http.Server
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
pingInterval time.Duration
pongTimeout time.Duration
}
// NewWebSocketTransport creates a new WebSocket transport
func NewWebSocketTransport(address string, port int, path string, isServer bool) *WebSocketTransport {
ctx, cancel := context.WithCancel(context.Background())
return &WebSocketTransport{
address: address,
port: port,
path: path,
connections: make(map[string]*websocket.Conn),
metrics: TransportMetrics{},
isServer: isServer,
receiveChan: make(chan *Message, 1000),
ctx: ctx,
cancel: cancel,
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for now
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
pingInterval: 30 * time.Second,
pongTimeout: 10 * time.Second,
}
}
// SetPingPongSettings configures WebSocket ping/pong settings
func (wt *WebSocketTransport) SetPingPongSettings(pingInterval, pongTimeout time.Duration) {
wt.pingInterval = pingInterval
wt.pongTimeout = pongTimeout
}
// Connect establishes the WebSocket connection
func (wt *WebSocketTransport) Connect(ctx context.Context) error {
wt.mu.Lock()
defer wt.mu.Unlock()
if wt.connected {
return nil
}
if wt.isServer {
return wt.startServer()
} else {
return wt.connectToServer(ctx)
}
}
// Disconnect closes the WebSocket connection
func (wt *WebSocketTransport) Disconnect(ctx context.Context) error {
wt.mu.Lock()
defer wt.mu.Unlock()
if !wt.connected {
return nil
}
wt.cancel()
if wt.isServer && wt.server != nil {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wt.server.Shutdown(shutdownCtx)
}
// Close all connections
for id, conn := range wt.connections {
conn.Close()
delete(wt.connections, id)
}
close(wt.receiveChan)
wt.connected = false
wt.metrics.Connections = 0
return nil
}
// Send transmits a message through WebSocket
func (wt *WebSocketTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
wt.mu.RLock()
if !wt.connected {
wt.mu.RUnlock()
wt.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Serialize message
data, err := json.Marshal(msg)
if err != nil {
wt.mu.RUnlock()
wt.metrics.Errors++
return fmt.Errorf("failed to marshal message: %w", err)
}
// Send to all connections
var sendErr error
connectionCount := len(wt.connections)
wt.mu.RUnlock()
if connectionCount == 0 {
wt.metrics.Errors++
return fmt.Errorf("no active connections")
}
wt.mu.RLock()
for connID, conn := range wt.connections {
if err := wt.sendToConnection(conn, data); err != nil {
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
// Remove failed connection
go wt.removeConnection(connID)
}
}
wt.mu.RUnlock()
if sendErr == nil {
wt.updateSendMetrics(msg, time.Since(start))
} else {
wt.metrics.Errors++
}
return sendErr
}
// Receive returns a channel for receiving messages
func (wt *WebSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) {
wt.mu.RLock()
defer wt.mu.RUnlock()
if !wt.connected {
return nil, fmt.Errorf("transport not connected")
}
return wt.receiveChan, nil
}
// Health returns the health status of the transport
func (wt *WebSocketTransport) Health() ComponentHealth {
wt.mu.RLock()
defer wt.mu.RUnlock()
status := "unhealthy"
if wt.connected {
if len(wt.connections) > 0 {
status = "healthy"
} else {
status = "degraded" // Connected but no active connections
}
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: time.Millisecond * 5, // Very fast for WebSocket
ErrorCount: wt.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (wt *WebSocketTransport) GetMetrics() TransportMetrics {
wt.mu.RLock()
defer wt.mu.RUnlock()
return TransportMetrics{
BytesSent: wt.metrics.BytesSent,
BytesReceived: wt.metrics.BytesReceived,
MessagesSent: wt.metrics.MessagesSent,
MessagesReceived: wt.metrics.MessagesReceived,
Connections: len(wt.connections),
Errors: wt.metrics.Errors,
Latency: wt.metrics.Latency,
}
}
// Private helper methods
func (wt *WebSocketTransport) startServer() error {
mux := http.NewServeMux()
mux.HandleFunc(wt.path, wt.handleWebSocket)
// Add health check endpoint
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
health := wt.Health()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(health)
})
// Add metrics endpoint
mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
metrics := wt.GetMetrics()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(metrics)
})
addr := fmt.Sprintf("%s:%d", wt.address, wt.port)
wt.server = &http.Server{
Addr: addr,
Handler: mux,
ReadTimeout: 60 * time.Second,
WriteTimeout: 60 * time.Second,
IdleTimeout: 120 * time.Second,
}
wt.connected = true
// Start server in goroutine
go func() {
if err := wt.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
wt.metrics.Errors++
}
}()
return nil
}
func (wt *WebSocketTransport) connectToServer(ctx context.Context) error {
url := fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path)
conn, _, err := websocket.DefaultDialer.DialContext(ctx, url, nil)
if err != nil {
return fmt.Errorf("failed to connect to WebSocket server: %w", err)
}
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
wt.connections[connID] = conn
wt.connected = true
wt.metrics.Connections = 1
// Start handling connection
go wt.handleConnection(connID, conn)
return nil
}
func (wt *WebSocketTransport) handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := wt.upgrader.Upgrade(w, r, nil)
if err != nil {
wt.metrics.Errors++
return
}
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
wt.mu.Lock()
wt.connections[connID] = conn
wt.metrics.Connections = len(wt.connections)
wt.mu.Unlock()
go wt.handleConnection(connID, conn)
}
func (wt *WebSocketTransport) handleConnection(connID string, conn *websocket.Conn) {
defer wt.removeConnection(connID)
// Set up ping/pong handling
conn.SetReadDeadline(time.Now().Add(wt.pongTimeout))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(wt.pongTimeout))
return nil
})
// Start ping routine
go wt.pingRoutine(connID, conn)
for {
select {
case <-wt.ctx.Done():
return
default:
var msg Message
err := conn.ReadJSON(&msg)
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
wt.metrics.Errors++
}
return
}
// Deliver message
select {
case wt.receiveChan <- &msg:
wt.updateReceiveMetrics(&msg)
case <-wt.ctx.Done():
return
default:
// Channel full, drop message
wt.metrics.Errors++
}
}
}
}
func (wt *WebSocketTransport) pingRoutine(connID string, conn *websocket.Conn) {
ticker := time.NewTicker(wt.pingInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
return // Connection is likely closed
}
case <-wt.ctx.Done():
return
}
}
}
func (wt *WebSocketTransport) sendToConnection(conn *websocket.Conn, data []byte) error {
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return conn.WriteMessage(websocket.TextMessage, data)
}
func (wt *WebSocketTransport) removeConnection(connID string) {
wt.mu.Lock()
defer wt.mu.Unlock()
if conn, exists := wt.connections[connID]; exists {
conn.Close()
delete(wt.connections, connID)
wt.metrics.Connections = len(wt.connections)
}
}
func (wt *WebSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) {
wt.mu.Lock()
defer wt.mu.Unlock()
wt.metrics.MessagesSent++
wt.metrics.Latency = latency
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
wt.metrics.BytesSent += messageSize
}
func (wt *WebSocketTransport) updateReceiveMetrics(msg *Message) {
wt.mu.Lock()
defer wt.mu.Unlock()
wt.metrics.MessagesReceived++
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
wt.metrics.BytesReceived += messageSize
}
// Broadcast sends a message to all connected clients (server mode only)
func (wt *WebSocketTransport) Broadcast(ctx context.Context, msg *Message) error {
if !wt.isServer {
return fmt.Errorf("broadcast only available in server mode")
}
return wt.Send(ctx, msg)
}
// GetURL returns the WebSocket URL (for testing/debugging)
func (wt *WebSocketTransport) GetURL() string {
return fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path)
}
// GetConnectionCount returns the number of active connections
func (wt *WebSocketTransport) GetConnectionCount() int {
wt.mu.RLock()
defer wt.mu.RUnlock()
return len(wt.connections)
}
// SetAllowedOrigins configures CORS for WebSocket connections
func (wt *WebSocketTransport) SetAllowedOrigins(origins []string) {
if len(origins) == 0 {
wt.upgrader.CheckOrigin = func(r *http.Request) bool {
return true // Allow all origins
}
return
}
originMap := make(map[string]bool)
for _, origin := range origins {
originMap[origin] = true
}
wt.upgrader.CheckOrigin = func(r *http.Request) bool {
origin := r.Header.Get("Origin")
return originMap[origin]
}
}