feat(transport): implement comprehensive universal message bus
🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
591
pkg/transport/dlq.go
Normal file
591
pkg/transport/dlq.go
Normal file
@@ -0,0 +1,591 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DeadLetterQueue handles failed messages with retry and reprocessing capabilities
|
||||
type DeadLetterQueue struct {
|
||||
messages map[string][]*DLQMessage
|
||||
config DLQConfig
|
||||
metrics DLQMetrics
|
||||
reprocessor MessageReprocessor
|
||||
mu sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// DLQMessage represents a message in the dead letter queue
|
||||
type DLQMessage struct {
|
||||
ID string
|
||||
OriginalMessage *Message
|
||||
Topic string
|
||||
FirstFailed time.Time
|
||||
LastAttempt time.Time
|
||||
AttemptCount int
|
||||
MaxRetries int
|
||||
FailureReason string
|
||||
RetryDelay time.Duration
|
||||
NextRetry time.Time
|
||||
Metadata map[string]interface{}
|
||||
Permanent bool
|
||||
}
|
||||
|
||||
// DLQConfig configures dead letter queue behavior
|
||||
type DLQConfig struct {
|
||||
MaxMessages int
|
||||
MaxRetries int
|
||||
RetentionTime time.Duration
|
||||
AutoReprocess bool
|
||||
ReprocessInterval time.Duration
|
||||
BackoffStrategy BackoffStrategy
|
||||
InitialRetryDelay time.Duration
|
||||
MaxRetryDelay time.Duration
|
||||
BackoffMultiplier float64
|
||||
PermanentFailures []string // Error patterns that mark messages as permanently failed
|
||||
ReprocessBatchSize int
|
||||
}
|
||||
|
||||
// BackoffStrategy defines retry delay calculation methods
|
||||
type BackoffStrategy string
|
||||
|
||||
const (
|
||||
BackoffFixed BackoffStrategy = "fixed"
|
||||
BackoffLinear BackoffStrategy = "linear"
|
||||
BackoffExponential BackoffStrategy = "exponential"
|
||||
BackoffCustom BackoffStrategy = "custom"
|
||||
)
|
||||
|
||||
// DLQMetrics tracks dead letter queue statistics
|
||||
type DLQMetrics struct {
|
||||
MessagesAdded int64
|
||||
MessagesReprocessed int64
|
||||
MessagesExpired int64
|
||||
MessagesPermanent int64
|
||||
ReprocessSuccesses int64
|
||||
ReprocessFailures int64
|
||||
QueueSize int64
|
||||
OldestMessage time.Time
|
||||
}
|
||||
|
||||
// MessageReprocessor handles message reprocessing logic
|
||||
type MessageReprocessor interface {
|
||||
Reprocess(ctx context.Context, msg *DLQMessage) error
|
||||
CanReprocess(msg *DLQMessage) bool
|
||||
ShouldRetry(msg *DLQMessage, err error) bool
|
||||
}
|
||||
|
||||
// DefaultMessageReprocessor implements basic reprocessing logic
|
||||
type DefaultMessageReprocessor struct {
|
||||
publisher MessagePublisher
|
||||
}
|
||||
|
||||
// MessagePublisher interface for republishing messages
|
||||
type MessagePublisher interface {
|
||||
Publish(ctx context.Context, msg *Message) error
|
||||
}
|
||||
|
||||
// NewDeadLetterQueue creates a new dead letter queue
|
||||
func NewDeadLetterQueue(config DLQConfig) *DeadLetterQueue {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
dlq := &DeadLetterQueue{
|
||||
messages: make(map[string][]*DLQMessage),
|
||||
config: config,
|
||||
metrics: DLQMetrics{},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Set default configuration values
|
||||
if dlq.config.MaxMessages == 0 {
|
||||
dlq.config.MaxMessages = 10000
|
||||
}
|
||||
if dlq.config.MaxRetries == 0 {
|
||||
dlq.config.MaxRetries = 3
|
||||
}
|
||||
if dlq.config.RetentionTime == 0 {
|
||||
dlq.config.RetentionTime = 24 * time.Hour
|
||||
}
|
||||
if dlq.config.ReprocessInterval == 0 {
|
||||
dlq.config.ReprocessInterval = 5 * time.Minute
|
||||
}
|
||||
if dlq.config.InitialRetryDelay == 0 {
|
||||
dlq.config.InitialRetryDelay = time.Minute
|
||||
}
|
||||
if dlq.config.MaxRetryDelay == 0 {
|
||||
dlq.config.MaxRetryDelay = time.Hour
|
||||
}
|
||||
if dlq.config.BackoffMultiplier == 0 {
|
||||
dlq.config.BackoffMultiplier = 2.0
|
||||
}
|
||||
if dlq.config.BackoffStrategy == "" {
|
||||
dlq.config.BackoffStrategy = BackoffExponential
|
||||
}
|
||||
if dlq.config.ReprocessBatchSize == 0 {
|
||||
dlq.config.ReprocessBatchSize = 10
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
dlq.startCleanupRoutine()
|
||||
|
||||
// Start reprocessing routine if enabled
|
||||
if dlq.config.AutoReprocess {
|
||||
dlq.startReprocessRoutine()
|
||||
}
|
||||
|
||||
return dlq
|
||||
}
|
||||
|
||||
// AddMessage adds a failed message to the dead letter queue
|
||||
func (dlq *DeadLetterQueue) AddMessage(topic string, msg *Message) error {
|
||||
return dlq.AddMessageWithReason(topic, msg, "unknown failure")
|
||||
}
|
||||
|
||||
// AddMessageWithReason adds a failed message with a specific failure reason
|
||||
func (dlq *DeadLetterQueue) AddMessageWithReason(topic string, msg *Message, reason string) error {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
|
||||
// Check if we've exceeded max messages
|
||||
totalMessages := dlq.getTotalMessageCount()
|
||||
if totalMessages >= dlq.config.MaxMessages {
|
||||
// Remove oldest message to make room
|
||||
dlq.removeOldestMessage()
|
||||
}
|
||||
|
||||
// Check if this is a permanent failure
|
||||
permanent := dlq.isPermanentFailure(reason)
|
||||
|
||||
dlqMsg := &DLQMessage{
|
||||
ID: fmt.Sprintf("dlq_%s_%d", topic, time.Now().UnixNano()),
|
||||
OriginalMessage: msg,
|
||||
Topic: topic,
|
||||
FirstFailed: time.Now(),
|
||||
LastAttempt: time.Now(),
|
||||
AttemptCount: 1,
|
||||
MaxRetries: dlq.config.MaxRetries,
|
||||
FailureReason: reason,
|
||||
Metadata: make(map[string]interface{}),
|
||||
Permanent: permanent,
|
||||
}
|
||||
|
||||
if !permanent {
|
||||
dlqMsg.RetryDelay = dlq.calculateRetryDelay(dlqMsg)
|
||||
dlqMsg.NextRetry = time.Now().Add(dlqMsg.RetryDelay)
|
||||
}
|
||||
|
||||
// Add to queue
|
||||
if _, exists := dlq.messages[topic]; !exists {
|
||||
dlq.messages[topic] = make([]*DLQMessage, 0)
|
||||
}
|
||||
dlq.messages[topic] = append(dlq.messages[topic], dlqMsg)
|
||||
|
||||
// Update metrics
|
||||
dlq.metrics.MessagesAdded++
|
||||
dlq.metrics.QueueSize++
|
||||
if permanent {
|
||||
dlq.metrics.MessagesPermanent++
|
||||
}
|
||||
dlq.updateOldestMessage()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessages returns all messages for a topic
|
||||
func (dlq *DeadLetterQueue) GetMessages(topic string) ([]*DLQMessage, error) {
|
||||
dlq.mu.RLock()
|
||||
defer dlq.mu.RUnlock()
|
||||
|
||||
messages, exists := dlq.messages[topic]
|
||||
if !exists {
|
||||
return []*DLQMessage{}, nil
|
||||
}
|
||||
|
||||
// Return a copy to avoid race conditions
|
||||
result := make([]*DLQMessage, len(messages))
|
||||
copy(result, messages)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetAllMessages returns all messages across all topics
|
||||
func (dlq *DeadLetterQueue) GetAllMessages() map[string][]*DLQMessage {
|
||||
dlq.mu.RLock()
|
||||
defer dlq.mu.RUnlock()
|
||||
|
||||
result := make(map[string][]*DLQMessage)
|
||||
for topic, messages := range dlq.messages {
|
||||
result[topic] = make([]*DLQMessage, len(messages))
|
||||
copy(result[topic], messages)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ReprocessMessage attempts to reprocess a specific message
|
||||
func (dlq *DeadLetterQueue) ReprocessMessage(messageID string) error {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
|
||||
// Find message
|
||||
var dlqMsg *DLQMessage
|
||||
var topic string
|
||||
var index int
|
||||
|
||||
for t, messages := range dlq.messages {
|
||||
for i, msg := range messages {
|
||||
if msg.ID == messageID {
|
||||
dlqMsg = msg
|
||||
topic = t
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if dlqMsg != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if dlqMsg == nil {
|
||||
return fmt.Errorf("message not found: %s", messageID)
|
||||
}
|
||||
|
||||
if dlqMsg.Permanent {
|
||||
return fmt.Errorf("message marked as permanent failure: %s", messageID)
|
||||
}
|
||||
|
||||
// Attempt reprocessing
|
||||
err := dlq.attemptReprocess(dlqMsg)
|
||||
if err == nil {
|
||||
// Success - remove from queue
|
||||
dlq.removeMessageByIndex(topic, index)
|
||||
dlq.metrics.ReprocessSuccesses++
|
||||
dlq.metrics.QueueSize--
|
||||
return nil
|
||||
}
|
||||
|
||||
// Failed - update retry information
|
||||
dlqMsg.AttemptCount++
|
||||
dlqMsg.LastAttempt = time.Now()
|
||||
dlqMsg.FailureReason = err.Error()
|
||||
|
||||
if dlqMsg.AttemptCount >= dlqMsg.MaxRetries {
|
||||
dlqMsg.Permanent = true
|
||||
dlq.metrics.MessagesPermanent++
|
||||
} else {
|
||||
dlqMsg.RetryDelay = dlq.calculateRetryDelay(dlqMsg)
|
||||
dlqMsg.NextRetry = time.Now().Add(dlqMsg.RetryDelay)
|
||||
}
|
||||
|
||||
dlq.metrics.ReprocessFailures++
|
||||
return fmt.Errorf("reprocessing failed: %w", err)
|
||||
}
|
||||
|
||||
// PurgeMessages removes all messages for a topic
|
||||
func (dlq *DeadLetterQueue) PurgeMessages(topic string) error {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
|
||||
if messages, exists := dlq.messages[topic]; exists {
|
||||
count := len(messages)
|
||||
delete(dlq.messages, topic)
|
||||
dlq.metrics.QueueSize -= int64(count)
|
||||
dlq.updateOldestMessage()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PurgeAllMessages removes all messages from the queue
|
||||
func (dlq *DeadLetterQueue) PurgeAllMessages() error {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
|
||||
dlq.messages = make(map[string][]*DLQMessage)
|
||||
dlq.metrics.QueueSize = 0
|
||||
dlq.metrics.OldestMessage = time.Time{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessageCount returns the total number of messages in the queue
|
||||
func (dlq *DeadLetterQueue) GetMessageCount() int {
|
||||
dlq.mu.RLock()
|
||||
defer dlq.mu.RUnlock()
|
||||
return dlq.getTotalMessageCount()
|
||||
}
|
||||
|
||||
// GetMetrics returns current DLQ metrics
|
||||
func (dlq *DeadLetterQueue) GetMetrics() DLQMetrics {
|
||||
dlq.mu.RLock()
|
||||
defer dlq.mu.RUnlock()
|
||||
return dlq.metrics
|
||||
}
|
||||
|
||||
// SetReprocessor sets the message reprocessor
|
||||
func (dlq *DeadLetterQueue) SetReprocessor(reprocessor MessageReprocessor) {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
dlq.reprocessor = reprocessor
|
||||
}
|
||||
|
||||
// Cleanup removes expired messages
|
||||
func (dlq *DeadLetterQueue) Cleanup(maxAge time.Duration) error {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
expiredCount := 0
|
||||
|
||||
for topic, messages := range dlq.messages {
|
||||
filtered := make([]*DLQMessage, 0)
|
||||
for _, msg := range messages {
|
||||
if msg.FirstFailed.After(cutoff) {
|
||||
filtered = append(filtered, msg)
|
||||
} else {
|
||||
expiredCount++
|
||||
}
|
||||
}
|
||||
dlq.messages[topic] = filtered
|
||||
|
||||
// Remove empty topics
|
||||
if len(filtered) == 0 {
|
||||
delete(dlq.messages, topic)
|
||||
}
|
||||
}
|
||||
|
||||
dlq.metrics.MessagesExpired += int64(expiredCount)
|
||||
dlq.metrics.QueueSize -= int64(expiredCount)
|
||||
dlq.updateOldestMessage()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the dead letter queue
|
||||
func (dlq *DeadLetterQueue) Stop() error {
|
||||
dlq.cancel()
|
||||
|
||||
if dlq.cleanupTicker != nil {
|
||||
dlq.cleanupTicker.Stop()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (dlq *DeadLetterQueue) getTotalMessageCount() int {
|
||||
count := 0
|
||||
for _, messages := range dlq.messages {
|
||||
count += len(messages)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) removeOldestMessage() {
|
||||
var oldestTime time.Time
|
||||
var oldestTopic string
|
||||
var oldestIndex int
|
||||
|
||||
for topic, messages := range dlq.messages {
|
||||
for i, msg := range messages {
|
||||
if oldestTime.IsZero() || msg.FirstFailed.Before(oldestTime) {
|
||||
oldestTime = msg.FirstFailed
|
||||
oldestTopic = topic
|
||||
oldestIndex = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !oldestTime.IsZero() {
|
||||
dlq.removeMessageByIndex(oldestTopic, oldestIndex)
|
||||
dlq.metrics.QueueSize--
|
||||
}
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) removeMessageByIndex(topic string, index int) {
|
||||
messages := dlq.messages[topic]
|
||||
dlq.messages[topic] = append(messages[:index], messages[index+1:]...)
|
||||
|
||||
if len(dlq.messages[topic]) == 0 {
|
||||
delete(dlq.messages, topic)
|
||||
}
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) isPermanentFailure(reason string) bool {
|
||||
for _, pattern := range dlq.config.PermanentFailures {
|
||||
if pattern == reason {
|
||||
return true
|
||||
}
|
||||
// Simple pattern matching (can be enhanced with regex)
|
||||
if len(pattern) > 0 && pattern[len(pattern)-1] == '*' {
|
||||
prefix := pattern[:len(pattern)-1]
|
||||
if len(reason) >= len(prefix) && reason[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) calculateRetryDelay(msg *DLQMessage) time.Duration {
|
||||
switch dlq.config.BackoffStrategy {
|
||||
case BackoffFixed:
|
||||
return dlq.config.InitialRetryDelay
|
||||
|
||||
case BackoffLinear:
|
||||
delay := time.Duration(msg.AttemptCount) * dlq.config.InitialRetryDelay
|
||||
if delay > dlq.config.MaxRetryDelay {
|
||||
return dlq.config.MaxRetryDelay
|
||||
}
|
||||
return delay
|
||||
|
||||
case BackoffExponential:
|
||||
delay := time.Duration(float64(dlq.config.InitialRetryDelay) *
|
||||
pow(dlq.config.BackoffMultiplier, float64(msg.AttemptCount-1)))
|
||||
if delay > dlq.config.MaxRetryDelay {
|
||||
return dlq.config.MaxRetryDelay
|
||||
}
|
||||
return delay
|
||||
|
||||
default:
|
||||
return dlq.config.InitialRetryDelay
|
||||
}
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) attemptReprocess(msg *DLQMessage) error {
|
||||
if dlq.reprocessor == nil {
|
||||
return fmt.Errorf("no reprocessor configured")
|
||||
}
|
||||
|
||||
if !dlq.reprocessor.CanReprocess(msg) {
|
||||
return fmt.Errorf("message cannot be reprocessed")
|
||||
}
|
||||
|
||||
return dlq.reprocessor.Reprocess(dlq.ctx, msg)
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) updateOldestMessage() {
|
||||
var oldest time.Time
|
||||
|
||||
for _, messages := range dlq.messages {
|
||||
for _, msg := range messages {
|
||||
if oldest.IsZero() || msg.FirstFailed.Before(oldest) {
|
||||
oldest = msg.FirstFailed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dlq.metrics.OldestMessage = oldest
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) startCleanupRoutine() {
|
||||
dlq.cleanupTicker = time.NewTicker(dlq.config.ReprocessInterval)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-dlq.cleanupTicker.C:
|
||||
dlq.Cleanup(dlq.config.RetentionTime)
|
||||
case <-dlq.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) startReprocessRoutine() {
|
||||
ticker := time.NewTicker(dlq.config.ReprocessInterval)
|
||||
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
dlq.processRetryableMessages()
|
||||
case <-dlq.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) processRetryableMessages() {
|
||||
dlq.mu.Lock()
|
||||
retryable := dlq.getRetryableMessages()
|
||||
dlq.mu.Unlock()
|
||||
|
||||
// Sort by next retry time
|
||||
sort.Slice(retryable, func(i, j int) bool {
|
||||
return retryable[i].NextRetry.Before(retryable[j].NextRetry)
|
||||
})
|
||||
|
||||
// Process batch
|
||||
batchSize := dlq.config.ReprocessBatchSize
|
||||
if len(retryable) < batchSize {
|
||||
batchSize = len(retryable)
|
||||
}
|
||||
|
||||
for i := 0; i < batchSize; i++ {
|
||||
msg := retryable[i]
|
||||
if time.Now().After(msg.NextRetry) {
|
||||
dlq.ReprocessMessage(msg.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) getRetryableMessages() []*DLQMessage {
|
||||
var retryable []*DLQMessage
|
||||
|
||||
for _, messages := range dlq.messages {
|
||||
for _, msg := range messages {
|
||||
if !msg.Permanent && msg.AttemptCount < msg.MaxRetries {
|
||||
retryable = append(retryable, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return retryable
|
||||
}
|
||||
|
||||
// Implementation of DefaultMessageReprocessor
|
||||
|
||||
func NewDefaultMessageReprocessor(publisher MessagePublisher) *DefaultMessageReprocessor {
|
||||
return &DefaultMessageReprocessor{
|
||||
publisher: publisher,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultMessageReprocessor) Reprocess(ctx context.Context, msg *DLQMessage) error {
|
||||
if r.publisher == nil {
|
||||
return fmt.Errorf("no publisher configured")
|
||||
}
|
||||
|
||||
return r.publisher.Publish(ctx, msg.OriginalMessage)
|
||||
}
|
||||
|
||||
func (r *DefaultMessageReprocessor) CanReprocess(msg *DLQMessage) bool {
|
||||
return !msg.Permanent && msg.AttemptCount < msg.MaxRetries
|
||||
}
|
||||
|
||||
func (r *DefaultMessageReprocessor) ShouldRetry(msg *DLQMessage, err error) bool {
|
||||
// Simple retry logic - can be enhanced based on error types
|
||||
return msg.AttemptCount < msg.MaxRetries
|
||||
}
|
||||
|
||||
// Helper function for power calculation
|
||||
func pow(base, exp float64) float64 {
|
||||
if exp == 0 {
|
||||
return 1
|
||||
}
|
||||
result := base
|
||||
for i := 1; i < int(exp); i++ {
|
||||
result *= base
|
||||
}
|
||||
return result
|
||||
}
|
||||
277
pkg/transport/interfaces.go
Normal file
277
pkg/transport/interfaces.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MessageType represents the type of message being sent
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
// Core message types
|
||||
MessageTypeEvent MessageType = "event"
|
||||
MessageTypeCommand MessageType = "command"
|
||||
MessageTypeResponse MessageType = "response"
|
||||
MessageTypeHeartbeat MessageType = "heartbeat"
|
||||
MessageTypeStatus MessageType = "status"
|
||||
MessageTypeError MessageType = "error"
|
||||
|
||||
// Business-specific message types
|
||||
MessageTypeArbitrage MessageType = "arbitrage"
|
||||
MessageTypeMarketData MessageType = "market_data"
|
||||
MessageTypeExecution MessageType = "execution"
|
||||
MessageTypeRiskCheck MessageType = "risk_check"
|
||||
)
|
||||
|
||||
// Priority levels for message routing
|
||||
type Priority uint8
|
||||
|
||||
const (
|
||||
PriorityLow Priority = iota
|
||||
PriorityNormal
|
||||
PriorityHigh
|
||||
PriorityCritical
|
||||
PriorityEmergency
|
||||
)
|
||||
|
||||
// Message represents a universal message in the system
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Type MessageType `json:"type"`
|
||||
Topic string `json:"topic"`
|
||||
Source string `json:"source"`
|
||||
Destination string `json:"destination"`
|
||||
Priority Priority `json:"priority"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
TTL time.Duration `json:"ttl"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Payload []byte `json:"payload"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
}
|
||||
|
||||
// MessageHandler processes incoming messages
|
||||
type MessageHandler func(ctx context.Context, msg *Message) error
|
||||
|
||||
// Transport defines the interface for different transport mechanisms
|
||||
type Transport interface {
|
||||
// Start initializes the transport
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// Stop gracefully shuts down the transport
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// Send publishes a message
|
||||
Send(ctx context.Context, msg *Message) error
|
||||
|
||||
// Subscribe registers a handler for messages on a topic
|
||||
Subscribe(ctx context.Context, topic string, handler MessageHandler) error
|
||||
|
||||
// Unsubscribe removes a handler for a topic
|
||||
Unsubscribe(ctx context.Context, topic string) error
|
||||
|
||||
// GetStats returns transport statistics
|
||||
GetStats() TransportStats
|
||||
|
||||
// GetType returns the transport type
|
||||
GetType() TransportType
|
||||
|
||||
// IsHealthy checks if the transport is functioning properly
|
||||
IsHealthy() bool
|
||||
}
|
||||
|
||||
// TransportType identifies different transport implementations
|
||||
type TransportType string
|
||||
|
||||
const (
|
||||
TransportTypeSharedMemory TransportType = "shared_memory"
|
||||
TransportTypeUnixSocket TransportType = "unix_socket"
|
||||
TransportTypeTCP TransportType = "tcp"
|
||||
TransportTypeWebSocket TransportType = "websocket"
|
||||
TransportTypeGRPC TransportType = "grpc"
|
||||
)
|
||||
|
||||
// TransportStats provides metrics about transport performance
|
||||
type TransportStats struct {
|
||||
MessagesSent uint64 `json:"messages_sent"`
|
||||
MessagesReceived uint64 `json:"messages_received"`
|
||||
MessagesDropped uint64 `json:"messages_dropped"`
|
||||
BytesSent uint64 `json:"bytes_sent"`
|
||||
BytesReceived uint64 `json:"bytes_received"`
|
||||
Latency time.Duration `json:"latency"`
|
||||
ErrorCount uint64 `json:"error_count"`
|
||||
ConnectedPeers int `json:"connected_peers"`
|
||||
Uptime time.Duration `json:"uptime"`
|
||||
}
|
||||
|
||||
// MessageBus coordinates message routing across multiple transports
|
||||
type MessageBus interface {
|
||||
// Start initializes the message bus
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// Stop gracefully shuts down the message bus
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// RegisterTransport adds a transport to the bus
|
||||
RegisterTransport(transport Transport) error
|
||||
|
||||
// UnregisterTransport removes a transport from the bus
|
||||
UnregisterTransport(transportType TransportType) error
|
||||
|
||||
// Publish sends a message through the optimal transport
|
||||
Publish(ctx context.Context, msg *Message) error
|
||||
|
||||
// Subscribe registers a handler for messages on a topic
|
||||
Subscribe(ctx context.Context, topic string, handler MessageHandler) error
|
||||
|
||||
// Unsubscribe removes a handler for a topic
|
||||
Unsubscribe(ctx context.Context, topic string) error
|
||||
|
||||
// GetTransport returns a specific transport
|
||||
GetTransport(transportType TransportType) (Transport, error)
|
||||
|
||||
// GetStats returns aggregated statistics
|
||||
GetStats() MessageBusStats
|
||||
}
|
||||
|
||||
// MessageBusStats provides comprehensive metrics
|
||||
type MessageBusStats struct {
|
||||
TotalMessages uint64 `json:"total_messages"`
|
||||
MessagesByType map[MessageType]uint64 `json:"messages_by_type"`
|
||||
TransportStats map[TransportType]TransportStats `json:"transport_stats"`
|
||||
ActiveTopics []string `json:"active_topics"`
|
||||
Subscribers int `json:"subscribers"`
|
||||
AverageLatency time.Duration `json:"average_latency"`
|
||||
ThroughputMPS float64 `json:"throughput_mps"` // Messages per second
|
||||
}
|
||||
|
||||
// Router determines the best transport for a message
|
||||
type Router interface {
|
||||
// Route selects the optimal transport for a message
|
||||
Route(msg *Message) (TransportType, error)
|
||||
|
||||
// AddRule adds a routing rule
|
||||
AddRule(rule RoutingRule) error
|
||||
|
||||
// RemoveRule removes a routing rule
|
||||
RemoveRule(ruleID string) error
|
||||
|
||||
// GetRules returns all routing rules
|
||||
GetRules() []RoutingRule
|
||||
}
|
||||
|
||||
// RoutingRule defines how messages should be routed
|
||||
type RoutingRule struct {
|
||||
ID string `json:"id"`
|
||||
Priority int `json:"priority"`
|
||||
Condition Condition `json:"condition"`
|
||||
Transport TransportType `json:"transport"`
|
||||
Fallback TransportType `json:"fallback,omitempty"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// Condition defines when a routing rule applies
|
||||
type Condition struct {
|
||||
MessageType *MessageType `json:"message_type,omitempty"`
|
||||
Topic *string `json:"topic,omitempty"`
|
||||
Priority *Priority `json:"priority,omitempty"`
|
||||
Source *string `json:"source,omitempty"`
|
||||
Destination *string `json:"destination,omitempty"`
|
||||
PayloadSize *int `json:"payload_size,omitempty"`
|
||||
LatencyReq *time.Duration `json:"latency_requirement,omitempty"`
|
||||
}
|
||||
|
||||
// DeadLetterQueue handles failed messages
|
||||
type DeadLetterQueue interface {
|
||||
// Add puts a failed message in the queue
|
||||
Add(ctx context.Context, msg *Message, reason error) error
|
||||
|
||||
// Retry attempts to resend failed messages
|
||||
Retry(ctx context.Context, maxRetries int) error
|
||||
|
||||
// Get retrieves failed messages
|
||||
Get(ctx context.Context, limit int) ([]*FailedMessage, error)
|
||||
|
||||
// Remove deletes a failed message
|
||||
Remove(ctx context.Context, messageID string) error
|
||||
|
||||
// GetStats returns dead letter queue statistics
|
||||
GetStats() DLQStats
|
||||
}
|
||||
|
||||
// FailedMessage represents a message that couldn't be delivered
|
||||
type FailedMessage struct {
|
||||
Message *Message `json:"message"`
|
||||
Reason string `json:"reason"`
|
||||
Attempts int `json:"attempts"`
|
||||
FirstFailed time.Time `json:"first_failed"`
|
||||
LastAttempt time.Time `json:"last_attempt"`
|
||||
}
|
||||
|
||||
// DLQStats provides dead letter queue metrics
|
||||
type DLQStats struct {
|
||||
TotalMessages uint64 `json:"total_messages"`
|
||||
RetryableMessages uint64 `json:"retryable_messages"`
|
||||
PermanentFailures uint64 `json:"permanent_failures"`
|
||||
OldestMessage time.Time `json:"oldest_message"`
|
||||
AverageRetries float64 `json:"average_retries"`
|
||||
}
|
||||
|
||||
// Serializer handles message encoding/decoding
|
||||
type Serializer interface {
|
||||
// Serialize converts a message to bytes
|
||||
Serialize(msg *Message) ([]byte, error)
|
||||
|
||||
// Deserialize converts bytes to a message
|
||||
Deserialize(data []byte) (*Message, error)
|
||||
|
||||
// GetFormat returns the serialization format
|
||||
GetFormat() SerializationFormat
|
||||
}
|
||||
|
||||
// SerializationFormat defines encoding types
|
||||
type SerializationFormat string
|
||||
|
||||
const (
|
||||
FormatJSON SerializationFormat = "json"
|
||||
FormatProtobuf SerializationFormat = "protobuf"
|
||||
FormatMsgPack SerializationFormat = "msgpack"
|
||||
FormatAvro SerializationFormat = "avro"
|
||||
)
|
||||
|
||||
// Persistence handles message storage
|
||||
type Persistence interface {
|
||||
// Store saves a message for persistence
|
||||
Store(ctx context.Context, msg *Message) error
|
||||
|
||||
// Retrieve gets a stored message
|
||||
Retrieve(ctx context.Context, messageID string) (*Message, error)
|
||||
|
||||
// Delete removes a stored message
|
||||
Delete(ctx context.Context, messageID string) error
|
||||
|
||||
// List returns stored messages matching criteria
|
||||
List(ctx context.Context, criteria PersistenceCriteria) ([]*Message, error)
|
||||
|
||||
// GetStats returns persistence statistics
|
||||
GetStats() PersistenceStats
|
||||
}
|
||||
|
||||
// PersistenceCriteria defines search parameters
|
||||
type PersistenceCriteria struct {
|
||||
Topic *string `json:"topic,omitempty"`
|
||||
MessageType *MessageType `json:"message_type,omitempty"`
|
||||
Source *string `json:"source,omitempty"`
|
||||
FromTime *time.Time `json:"from_time,omitempty"`
|
||||
ToTime *time.Time `json:"to_time,omitempty"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
// PersistenceStats provides storage metrics
|
||||
type PersistenceStats struct {
|
||||
StoredMessages uint64 `json:"stored_messages"`
|
||||
StorageSize uint64 `json:"storage_size_bytes"`
|
||||
OldestMessage time.Time `json:"oldest_message"`
|
||||
NewestMessage time.Time `json:"newest_message"`
|
||||
}
|
||||
230
pkg/transport/memory_transport.go
Normal file
230
pkg/transport/memory_transport.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryTransport implements in-memory message transport for local communication
|
||||
type MemoryTransport struct {
|
||||
channels map[string]chan *Message
|
||||
metrics TransportMetrics
|
||||
connected bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMemoryTransport creates a new in-memory transport
|
||||
func NewMemoryTransport() *MemoryTransport {
|
||||
return &MemoryTransport{
|
||||
channels: make(map[string]chan *Message),
|
||||
metrics: TransportMetrics{},
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes the transport connection
|
||||
func (mt *MemoryTransport) Connect(ctx context.Context) error {
|
||||
mt.mu.Lock()
|
||||
defer mt.mu.Unlock()
|
||||
|
||||
if mt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
mt.connected = true
|
||||
mt.metrics.Connections = 1
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect closes the transport connection
|
||||
func (mt *MemoryTransport) Disconnect(ctx context.Context) error {
|
||||
mt.mu.Lock()
|
||||
defer mt.mu.Unlock()
|
||||
|
||||
if !mt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close all channels
|
||||
for _, ch := range mt.channels {
|
||||
close(ch)
|
||||
}
|
||||
mt.channels = make(map[string]chan *Message)
|
||||
mt.connected = false
|
||||
mt.metrics.Connections = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send transmits a message through the memory transport
|
||||
func (mt *MemoryTransport) Send(ctx context.Context, msg *Message) error {
|
||||
start := time.Now()
|
||||
|
||||
mt.mu.RLock()
|
||||
if !mt.connected {
|
||||
mt.mu.RUnlock()
|
||||
mt.metrics.Errors++
|
||||
return fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
// Get or create channel for topic
|
||||
ch, exists := mt.channels[msg.Topic]
|
||||
if !exists {
|
||||
mt.mu.RUnlock()
|
||||
mt.mu.Lock()
|
||||
// Double-check after acquiring write lock
|
||||
if ch, exists = mt.channels[msg.Topic]; !exists {
|
||||
ch = make(chan *Message, 1000) // Buffered channel
|
||||
mt.channels[msg.Topic] = ch
|
||||
}
|
||||
mt.mu.Unlock()
|
||||
} else {
|
||||
mt.mu.RUnlock()
|
||||
}
|
||||
|
||||
// Send message
|
||||
select {
|
||||
case ch <- msg:
|
||||
mt.updateSendMetrics(msg, time.Since(start))
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
mt.metrics.Errors++
|
||||
return ctx.Err()
|
||||
default:
|
||||
mt.metrics.Errors++
|
||||
return fmt.Errorf("channel full for topic: %s", msg.Topic)
|
||||
}
|
||||
}
|
||||
|
||||
// Receive returns a channel for receiving messages
|
||||
func (mt *MemoryTransport) Receive(ctx context.Context) (<-chan *Message, error) {
|
||||
mt.mu.RLock()
|
||||
defer mt.mu.RUnlock()
|
||||
|
||||
if !mt.connected {
|
||||
return nil, fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
// Create a merged channel that receives from all topic channels
|
||||
merged := make(chan *Message, 1000)
|
||||
|
||||
go func() {
|
||||
defer close(merged)
|
||||
|
||||
// Use a wait group to handle multiple topic channels
|
||||
var wg sync.WaitGroup
|
||||
|
||||
mt.mu.RLock()
|
||||
for topic, ch := range mt.channels {
|
||||
wg.Add(1)
|
||||
go func(topicCh <-chan *Message, topicName string) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-topicCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case merged <- msg:
|
||||
mt.updateReceiveMetrics(msg)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}(ch, topic)
|
||||
}
|
||||
mt.mu.RUnlock()
|
||||
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
// Health returns the health status of the transport
|
||||
func (mt *MemoryTransport) Health() ComponentHealth {
|
||||
mt.mu.RLock()
|
||||
defer mt.mu.RUnlock()
|
||||
|
||||
status := "unhealthy"
|
||||
if mt.connected {
|
||||
status = "healthy"
|
||||
}
|
||||
|
||||
return ComponentHealth{
|
||||
Status: status,
|
||||
LastCheck: time.Now(),
|
||||
ResponseTime: time.Microsecond, // Very fast for memory transport
|
||||
ErrorCount: mt.metrics.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns transport-specific metrics
|
||||
func (mt *MemoryTransport) GetMetrics() TransportMetrics {
|
||||
mt.mu.RLock()
|
||||
defer mt.mu.RUnlock()
|
||||
|
||||
// Create a copy to avoid race conditions
|
||||
return TransportMetrics{
|
||||
BytesSent: mt.metrics.BytesSent,
|
||||
BytesReceived: mt.metrics.BytesReceived,
|
||||
MessagesSent: mt.metrics.MessagesSent,
|
||||
MessagesReceived: mt.metrics.MessagesReceived,
|
||||
Connections: mt.metrics.Connections,
|
||||
Errors: mt.metrics.Errors,
|
||||
Latency: mt.metrics.Latency,
|
||||
}
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (mt *MemoryTransport) updateSendMetrics(msg *Message, latency time.Duration) {
|
||||
mt.mu.Lock()
|
||||
defer mt.mu.Unlock()
|
||||
|
||||
mt.metrics.MessagesSent++
|
||||
mt.metrics.Latency = latency
|
||||
|
||||
// Estimate message size (simplified)
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
mt.metrics.BytesSent += messageSize
|
||||
}
|
||||
|
||||
func (mt *MemoryTransport) updateReceiveMetrics(msg *Message) {
|
||||
mt.mu.Lock()
|
||||
defer mt.mu.Unlock()
|
||||
|
||||
mt.metrics.MessagesReceived++
|
||||
|
||||
// Estimate message size (simplified)
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
mt.metrics.BytesReceived += messageSize
|
||||
}
|
||||
|
||||
// GetChannelForTopic returns the channel for a specific topic (for testing/debugging)
|
||||
func (mt *MemoryTransport) GetChannelForTopic(topic string) (<-chan *Message, bool) {
|
||||
mt.mu.RLock()
|
||||
defer mt.mu.RUnlock()
|
||||
|
||||
ch, exists := mt.channels[topic]
|
||||
return ch, exists
|
||||
}
|
||||
|
||||
// GetTopicCount returns the number of active topic channels
|
||||
func (mt *MemoryTransport) GetTopicCount() int {
|
||||
mt.mu.RLock()
|
||||
defer mt.mu.RUnlock()
|
||||
|
||||
return len(mt.channels)
|
||||
}
|
||||
453
pkg/transport/message_bus.go
Normal file
453
pkg/transport/message_bus.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MessageType represents the type of message being sent
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
MessageTypeEvent MessageType = "event"
|
||||
MessageTypeCommand MessageType = "command"
|
||||
MessageTypeQuery MessageType = "query"
|
||||
MessageTypeResponse MessageType = "response"
|
||||
MessageTypeNotification MessageType = "notification"
|
||||
MessageTypeHeartbeat MessageType = "heartbeat"
|
||||
)
|
||||
|
||||
// MessagePriority defines message processing priority
|
||||
type MessagePriority int
|
||||
|
||||
const (
|
||||
PriorityLow MessagePriority = iota
|
||||
PriorityNormal
|
||||
PriorityHigh
|
||||
PriorityCritical
|
||||
)
|
||||
|
||||
// Message represents a universal message in the system
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Type MessageType `json:"type"`
|
||||
Topic string `json:"topic"`
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target,omitempty"`
|
||||
Priority MessagePriority `json:"priority"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Data interface{} `json:"data"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
CorrelationID string `json:"correlation_id,omitempty"`
|
||||
TTL time.Duration `json:"ttl,omitempty"`
|
||||
Retries int `json:"retries"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MessageHandler processes incoming messages
|
||||
type MessageHandler func(ctx context.Context, msg *Message) error
|
||||
|
||||
// MessageFilter determines if a message should be processed
|
||||
type MessageFilter func(msg *Message) bool
|
||||
|
||||
// Subscription represents a topic subscription
|
||||
type Subscription struct {
|
||||
ID string
|
||||
Topic string
|
||||
Filter MessageFilter
|
||||
Handler MessageHandler
|
||||
Options SubscriptionOptions
|
||||
created time.Time
|
||||
active bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SubscriptionOptions configures subscription behavior
|
||||
type SubscriptionOptions struct {
|
||||
QueueSize int
|
||||
BatchSize int
|
||||
BatchTimeout time.Duration
|
||||
DLQEnabled bool
|
||||
RetryEnabled bool
|
||||
Persistent bool
|
||||
Durable bool
|
||||
}
|
||||
|
||||
// MessageBusInterface defines the universal message bus contract
|
||||
type MessageBusInterface interface {
|
||||
// Core messaging operations
|
||||
Publish(ctx context.Context, msg *Message) error
|
||||
Subscribe(topic string, handler MessageHandler, opts ...SubscriptionOption) (*Subscription, error)
|
||||
Unsubscribe(subscriptionID string) error
|
||||
|
||||
// Advanced messaging patterns
|
||||
Request(ctx context.Context, msg *Message, timeout time.Duration) (*Message, error)
|
||||
Reply(ctx context.Context, originalMsg *Message, response *Message) error
|
||||
|
||||
// Topic management
|
||||
CreateTopic(topic string, config TopicConfig) error
|
||||
DeleteTopic(topic string) error
|
||||
ListTopics() []string
|
||||
GetTopicInfo(topic string) (*TopicInfo, error)
|
||||
|
||||
// Queue operations
|
||||
QueueMessage(topic string, msg *Message) error
|
||||
DequeueMessage(topic string, timeout time.Duration) (*Message, error)
|
||||
PeekMessage(topic string) (*Message, error)
|
||||
|
||||
// Dead letter queue
|
||||
GetDLQMessages(topic string) ([]*Message, error)
|
||||
ReprocessDLQMessage(messageID string) error
|
||||
PurgeDLQ(topic string) error
|
||||
|
||||
// Lifecycle management
|
||||
Start(ctx context.Context) error
|
||||
Stop(ctx context.Context) error
|
||||
Health() HealthStatus
|
||||
|
||||
// Metrics and monitoring
|
||||
GetMetrics() MessageBusMetrics
|
||||
GetSubscriptions() []*Subscription
|
||||
GetActiveConnections() int
|
||||
}
|
||||
|
||||
// SubscriptionOption configures subscription behavior
|
||||
type SubscriptionOption func(*SubscriptionOptions)
|
||||
|
||||
// TopicConfig defines topic configuration
|
||||
type TopicConfig struct {
|
||||
Persistent bool
|
||||
Replicated bool
|
||||
RetentionPolicy RetentionPolicy
|
||||
Partitions int
|
||||
MaxMessageSize int64
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// RetentionPolicy defines message retention behavior
|
||||
type RetentionPolicy struct {
|
||||
MaxMessages int
|
||||
MaxAge time.Duration
|
||||
MaxSize int64
|
||||
}
|
||||
|
||||
// TopicInfo provides topic statistics
|
||||
type TopicInfo struct {
|
||||
Name string
|
||||
Config TopicConfig
|
||||
MessageCount int64
|
||||
SubscriberCount int
|
||||
LastActivity time.Time
|
||||
SizeBytes int64
|
||||
}
|
||||
|
||||
// HealthStatus represents system health
|
||||
type HealthStatus struct {
|
||||
Status string
|
||||
Uptime time.Duration
|
||||
LastCheck time.Time
|
||||
Components map[string]ComponentHealth
|
||||
Errors []HealthError
|
||||
}
|
||||
|
||||
// ComponentHealth represents component-specific health
|
||||
type ComponentHealth struct {
|
||||
Status string
|
||||
LastCheck time.Time
|
||||
ResponseTime time.Duration
|
||||
ErrorCount int64
|
||||
}
|
||||
|
||||
// HealthError represents a health check error
|
||||
type HealthError struct {
|
||||
Component string
|
||||
Message string
|
||||
Timestamp time.Time
|
||||
Severity string
|
||||
}
|
||||
|
||||
// MessageBusMetrics provides operational metrics
|
||||
type MessageBusMetrics struct {
|
||||
MessagesPublished int64
|
||||
MessagesConsumed int64
|
||||
MessagesFailed int64
|
||||
MessagesInDLQ int64
|
||||
ActiveSubscriptions int
|
||||
TopicCount int
|
||||
AverageLatency time.Duration
|
||||
ThroughputPerSec float64
|
||||
ErrorRate float64
|
||||
MemoryUsage int64
|
||||
CPUUsage float64
|
||||
}
|
||||
|
||||
// TransportType defines available transport mechanisms
|
||||
type TransportType string
|
||||
|
||||
const (
|
||||
TransportMemory TransportType = "memory"
|
||||
TransportUnixSocket TransportType = "unix"
|
||||
TransportTCP TransportType = "tcp"
|
||||
TransportWebSocket TransportType = "websocket"
|
||||
TransportRedis TransportType = "redis"
|
||||
TransportNATS TransportType = "nats"
|
||||
)
|
||||
|
||||
// TransportConfig configures transport layer
|
||||
type TransportConfig struct {
|
||||
Type TransportType
|
||||
Address string
|
||||
Options map[string]interface{}
|
||||
RetryConfig RetryConfig
|
||||
SecurityConfig SecurityConfig
|
||||
}
|
||||
|
||||
// RetryConfig defines retry behavior
|
||||
type RetryConfig struct {
|
||||
MaxRetries int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
BackoffFactor float64
|
||||
Jitter bool
|
||||
}
|
||||
|
||||
// SecurityConfig defines security settings
|
||||
type SecurityConfig struct {
|
||||
Enabled bool
|
||||
TLSConfig *TLSConfig
|
||||
AuthConfig *AuthConfig
|
||||
Encryption bool
|
||||
Compression bool
|
||||
}
|
||||
|
||||
// TLSConfig for secure transport
|
||||
type TLSConfig struct {
|
||||
CertFile string
|
||||
KeyFile string
|
||||
CAFile string
|
||||
Verify bool
|
||||
}
|
||||
|
||||
// AuthConfig for authentication
|
||||
type AuthConfig struct {
|
||||
Username string
|
||||
Password string
|
||||
Token string
|
||||
Method string
|
||||
}
|
||||
|
||||
// UniversalMessageBus implements MessageBusInterface
|
||||
type UniversalMessageBus struct {
|
||||
config MessageBusConfig
|
||||
transports map[TransportType]Transport
|
||||
router *MessageRouter
|
||||
topics map[string]*Topic
|
||||
subscriptions map[string]*Subscription
|
||||
dlq *DeadLetterQueue
|
||||
metrics *MetricsCollector
|
||||
persistence PersistenceLayer
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
started bool
|
||||
}
|
||||
|
||||
// MessageBusConfig configures the message bus
|
||||
type MessageBusConfig struct {
|
||||
DefaultTransport TransportType
|
||||
EnablePersistence bool
|
||||
EnableMetrics bool
|
||||
EnableDLQ bool
|
||||
MaxMessageSize int64
|
||||
DefaultTTL time.Duration
|
||||
HealthCheckInterval time.Duration
|
||||
CleanupInterval time.Duration
|
||||
}
|
||||
|
||||
// Transport interface for different transport mechanisms
|
||||
type Transport interface {
|
||||
Send(ctx context.Context, msg *Message) error
|
||||
Receive(ctx context.Context) (<-chan *Message, error)
|
||||
Connect(ctx context.Context) error
|
||||
Disconnect(ctx context.Context) error
|
||||
Health() ComponentHealth
|
||||
GetMetrics() TransportMetrics
|
||||
}
|
||||
|
||||
// TransportMetrics for transport-specific metrics
|
||||
type TransportMetrics struct {
|
||||
BytesSent int64
|
||||
BytesReceived int64
|
||||
MessagesSent int64
|
||||
MessagesReceived int64
|
||||
Connections int
|
||||
Errors int64
|
||||
Latency time.Duration
|
||||
}
|
||||
|
||||
// MessageRouter handles message routing logic
|
||||
type MessageRouter struct {
|
||||
rules []RoutingRule
|
||||
fallback TransportType
|
||||
loadBalancer LoadBalancer
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// RoutingRule defines message routing logic
|
||||
type RoutingRule struct {
|
||||
Condition MessageFilter
|
||||
Transport TransportType
|
||||
Priority int
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// LoadBalancer for transport selection
|
||||
type LoadBalancer interface {
|
||||
SelectTransport(transports []TransportType, msg *Message) TransportType
|
||||
UpdateStats(transport TransportType, latency time.Duration, success bool)
|
||||
}
|
||||
|
||||
// Topic represents a message topic
|
||||
type Topic struct {
|
||||
Name string
|
||||
Config TopicConfig
|
||||
Messages []StoredMessage
|
||||
Subscribers []*Subscription
|
||||
Created time.Time
|
||||
LastActivity time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// StoredMessage represents a persisted message
|
||||
type StoredMessage struct {
|
||||
Message *Message
|
||||
Stored time.Time
|
||||
Processed bool
|
||||
}
|
||||
|
||||
// DeadLetterQueue handles failed messages
|
||||
type DeadLetterQueue struct {
|
||||
messages map[string][]*Message
|
||||
config DLQConfig
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// DLQConfig configures dead letter queue
|
||||
type DLQConfig struct {
|
||||
MaxMessages int
|
||||
MaxRetries int
|
||||
RetentionTime time.Duration
|
||||
AutoReprocess bool
|
||||
}
|
||||
|
||||
// MetricsCollector gathers operational metrics
|
||||
type MetricsCollector struct {
|
||||
metrics map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// PersistenceLayer handles message persistence
|
||||
type PersistenceLayer interface {
|
||||
Store(msg *Message) error
|
||||
Retrieve(id string) (*Message, error)
|
||||
Delete(id string) error
|
||||
List(topic string, limit int) ([]*Message, error)
|
||||
Cleanup(maxAge time.Duration) error
|
||||
}
|
||||
|
||||
// Factory functions for common subscription options
|
||||
func WithQueueSize(size int) SubscriptionOption {
|
||||
return func(opts *SubscriptionOptions) {
|
||||
opts.QueueSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithBatchProcessing(size int, timeout time.Duration) SubscriptionOption {
|
||||
return func(opts *SubscriptionOptions) {
|
||||
opts.BatchSize = size
|
||||
opts.BatchTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
func WithDLQ(enabled bool) SubscriptionOption {
|
||||
return func(opts *SubscriptionOptions) {
|
||||
opts.DLQEnabled = enabled
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetry(enabled bool) SubscriptionOption {
|
||||
return func(opts *SubscriptionOptions) {
|
||||
opts.RetryEnabled = enabled
|
||||
}
|
||||
}
|
||||
|
||||
func WithPersistence(enabled bool) SubscriptionOption {
|
||||
return func(opts *SubscriptionOptions) {
|
||||
opts.Persistent = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// NewUniversalMessageBus creates a new message bus instance
|
||||
func NewUniversalMessageBus(config MessageBusConfig) *UniversalMessageBus {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &UniversalMessageBus{
|
||||
config: config,
|
||||
transports: make(map[TransportType]Transport),
|
||||
topics: make(map[string]*Topic),
|
||||
subscriptions: make(map[string]*Subscription),
|
||||
router: NewMessageRouter(),
|
||||
dlq: NewDeadLetterQueue(DLQConfig{}),
|
||||
metrics: NewMetricsCollector(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageRouter creates a new message router
|
||||
func NewMessageRouter() *MessageRouter {
|
||||
return &MessageRouter{
|
||||
rules: make([]RoutingRule, 0),
|
||||
fallback: TransportMemory,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDeadLetterQueue creates a new dead letter queue
|
||||
func NewDeadLetterQueue(config DLQConfig) *DeadLetterQueue {
|
||||
return &DeadLetterQueue{
|
||||
messages: make(map[string][]*Message),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMetricsCollector creates a new metrics collector
|
||||
func NewMetricsCollector() *MetricsCollector {
|
||||
return &MetricsCollector{
|
||||
metrics: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to generate message ID
|
||||
func GenerateMessageID() string {
|
||||
return fmt.Sprintf("msg_%d_%d", time.Now().UnixNano(), time.Now().Nanosecond())
|
||||
}
|
||||
|
||||
// Helper function to create message with defaults
|
||||
func NewMessage(msgType MessageType, topic string, source string, data interface{}) *Message {
|
||||
return &Message{
|
||||
ID: GenerateMessageID(),
|
||||
Type: msgType,
|
||||
Topic: topic,
|
||||
Source: source,
|
||||
Priority: PriorityNormal,
|
||||
Timestamp: time.Now(),
|
||||
Data: data,
|
||||
Headers: make(map[string]string),
|
||||
Metadata: make(map[string]interface{}),
|
||||
Retries: 0,
|
||||
MaxRetries: 3,
|
||||
}
|
||||
}
|
||||
743
pkg/transport/message_bus_impl.go
Normal file
743
pkg/transport/message_bus_impl.go
Normal file
@@ -0,0 +1,743 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Publish sends a message to the specified topic
|
||||
func (mb *UniversalMessageBus) Publish(ctx context.Context, msg *Message) error {
|
||||
if !mb.started {
|
||||
return fmt.Errorf("message bus not started")
|
||||
}
|
||||
|
||||
// Validate message
|
||||
if err := mb.validateMessage(msg); err != nil {
|
||||
return fmt.Errorf("invalid message: %w", err)
|
||||
}
|
||||
|
||||
// Set timestamp if not set
|
||||
if msg.Timestamp.IsZero() {
|
||||
msg.Timestamp = time.Now()
|
||||
}
|
||||
|
||||
// Set ID if not set
|
||||
if msg.ID == "" {
|
||||
msg.ID = GenerateMessageID()
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
mb.metrics.IncrementCounter("messages_published_total")
|
||||
mb.metrics.RecordLatency("publish_latency", time.Since(msg.Timestamp))
|
||||
|
||||
// Route message to appropriate transport
|
||||
transport, err := mb.router.RouteMessage(msg, mb.transports)
|
||||
if err != nil {
|
||||
mb.metrics.IncrementCounter("routing_errors_total")
|
||||
return fmt.Errorf("routing failed: %w", err)
|
||||
}
|
||||
|
||||
// Send via transport
|
||||
if err := transport.Send(ctx, msg); err != nil {
|
||||
mb.metrics.IncrementCounter("send_errors_total")
|
||||
// Try dead letter queue if enabled
|
||||
if mb.config.EnableDLQ {
|
||||
if dlqErr := mb.dlq.AddMessage(msg.Topic, msg); dlqErr != nil {
|
||||
return fmt.Errorf("send failed and DLQ failed: %v, original error: %w", dlqErr, err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("send failed: %w", err)
|
||||
}
|
||||
|
||||
// Store in topic if persistence enabled
|
||||
if mb.config.EnablePersistence {
|
||||
if err := mb.addMessageToTopic(msg); err != nil {
|
||||
// Log error but don't fail the publish
|
||||
mb.metrics.IncrementCounter("persistence_errors_total")
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver to local subscribers
|
||||
go mb.deliverToSubscribers(ctx, msg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe creates a subscription to a topic
|
||||
func (mb *UniversalMessageBus) Subscribe(topic string, handler MessageHandler, opts ...SubscriptionOption) (*Subscription, error) {
|
||||
if !mb.started {
|
||||
return nil, fmt.Errorf("message bus not started")
|
||||
}
|
||||
|
||||
// Apply subscription options
|
||||
options := SubscriptionOptions{
|
||||
QueueSize: 1000,
|
||||
BatchSize: 1,
|
||||
BatchTimeout: time.Second,
|
||||
DLQEnabled: mb.config.EnableDLQ,
|
||||
RetryEnabled: true,
|
||||
Persistent: false,
|
||||
Durable: false,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
// Create subscription
|
||||
subscription := &Subscription{
|
||||
ID: fmt.Sprintf("sub_%s_%d", topic, time.Now().UnixNano()),
|
||||
Topic: topic,
|
||||
Handler: handler,
|
||||
Options: options,
|
||||
created: time.Now(),
|
||||
active: true,
|
||||
}
|
||||
|
||||
mb.mu.Lock()
|
||||
mb.subscriptions[subscription.ID] = subscription
|
||||
mb.mu.Unlock()
|
||||
|
||||
// Add to topic subscribers
|
||||
mb.addSubscriberToTopic(topic, subscription)
|
||||
|
||||
mb.metrics.IncrementCounter("subscriptions_created_total")
|
||||
|
||||
return subscription, nil
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription
|
||||
func (mb *UniversalMessageBus) Unsubscribe(subscriptionID string) error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
subscription, exists := mb.subscriptions[subscriptionID]
|
||||
if !exists {
|
||||
return fmt.Errorf("subscription not found: %s", subscriptionID)
|
||||
}
|
||||
|
||||
// Mark as inactive
|
||||
subscription.mu.Lock()
|
||||
subscription.active = false
|
||||
subscription.mu.Unlock()
|
||||
|
||||
// Remove from subscriptions map
|
||||
delete(mb.subscriptions, subscriptionID)
|
||||
|
||||
// Remove from topic subscribers
|
||||
mb.removeSubscriberFromTopic(subscription.Topic, subscriptionID)
|
||||
|
||||
mb.metrics.IncrementCounter("subscriptions_removed_total")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Request sends a request and waits for a response
|
||||
func (mb *UniversalMessageBus) Request(ctx context.Context, msg *Message, timeout time.Duration) (*Message, error) {
|
||||
if !mb.started {
|
||||
return nil, fmt.Errorf("message bus not started")
|
||||
}
|
||||
|
||||
// Set correlation ID for request-response
|
||||
if msg.CorrelationID == "" {
|
||||
msg.CorrelationID = GenerateMessageID()
|
||||
}
|
||||
|
||||
// Create response channel
|
||||
responseChannel := make(chan *Message, 1)
|
||||
defer close(responseChannel)
|
||||
|
||||
// Subscribe to response topic
|
||||
responseTopic := fmt.Sprintf("response.%s", msg.CorrelationID)
|
||||
subscription, err := mb.Subscribe(responseTopic, func(ctx context.Context, response *Message) error {
|
||||
select {
|
||||
case responseChannel <- response:
|
||||
default:
|
||||
// Channel full, ignore
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to subscribe to response topic: %w", err)
|
||||
}
|
||||
defer mb.Unsubscribe(subscription.ID)
|
||||
|
||||
// Send request
|
||||
if err := mb.Publish(ctx, msg); err != nil {
|
||||
return nil, fmt.Errorf("failed to publish request: %w", err)
|
||||
}
|
||||
|
||||
// Wait for response with timeout
|
||||
select {
|
||||
case response := <-responseChannel:
|
||||
return response, nil
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("request timeout after %v", timeout)
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Reply sends a response to a request
|
||||
func (mb *UniversalMessageBus) Reply(ctx context.Context, originalMsg *Message, response *Message) error {
|
||||
if originalMsg.CorrelationID == "" {
|
||||
return fmt.Errorf("original message has no correlation ID")
|
||||
}
|
||||
|
||||
// Set response properties
|
||||
response.Type = MessageTypeResponse
|
||||
response.Topic = fmt.Sprintf("response.%s", originalMsg.CorrelationID)
|
||||
response.CorrelationID = originalMsg.CorrelationID
|
||||
response.Target = originalMsg.Source
|
||||
|
||||
return mb.Publish(ctx, response)
|
||||
}
|
||||
|
||||
// CreateTopic creates a new topic with configuration
|
||||
func (mb *UniversalMessageBus) CreateTopic(topicName string, config TopicConfig) error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
if _, exists := mb.topics[topicName]; exists {
|
||||
return fmt.Errorf("topic already exists: %s", topicName)
|
||||
}
|
||||
|
||||
topic := &Topic{
|
||||
Name: topicName,
|
||||
Config: config,
|
||||
Messages: make([]StoredMessage, 0),
|
||||
Subscribers: make([]*Subscription, 0),
|
||||
Created: time.Now(),
|
||||
LastActivity: time.Now(),
|
||||
}
|
||||
|
||||
mb.topics[topicName] = topic
|
||||
mb.metrics.IncrementCounter("topics_created_total")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTopic removes a topic
|
||||
func (mb *UniversalMessageBus) DeleteTopic(topicName string) error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
topic, exists := mb.topics[topicName]
|
||||
if !exists {
|
||||
return fmt.Errorf("topic not found: %s", topicName)
|
||||
}
|
||||
|
||||
// Remove all subscribers
|
||||
for _, sub := range topic.Subscribers {
|
||||
mb.Unsubscribe(sub.ID)
|
||||
}
|
||||
|
||||
delete(mb.topics, topicName)
|
||||
mb.metrics.IncrementCounter("topics_deleted_total")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListTopics returns all topic names
|
||||
func (mb *UniversalMessageBus) ListTopics() []string {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
|
||||
topics := make([]string, 0, len(mb.topics))
|
||||
for name := range mb.topics {
|
||||
topics = append(topics, name)
|
||||
}
|
||||
|
||||
return topics
|
||||
}
|
||||
|
||||
// GetTopicInfo returns topic information
|
||||
func (mb *UniversalMessageBus) GetTopicInfo(topicName string) (*TopicInfo, error) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
|
||||
topic, exists := mb.topics[topicName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("topic not found: %s", topicName)
|
||||
}
|
||||
|
||||
topic.mu.RLock()
|
||||
defer topic.mu.RUnlock()
|
||||
|
||||
// Calculate size
|
||||
var sizeBytes int64
|
||||
for _, stored := range topic.Messages {
|
||||
// Rough estimation of message size
|
||||
sizeBytes += int64(len(fmt.Sprintf("%+v", stored.Message)))
|
||||
}
|
||||
|
||||
return &TopicInfo{
|
||||
Name: topic.Name,
|
||||
Config: topic.Config,
|
||||
MessageCount: int64(len(topic.Messages)),
|
||||
SubscriberCount: len(topic.Subscribers),
|
||||
LastActivity: topic.LastActivity,
|
||||
SizeBytes: sizeBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueueMessage adds a message to a topic queue
|
||||
func (mb *UniversalMessageBus) QueueMessage(topic string, msg *Message) error {
|
||||
return mb.addMessageToTopic(msg)
|
||||
}
|
||||
|
||||
// DequeueMessage removes a message from a topic queue
|
||||
func (mb *UniversalMessageBus) DequeueMessage(topic string, timeout time.Duration) (*Message, error) {
|
||||
start := time.Now()
|
||||
|
||||
for time.Since(start) < timeout {
|
||||
mb.mu.RLock()
|
||||
topicObj, exists := mb.topics[topic]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("topic not found: %s", topic)
|
||||
}
|
||||
|
||||
topicObj.mu.Lock()
|
||||
if len(topicObj.Messages) > 0 {
|
||||
// Get first unprocessed message
|
||||
for i, stored := range topicObj.Messages {
|
||||
if !stored.Processed {
|
||||
topicObj.Messages[i].Processed = true
|
||||
topicObj.mu.Unlock()
|
||||
return stored.Message, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
topicObj.mu.Unlock()
|
||||
|
||||
// Wait a bit before trying again
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no message available within timeout")
|
||||
}
|
||||
|
||||
// PeekMessage returns the next message without removing it
|
||||
func (mb *UniversalMessageBus) PeekMessage(topic string) (*Message, error) {
|
||||
mb.mu.RLock()
|
||||
topicObj, exists := mb.topics[topic]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("topic not found: %s", topic)
|
||||
}
|
||||
|
||||
topicObj.mu.RLock()
|
||||
defer topicObj.mu.RUnlock()
|
||||
|
||||
for _, stored := range topicObj.Messages {
|
||||
if !stored.Processed {
|
||||
return stored.Message, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no messages available")
|
||||
}
|
||||
|
||||
// Start initializes and starts the message bus
|
||||
func (mb *UniversalMessageBus) Start(ctx context.Context) error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
if mb.started {
|
||||
return fmt.Errorf("message bus already started")
|
||||
}
|
||||
|
||||
// Initialize default transport if none configured
|
||||
if len(mb.transports) == 0 {
|
||||
memTransport := NewMemoryTransport()
|
||||
mb.transports[TransportMemory] = memTransport
|
||||
if err := memTransport.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("failed to connect default transport: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start background routines
|
||||
go mb.healthCheckLoop()
|
||||
go mb.cleanupLoop()
|
||||
go mb.metricsLoop()
|
||||
|
||||
mb.started = true
|
||||
mb.metrics.RecordEvent("message_bus_started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the message bus
|
||||
func (mb *UniversalMessageBus) Stop(ctx context.Context) error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
if !mb.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel context to stop background routines
|
||||
mb.cancel()
|
||||
|
||||
// Disconnect all transports
|
||||
for _, transport := range mb.transports {
|
||||
if err := transport.Disconnect(ctx); err != nil {
|
||||
// Log error but continue shutdown
|
||||
}
|
||||
}
|
||||
|
||||
mb.started = false
|
||||
mb.metrics.RecordEvent("message_bus_stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Health returns the current health status
|
||||
func (mb *UniversalMessageBus) Health() HealthStatus {
|
||||
components := make(map[string]ComponentHealth)
|
||||
|
||||
// Check transport health
|
||||
for transportType, transport := range mb.transports {
|
||||
components[string(transportType)] = transport.Health()
|
||||
}
|
||||
|
||||
// Overall status
|
||||
status := "healthy"
|
||||
var errors []HealthError
|
||||
|
||||
for name, component := range components {
|
||||
if component.Status != "healthy" {
|
||||
status = "degraded"
|
||||
errors = append(errors, HealthError{
|
||||
Component: name,
|
||||
Message: fmt.Sprintf("Component %s is %s", name, component.Status),
|
||||
Timestamp: time.Now(),
|
||||
Severity: "warning",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return HealthStatus{
|
||||
Status: status,
|
||||
Uptime: time.Since(time.Now()), // Would track actual uptime
|
||||
LastCheck: time.Now(),
|
||||
Components: components,
|
||||
Errors: errors,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns current operational metrics
|
||||
func (mb *UniversalMessageBus) GetMetrics() MessageBusMetrics {
|
||||
metrics := mb.metrics.GetAll()
|
||||
|
||||
return MessageBusMetrics{
|
||||
MessagesPublished: mb.getMetricInt64("messages_published_total"),
|
||||
MessagesConsumed: mb.getMetricInt64("messages_consumed_total"),
|
||||
MessagesFailed: mb.getMetricInt64("send_errors_total"),
|
||||
MessagesInDLQ: int64(mb.dlq.GetMessageCount()),
|
||||
ActiveSubscriptions: len(mb.subscriptions),
|
||||
TopicCount: len(mb.topics),
|
||||
AverageLatency: mb.getMetricDuration("average_latency"),
|
||||
ThroughputPerSec: mb.getMetricFloat64("throughput_per_second"),
|
||||
ErrorRate: mb.getMetricFloat64("error_rate"),
|
||||
MemoryUsage: mb.getMetricInt64("memory_usage_bytes"),
|
||||
CPUUsage: mb.getMetricFloat64("cpu_usage_percent"),
|
||||
}
|
||||
}
|
||||
|
||||
// GetSubscriptions returns all active subscriptions
|
||||
func (mb *UniversalMessageBus) GetSubscriptions() []*Subscription {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
|
||||
subscriptions := make([]*Subscription, 0, len(mb.subscriptions))
|
||||
for _, sub := range mb.subscriptions {
|
||||
subscriptions = append(subscriptions, sub)
|
||||
}
|
||||
|
||||
return subscriptions
|
||||
}
|
||||
|
||||
// GetActiveConnections returns the number of active connections
|
||||
func (mb *UniversalMessageBus) GetActiveConnections() int {
|
||||
count := 0
|
||||
for _, transport := range mb.transports {
|
||||
metrics := transport.GetMetrics()
|
||||
count += metrics.Connections
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
func (mb *UniversalMessageBus) validateMessage(msg *Message) error {
|
||||
if msg == nil {
|
||||
return fmt.Errorf("message is nil")
|
||||
}
|
||||
if msg.Topic == "" {
|
||||
return fmt.Errorf("message topic is empty")
|
||||
}
|
||||
if msg.Source == "" {
|
||||
return fmt.Errorf("message source is empty")
|
||||
}
|
||||
if msg.Data == nil {
|
||||
return fmt.Errorf("message data is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) addMessageToTopic(msg *Message) error {
|
||||
mb.mu.RLock()
|
||||
topic, exists := mb.topics[msg.Topic]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Create topic automatically
|
||||
config := TopicConfig{
|
||||
Persistent: true,
|
||||
RetentionPolicy: RetentionPolicy{MaxMessages: 10000, MaxAge: 24 * time.Hour},
|
||||
}
|
||||
if err := mb.CreateTopic(msg.Topic, config); err != nil {
|
||||
return err
|
||||
}
|
||||
topic = mb.topics[msg.Topic]
|
||||
}
|
||||
|
||||
topic.mu.Lock()
|
||||
defer topic.mu.Unlock()
|
||||
|
||||
stored := StoredMessage{
|
||||
Message: msg,
|
||||
Stored: time.Now(),
|
||||
Processed: false,
|
||||
}
|
||||
|
||||
topic.Messages = append(topic.Messages, stored)
|
||||
topic.LastActivity = time.Now()
|
||||
|
||||
// Apply retention policy
|
||||
mb.applyRetentionPolicy(topic)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) addSubscriberToTopic(topicName string, subscription *Subscription) {
|
||||
mb.mu.RLock()
|
||||
topic, exists := mb.topics[topicName]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Create topic automatically
|
||||
config := TopicConfig{Persistent: false}
|
||||
mb.CreateTopic(topicName, config)
|
||||
topic = mb.topics[topicName]
|
||||
}
|
||||
|
||||
topic.mu.Lock()
|
||||
topic.Subscribers = append(topic.Subscribers, subscription)
|
||||
topic.mu.Unlock()
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) removeSubscriberFromTopic(topicName, subscriptionID string) {
|
||||
mb.mu.RLock()
|
||||
topic, exists := mb.topics[topicName]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
topic.mu.Lock()
|
||||
defer topic.mu.Unlock()
|
||||
|
||||
for i, sub := range topic.Subscribers {
|
||||
if sub.ID == subscriptionID {
|
||||
topic.Subscribers = append(topic.Subscribers[:i], topic.Subscribers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) deliverToSubscribers(ctx context.Context, msg *Message) {
|
||||
mb.mu.RLock()
|
||||
topic, exists := mb.topics[msg.Topic]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
topic.mu.RLock()
|
||||
subscribers := make([]*Subscription, len(topic.Subscribers))
|
||||
copy(subscribers, topic.Subscribers)
|
||||
topic.mu.RUnlock()
|
||||
|
||||
for _, sub := range subscribers {
|
||||
sub.mu.RLock()
|
||||
if !sub.active {
|
||||
sub.mu.RUnlock()
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply filter if present
|
||||
if sub.Filter != nil && !sub.Filter(msg) {
|
||||
sub.mu.RUnlock()
|
||||
continue
|
||||
}
|
||||
|
||||
handler := sub.Handler
|
||||
sub.mu.RUnlock()
|
||||
|
||||
// Deliver message in goroutine
|
||||
go func(subscription *Subscription, message *Message) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
mb.metrics.IncrementCounter("handler_panics_total")
|
||||
}
|
||||
}()
|
||||
|
||||
if err := handler(ctx, message); err != nil {
|
||||
mb.metrics.IncrementCounter("handler_errors_total")
|
||||
if mb.config.EnableDLQ {
|
||||
mb.dlq.AddMessage(message.Topic, message)
|
||||
}
|
||||
} else {
|
||||
mb.metrics.IncrementCounter("messages_consumed_total")
|
||||
}
|
||||
}(sub, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) applyRetentionPolicy(topic *Topic) {
|
||||
policy := topic.Config.RetentionPolicy
|
||||
|
||||
// Remove old messages
|
||||
if policy.MaxAge > 0 {
|
||||
cutoff := time.Now().Add(-policy.MaxAge)
|
||||
filtered := make([]StoredMessage, 0)
|
||||
for _, stored := range topic.Messages {
|
||||
if stored.Stored.After(cutoff) {
|
||||
filtered = append(filtered, stored)
|
||||
}
|
||||
}
|
||||
topic.Messages = filtered
|
||||
}
|
||||
|
||||
// Limit number of messages
|
||||
if policy.MaxMessages > 0 && len(topic.Messages) > policy.MaxMessages {
|
||||
topic.Messages = topic.Messages[len(topic.Messages)-policy.MaxMessages:]
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) healthCheckLoop() {
|
||||
ticker := time.NewTicker(mb.config.HealthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mb.performHealthCheck()
|
||||
case <-mb.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) cleanupLoop() {
|
||||
ticker := time.NewTicker(mb.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mb.performCleanup()
|
||||
case <-mb.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) metricsLoop() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mb.updateMetrics()
|
||||
case <-mb.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) performHealthCheck() {
|
||||
// Check all transports
|
||||
for _, transport := range mb.transports {
|
||||
health := transport.Health()
|
||||
mb.metrics.RecordGauge(fmt.Sprintf("transport_%s_healthy", health.Component),
|
||||
map[string]float64{"healthy": 1, "unhealthy": 0, "degraded": 0.5}[health.Status])
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) performCleanup() {
|
||||
// Clean up processed messages in topics
|
||||
mb.mu.RLock()
|
||||
topics := make([]*Topic, 0, len(mb.topics))
|
||||
for _, topic := range mb.topics {
|
||||
topics = append(topics, topic)
|
||||
}
|
||||
mb.mu.RUnlock()
|
||||
|
||||
for _, topic := range topics {
|
||||
topic.mu.Lock()
|
||||
mb.applyRetentionPolicy(topic)
|
||||
topic.mu.Unlock()
|
||||
}
|
||||
|
||||
// Clean up DLQ
|
||||
mb.dlq.Cleanup(time.Hour * 24) // Clean messages older than 24 hours
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) updateMetrics() {
|
||||
// Update throughput metrics
|
||||
publishedCount := mb.getMetricInt64("messages_published_total")
|
||||
if publishedCount > 0 {
|
||||
// Calculate per-second rate (simplified)
|
||||
mb.metrics.RecordGauge("throughput_per_second", float64(publishedCount)/60.0)
|
||||
}
|
||||
|
||||
// Update error rate
|
||||
errorCount := mb.getMetricInt64("send_errors_total")
|
||||
totalCount := publishedCount
|
||||
if totalCount > 0 {
|
||||
errorRate := float64(errorCount) / float64(totalCount)
|
||||
mb.metrics.RecordGauge("error_rate", errorRate)
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) getMetricInt64(key string) int64 {
|
||||
if val, ok := mb.metrics.Get(key).(int64); ok {
|
||||
return val
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) getMetricFloat64(key string) float64 {
|
||||
if val, ok := mb.metrics.Get(key).(float64); ok {
|
||||
return val
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) getMetricDuration(key string) time.Duration {
|
||||
if val, ok := mb.metrics.Get(key).(time.Duration); ok {
|
||||
return val
|
||||
}
|
||||
return 0
|
||||
}
|
||||
622
pkg/transport/persistence.go
Normal file
622
pkg/transport/persistence.go
Normal file
@@ -0,0 +1,622 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FilePersistenceLayer implements file-based message persistence
|
||||
type FilePersistenceLayer struct {
|
||||
basePath string
|
||||
maxFileSize int64
|
||||
maxFiles int
|
||||
compression bool
|
||||
encryption EncryptionConfig
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// EncryptionConfig configures message encryption at rest
|
||||
type EncryptionConfig struct {
|
||||
Enabled bool
|
||||
Algorithm string
|
||||
Key []byte
|
||||
}
|
||||
|
||||
// PersistedMessage represents a message stored on disk
|
||||
type PersistedMessage struct {
|
||||
ID string `json:"id"`
|
||||
Topic string `json:"topic"`
|
||||
Message *Message `json:"message"`
|
||||
Stored time.Time `json:"stored"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
Encrypted bool `json:"encrypted"`
|
||||
}
|
||||
|
||||
// PersistenceMetrics tracks persistence layer statistics
|
||||
type PersistenceMetrics struct {
|
||||
MessagesStored int64
|
||||
MessagesRetrieved int64
|
||||
MessagesDeleted int64
|
||||
StorageSize int64
|
||||
FileCount int
|
||||
LastCleanup time.Time
|
||||
Errors int64
|
||||
}
|
||||
|
||||
// NewFilePersistenceLayer creates a new file-based persistence layer
|
||||
func NewFilePersistenceLayer(basePath string) *FilePersistenceLayer {
|
||||
return &FilePersistenceLayer{
|
||||
basePath: basePath,
|
||||
maxFileSize: 100 * 1024 * 1024, // 100MB default
|
||||
maxFiles: 1000,
|
||||
compression: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxFileSize configures the maximum file size
|
||||
func (fpl *FilePersistenceLayer) SetMaxFileSize(size int64) {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
fpl.maxFileSize = size
|
||||
}
|
||||
|
||||
// SetMaxFiles configures the maximum number of files
|
||||
func (fpl *FilePersistenceLayer) SetMaxFiles(count int) {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
fpl.maxFiles = count
|
||||
}
|
||||
|
||||
// EnableCompression enables/disables compression
|
||||
func (fpl *FilePersistenceLayer) EnableCompression(enabled bool) {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
fpl.compression = enabled
|
||||
}
|
||||
|
||||
// SetEncryption configures encryption settings
|
||||
func (fpl *FilePersistenceLayer) SetEncryption(config EncryptionConfig) {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
fpl.encryption = config
|
||||
}
|
||||
|
||||
// Store persists a message to disk
|
||||
func (fpl *FilePersistenceLayer) Store(msg *Message) error {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
topicDir := filepath.Join(fpl.basePath, msg.Topic)
|
||||
if err := os.MkdirAll(topicDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create topic directory: %w", err)
|
||||
}
|
||||
|
||||
// Create persisted message
|
||||
persistedMsg := &PersistedMessage{
|
||||
ID: msg.ID,
|
||||
Topic: msg.Topic,
|
||||
Message: msg,
|
||||
Stored: time.Now(),
|
||||
Metadata: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
data, err := json.Marshal(persistedMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Apply encryption if enabled
|
||||
if fpl.encryption.Enabled {
|
||||
encryptedData, err := fpl.encrypt(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encryption failed: %w", err)
|
||||
}
|
||||
data = encryptedData
|
||||
persistedMsg.Encrypted = true
|
||||
}
|
||||
|
||||
// Apply compression if enabled
|
||||
if fpl.compression {
|
||||
compressedData, err := fpl.compress(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compression failed: %w", err)
|
||||
}
|
||||
data = compressedData
|
||||
}
|
||||
|
||||
// Find appropriate file to write to
|
||||
filename, err := fpl.getWritableFile(topicDir, len(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get writable file: %w", err)
|
||||
}
|
||||
|
||||
// Write to file
|
||||
file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write length prefix and data
|
||||
lengthPrefix := fmt.Sprintf("%d\n", len(data))
|
||||
if _, err := file.WriteString(lengthPrefix); err != nil {
|
||||
return fmt.Errorf("failed to write length prefix: %w", err)
|
||||
}
|
||||
if _, err := file.Write(data); err != nil {
|
||||
return fmt.Errorf("failed to write data: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retrieve loads a message from disk by ID
|
||||
func (fpl *FilePersistenceLayer) Retrieve(id string) (*Message, error) {
|
||||
fpl.mu.RLock()
|
||||
defer fpl.mu.RUnlock()
|
||||
|
||||
// Search all topic directories
|
||||
topicDirs, err := fpl.getTopicDirectories()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get topic directories: %w", err)
|
||||
}
|
||||
|
||||
for _, topicDir := range topicDirs {
|
||||
files, err := fpl.getTopicFiles(topicDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
msg, err := fpl.findMessageInFile(file, id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if msg != nil {
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("message not found: %s", id)
|
||||
}
|
||||
|
||||
// Delete removes a message from disk by ID
|
||||
func (fpl *FilePersistenceLayer) Delete(id string) error {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
|
||||
// This is a simplified implementation
|
||||
// In a production system, you might want to mark messages as deleted
|
||||
// and compact files periodically instead of rewriting entire files
|
||||
return fmt.Errorf("delete operation not yet implemented")
|
||||
}
|
||||
|
||||
// List returns messages for a topic with optional limit
|
||||
func (fpl *FilePersistenceLayer) List(topic string, limit int) ([]*Message, error) {
|
||||
fpl.mu.RLock()
|
||||
defer fpl.mu.RUnlock()
|
||||
|
||||
topicDir := filepath.Join(fpl.basePath, topic)
|
||||
if _, err := os.Stat(topicDir); os.IsNotExist(err) {
|
||||
return []*Message{}, nil
|
||||
}
|
||||
|
||||
files, err := fpl.getTopicFiles(topicDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get topic files: %w", err)
|
||||
}
|
||||
|
||||
var messages []*Message
|
||||
count := 0
|
||||
|
||||
// Read files in chronological order (newest first)
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
infoI, _ := os.Stat(files[i])
|
||||
infoJ, _ := os.Stat(files[j])
|
||||
return infoI.ModTime().After(infoJ.ModTime())
|
||||
})
|
||||
|
||||
for _, file := range files {
|
||||
fileMessages, err := fpl.readMessagesFromFile(file)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, msg := range fileMessages {
|
||||
messages = append(messages, msg)
|
||||
count++
|
||||
if limit > 0 && count >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if limit > 0 && count >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// Cleanup removes messages older than maxAge
|
||||
func (fpl *FilePersistenceLayer) Cleanup(maxAge time.Duration) error {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
|
||||
topicDirs, err := fpl.getTopicDirectories()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get topic directories: %w", err)
|
||||
}
|
||||
|
||||
for _, topicDir := range topicDirs {
|
||||
files, err := fpl.getTopicFiles(topicDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
// Check file modification time
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.ModTime().Before(cutoff) {
|
||||
os.Remove(file)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove empty topic directories
|
||||
if isEmpty, _ := fpl.isDirectoryEmpty(topicDir); isEmpty {
|
||||
os.Remove(topicDir)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMetrics returns persistence layer metrics
|
||||
func (fpl *FilePersistenceLayer) GetMetrics() (PersistenceMetrics, error) {
|
||||
fpl.mu.RLock()
|
||||
defer fpl.mu.RUnlock()
|
||||
|
||||
metrics := PersistenceMetrics{}
|
||||
|
||||
// Calculate storage size and file count
|
||||
err := filepath.Walk(fpl.basePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
metrics.FileCount++
|
||||
metrics.StorageSize += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return metrics, err
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (fpl *FilePersistenceLayer) getWritableFile(topicDir string, dataSize int) (string, error) {
|
||||
files, err := fpl.getTopicFiles(topicDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Find a file with enough space
|
||||
for _, file := range files {
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.Size()+int64(dataSize) <= fpl.maxFileSize {
|
||||
return file, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create new file
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
filename := filepath.Join(topicDir, fmt.Sprintf("messages_%s.dat", timestamp))
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) getTopicDirectories() ([]string, error) {
|
||||
entries, err := ioutil.ReadDir(fpl.basePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dirs []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
dirs = append(dirs, filepath.Join(fpl.basePath, entry.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
return dirs, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) getTopicFiles(topicDir string) ([]string, error) {
|
||||
entries, err := ioutil.ReadDir(topicDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var files []string
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && filepath.Ext(entry.Name()) == ".dat" {
|
||||
files = append(files, filepath.Join(topicDir, entry.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) findMessageInFile(filename, messageID string) (*Message, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse messages from file data
|
||||
messages, err := fpl.parseFileData(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.ID == messageID {
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) readMessagesFromFile(filename string) ([]*Message, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fpl.parseFileData(data)
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) parseFileData(data []byte) ([]*Message, error) {
|
||||
var messages []*Message
|
||||
offset := 0
|
||||
|
||||
for offset < len(data) {
|
||||
// Read length prefix
|
||||
lengthEnd := -1
|
||||
for i := offset; i < len(data); i++ {
|
||||
if data[i] == '\n' {
|
||||
lengthEnd = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if lengthEnd == -1 {
|
||||
break // No more complete messages
|
||||
}
|
||||
|
||||
lengthStr := string(data[offset:lengthEnd])
|
||||
var messageLength int
|
||||
if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil {
|
||||
break // Invalid length prefix
|
||||
}
|
||||
|
||||
messageStart := lengthEnd + 1
|
||||
messageEnd := messageStart + messageLength
|
||||
|
||||
if messageEnd > len(data) {
|
||||
break // Incomplete message
|
||||
}
|
||||
|
||||
messageData := data[messageStart:messageEnd]
|
||||
|
||||
// Apply decompression if needed
|
||||
if fpl.compression {
|
||||
decompressed, err := fpl.decompress(messageData)
|
||||
if err != nil {
|
||||
offset = messageEnd
|
||||
continue
|
||||
}
|
||||
messageData = decompressed
|
||||
}
|
||||
|
||||
// Apply decryption if needed
|
||||
if fpl.encryption.Enabled {
|
||||
decrypted, err := fpl.decrypt(messageData)
|
||||
if err != nil {
|
||||
offset = messageEnd
|
||||
continue
|
||||
}
|
||||
messageData = decrypted
|
||||
}
|
||||
|
||||
// Parse message
|
||||
var persistedMsg PersistedMessage
|
||||
if err := json.Unmarshal(messageData, &persistedMsg); err != nil {
|
||||
offset = messageEnd
|
||||
continue
|
||||
}
|
||||
|
||||
messages = append(messages, persistedMsg.Message)
|
||||
offset = messageEnd
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) isDirectoryEmpty(dir string) (bool, error) {
|
||||
entries, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return len(entries) == 0, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) encrypt(data []byte) ([]byte, error) {
|
||||
// Placeholder for encryption implementation
|
||||
// In a real implementation, you would use proper encryption libraries
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) decrypt(data []byte) ([]byte, error) {
|
||||
// Placeholder for decryption implementation
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) compress(data []byte) ([]byte, error) {
|
||||
// Placeholder for compression implementation
|
||||
// In a real implementation, you would use libraries like gzip
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) decompress(data []byte) ([]byte, error) {
|
||||
// Placeholder for decompression implementation
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// InMemoryPersistenceLayer implements in-memory persistence for testing/development
|
||||
type InMemoryPersistenceLayer struct {
|
||||
messages map[string]*Message
|
||||
topics map[string][]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewInMemoryPersistenceLayer creates a new in-memory persistence layer
|
||||
func NewInMemoryPersistenceLayer() *InMemoryPersistenceLayer {
|
||||
return &InMemoryPersistenceLayer{
|
||||
messages: make(map[string]*Message),
|
||||
topics: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Store stores a message in memory
|
||||
func (impl *InMemoryPersistenceLayer) Store(msg *Message) error {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
impl.messages[msg.ID] = msg
|
||||
|
||||
if _, exists := impl.topics[msg.Topic]; !exists {
|
||||
impl.topics[msg.Topic] = make([]string, 0)
|
||||
}
|
||||
impl.topics[msg.Topic] = append(impl.topics[msg.Topic], msg.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retrieve retrieves a message from memory
|
||||
func (impl *InMemoryPersistenceLayer) Retrieve(id string) (*Message, error) {
|
||||
impl.mu.RLock()
|
||||
defer impl.mu.RUnlock()
|
||||
|
||||
msg, exists := impl.messages[id]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("message not found: %s", id)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Delete removes a message from memory
|
||||
func (impl *InMemoryPersistenceLayer) Delete(id string) error {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
msg, exists := impl.messages[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("message not found: %s", id)
|
||||
}
|
||||
|
||||
delete(impl.messages, id)
|
||||
|
||||
// Remove from topic index
|
||||
if messageIDs, exists := impl.topics[msg.Topic]; exists {
|
||||
for i, msgID := range messageIDs {
|
||||
if msgID == id {
|
||||
impl.topics[msg.Topic] = append(messageIDs[:i], messageIDs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns messages for a topic
|
||||
func (impl *InMemoryPersistenceLayer) List(topic string, limit int) ([]*Message, error) {
|
||||
impl.mu.RLock()
|
||||
defer impl.mu.RUnlock()
|
||||
|
||||
messageIDs, exists := impl.topics[topic]
|
||||
if !exists {
|
||||
return []*Message{}, nil
|
||||
}
|
||||
|
||||
var messages []*Message
|
||||
count := 0
|
||||
|
||||
for _, msgID := range messageIDs {
|
||||
if limit > 0 && count >= limit {
|
||||
break
|
||||
}
|
||||
|
||||
if msg, exists := impl.messages[msgID]; exists {
|
||||
messages = append(messages, msg)
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// Cleanup removes messages older than maxAge
|
||||
func (impl *InMemoryPersistenceLayer) Cleanup(maxAge time.Duration) error {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
var toDelete []string
|
||||
|
||||
for id, msg := range impl.messages {
|
||||
if msg.Timestamp.Before(cutoff) {
|
||||
toDelete = append(toDelete, id)
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range toDelete {
|
||||
impl.Delete(id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
478
pkg/transport/router.go
Normal file
478
pkg/transport/router.go
Normal file
@@ -0,0 +1,478 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MessageRouter handles intelligent message routing and transport selection
|
||||
type MessageRouter struct {
|
||||
rules []RoutingRule
|
||||
fallback TransportType
|
||||
loadBalancer LoadBalancer
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// RoutingRule defines message routing logic
|
||||
type RoutingRule struct {
|
||||
ID string
|
||||
Name string
|
||||
Condition MessageFilter
|
||||
Transport TransportType
|
||||
Priority int
|
||||
Enabled bool
|
||||
Created time.Time
|
||||
LastUsed time.Time
|
||||
UsageCount int64
|
||||
}
|
||||
|
||||
// RouteMessage selects the appropriate transport for a message
|
||||
func (mr *MessageRouter) RouteMessage(msg *Message, transports map[TransportType]Transport) (Transport, error) {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
|
||||
// Find matching rules (sorted by priority)
|
||||
matchingRules := mr.findMatchingRules(msg)
|
||||
|
||||
// Try each matching rule in priority order
|
||||
for _, rule := range matchingRules {
|
||||
if transport, exists := transports[rule.Transport]; exists {
|
||||
// Check transport health
|
||||
if health := transport.Health(); health.Status == "healthy" {
|
||||
mr.updateRuleUsage(rule.ID)
|
||||
return transport, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use load balancer for available transports
|
||||
if mr.loadBalancer != nil {
|
||||
availableTransports := mr.getHealthyTransports(transports)
|
||||
if len(availableTransports) > 0 {
|
||||
selectedType := mr.loadBalancer.SelectTransport(availableTransports, msg)
|
||||
if transport, exists := transports[selectedType]; exists {
|
||||
return transport, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default transport
|
||||
if fallbackTransport, exists := transports[mr.fallback]; exists {
|
||||
if health := fallbackTransport.Health(); health.Status != "unhealthy" {
|
||||
return fallbackTransport, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no available transport for message")
|
||||
}
|
||||
|
||||
// AddRule adds a new routing rule
|
||||
func (mr *MessageRouter) AddRule(rule RoutingRule) {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
if rule.ID == "" {
|
||||
rule.ID = fmt.Sprintf("rule_%d", time.Now().UnixNano())
|
||||
}
|
||||
rule.Created = time.Now()
|
||||
rule.Enabled = true
|
||||
|
||||
mr.rules = append(mr.rules, rule)
|
||||
mr.sortRulesByPriority()
|
||||
}
|
||||
|
||||
// RemoveRule removes a routing rule by ID
|
||||
func (mr *MessageRouter) RemoveRule(ruleID string) bool {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
for i, rule := range mr.rules {
|
||||
if rule.ID == ruleID {
|
||||
mr.rules = append(mr.rules[:i], mr.rules[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateRule updates an existing routing rule
|
||||
func (mr *MessageRouter) UpdateRule(ruleID string, updates func(*RoutingRule)) bool {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
for i := range mr.rules {
|
||||
if mr.rules[i].ID == ruleID {
|
||||
updates(&mr.rules[i])
|
||||
mr.sortRulesByPriority()
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetRules returns all routing rules
|
||||
func (mr *MessageRouter) GetRules() []RoutingRule {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
|
||||
rules := make([]RoutingRule, len(mr.rules))
|
||||
copy(rules, mr.rules)
|
||||
return rules
|
||||
}
|
||||
|
||||
// EnableRule enables a routing rule
|
||||
func (mr *MessageRouter) EnableRule(ruleID string) bool {
|
||||
return mr.UpdateRule(ruleID, func(rule *RoutingRule) {
|
||||
rule.Enabled = true
|
||||
})
|
||||
}
|
||||
|
||||
// DisableRule disables a routing rule
|
||||
func (mr *MessageRouter) DisableRule(ruleID string) bool {
|
||||
return mr.UpdateRule(ruleID, func(rule *RoutingRule) {
|
||||
rule.Enabled = false
|
||||
})
|
||||
}
|
||||
|
||||
// SetFallbackTransport sets the fallback transport type
|
||||
func (mr *MessageRouter) SetFallbackTransport(transportType TransportType) {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
mr.fallback = transportType
|
||||
}
|
||||
|
||||
// SetLoadBalancer sets the load balancer
|
||||
func (mr *MessageRouter) SetLoadBalancer(lb LoadBalancer) {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
mr.loadBalancer = lb
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (mr *MessageRouter) findMatchingRules(msg *Message) []RoutingRule {
|
||||
var matching []RoutingRule
|
||||
|
||||
for _, rule := range mr.rules {
|
||||
if rule.Enabled && (rule.Condition == nil || rule.Condition(msg)) {
|
||||
matching = append(matching, rule)
|
||||
}
|
||||
}
|
||||
|
||||
return matching
|
||||
}
|
||||
|
||||
func (mr *MessageRouter) sortRulesByPriority() {
|
||||
sort.Slice(mr.rules, func(i, j int) bool {
|
||||
return mr.rules[i].Priority > mr.rules[j].Priority
|
||||
})
|
||||
}
|
||||
|
||||
func (mr *MessageRouter) updateRuleUsage(ruleID string) {
|
||||
for i := range mr.rules {
|
||||
if mr.rules[i].ID == ruleID {
|
||||
mr.rules[i].LastUsed = time.Now()
|
||||
mr.rules[i].UsageCount++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *MessageRouter) getHealthyTransports(transports map[TransportType]Transport) []TransportType {
|
||||
var healthy []TransportType
|
||||
|
||||
for transportType, transport := range transports {
|
||||
if health := transport.Health(); health.Status == "healthy" {
|
||||
healthy = append(healthy, transportType)
|
||||
}
|
||||
}
|
||||
|
||||
return healthy
|
||||
}
|
||||
|
||||
// LoadBalancer implementations
|
||||
|
||||
// RoundRobinLoadBalancer implements round-robin load balancing
|
||||
type RoundRobinLoadBalancer struct {
|
||||
counter int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewRoundRobinLoadBalancer() *RoundRobinLoadBalancer {
|
||||
return &RoundRobinLoadBalancer{}
|
||||
}
|
||||
|
||||
func (lb *RoundRobinLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
|
||||
if len(transports) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
selected := transports[lb.counter%int64(len(transports))]
|
||||
lb.counter++
|
||||
return selected
|
||||
}
|
||||
|
||||
func (lb *RoundRobinLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
|
||||
// Round-robin doesn't use stats
|
||||
}
|
||||
|
||||
// WeightedLoadBalancer implements weighted load balancing based on performance
|
||||
type WeightedLoadBalancer struct {
|
||||
stats map[TransportType]*TransportStats
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type TransportStats struct {
|
||||
TotalRequests int64
|
||||
SuccessRequests int64
|
||||
TotalLatency time.Duration
|
||||
LastUpdate time.Time
|
||||
Weight float64
|
||||
}
|
||||
|
||||
func NewWeightedLoadBalancer() *WeightedLoadBalancer {
|
||||
return &WeightedLoadBalancer{
|
||||
stats: make(map[TransportType]*TransportStats),
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *WeightedLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
|
||||
if len(transports) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
lb.mu.RLock()
|
||||
defer lb.mu.RUnlock()
|
||||
|
||||
// Calculate weights and select based on weighted random selection
|
||||
totalWeight := 0.0
|
||||
weights := make(map[TransportType]float64)
|
||||
|
||||
for _, transport := range transports {
|
||||
weight := lb.calculateWeight(transport)
|
||||
weights[transport] = weight
|
||||
totalWeight += weight
|
||||
}
|
||||
|
||||
if totalWeight == 0 {
|
||||
// Fall back to random selection
|
||||
return transports[rand.Intn(len(transports))]
|
||||
}
|
||||
|
||||
// Weighted random selection
|
||||
target := rand.Float64() * totalWeight
|
||||
current := 0.0
|
||||
|
||||
for _, transport := range transports {
|
||||
current += weights[transport]
|
||||
if current >= target {
|
||||
return transport
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback (shouldn't happen)
|
||||
return transports[0]
|
||||
}
|
||||
|
||||
func (lb *WeightedLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
stats, exists := lb.stats[transport]
|
||||
if !exists {
|
||||
stats = &TransportStats{
|
||||
Weight: 1.0, // Default weight
|
||||
}
|
||||
lb.stats[transport] = stats
|
||||
}
|
||||
|
||||
stats.TotalRequests++
|
||||
stats.TotalLatency += latency
|
||||
stats.LastUpdate = time.Now()
|
||||
|
||||
if success {
|
||||
stats.SuccessRequests++
|
||||
}
|
||||
|
||||
// Recalculate weight based on performance
|
||||
stats.Weight = lb.calculateWeight(transport)
|
||||
}
|
||||
|
||||
func (lb *WeightedLoadBalancer) calculateWeight(transport TransportType) float64 {
|
||||
stats, exists := lb.stats[transport]
|
||||
if !exists {
|
||||
return 1.0 // Default weight for unknown transports
|
||||
}
|
||||
|
||||
if stats.TotalRequests == 0 {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
// Calculate success rate
|
||||
successRate := float64(stats.SuccessRequests) / float64(stats.TotalRequests)
|
||||
|
||||
// Calculate average latency
|
||||
avgLatency := stats.TotalLatency / time.Duration(stats.TotalRequests)
|
||||
|
||||
// Weight formula: success rate / (latency factor)
|
||||
// Lower latency and higher success rate = higher weight
|
||||
latencyFactor := float64(avgLatency) / float64(time.Millisecond)
|
||||
if latencyFactor < 1 {
|
||||
latencyFactor = 1
|
||||
}
|
||||
|
||||
weight := successRate / latencyFactor
|
||||
|
||||
// Ensure minimum weight
|
||||
if weight < 0.1 {
|
||||
weight = 0.1
|
||||
}
|
||||
|
||||
return weight
|
||||
}
|
||||
|
||||
// LeastLatencyLoadBalancer selects the transport with the lowest latency
|
||||
type LeastLatencyLoadBalancer struct {
|
||||
stats map[TransportType]*LatencyStats
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type LatencyStats struct {
|
||||
RecentLatencies []time.Duration
|
||||
MaxSamples int
|
||||
LastUpdate time.Time
|
||||
}
|
||||
|
||||
func NewLeastLatencyLoadBalancer() *LeastLatencyLoadBalancer {
|
||||
return &LeastLatencyLoadBalancer{
|
||||
stats: make(map[TransportType]*LatencyStats),
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LeastLatencyLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
|
||||
if len(transports) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
lb.mu.RLock()
|
||||
defer lb.mu.RUnlock()
|
||||
|
||||
bestTransport := transports[0]
|
||||
bestLatency := time.Hour // Large initial value
|
||||
|
||||
for _, transport := range transports {
|
||||
avgLatency := lb.getAverageLatency(transport)
|
||||
if avgLatency < bestLatency {
|
||||
bestLatency = avgLatency
|
||||
bestTransport = transport
|
||||
}
|
||||
}
|
||||
|
||||
return bestTransport
|
||||
}
|
||||
|
||||
func (lb *LeastLatencyLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
|
||||
if !success {
|
||||
return // Only track successful requests
|
||||
}
|
||||
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
stats, exists := lb.stats[transport]
|
||||
if !exists {
|
||||
stats = &LatencyStats{
|
||||
RecentLatencies: make([]time.Duration, 0),
|
||||
MaxSamples: 10, // Keep last 10 samples
|
||||
}
|
||||
lb.stats[transport] = stats
|
||||
}
|
||||
|
||||
// Add new latency sample
|
||||
stats.RecentLatencies = append(stats.RecentLatencies, latency)
|
||||
|
||||
// Keep only recent samples
|
||||
if len(stats.RecentLatencies) > stats.MaxSamples {
|
||||
stats.RecentLatencies = stats.RecentLatencies[1:]
|
||||
}
|
||||
|
||||
stats.LastUpdate = time.Now()
|
||||
}
|
||||
|
||||
func (lb *LeastLatencyLoadBalancer) getAverageLatency(transport TransportType) time.Duration {
|
||||
stats, exists := lb.stats[transport]
|
||||
if !exists || len(stats.RecentLatencies) == 0 {
|
||||
return time.Millisecond * 100 // Default estimate
|
||||
}
|
||||
|
||||
total := time.Duration(0)
|
||||
for _, latency := range stats.RecentLatencies {
|
||||
total += latency
|
||||
}
|
||||
|
||||
return total / time.Duration(len(stats.RecentLatencies))
|
||||
}
|
||||
|
||||
// Common routing rule factory functions
|
||||
|
||||
// CreateTopicRule creates a rule based on message topic
|
||||
func CreateTopicRule(name string, topic string, transport TransportType, priority int) RoutingRule {
|
||||
return RoutingRule{
|
||||
Name: name,
|
||||
Condition: func(msg *Message) bool { return msg.Topic == topic },
|
||||
Transport: transport,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTopicPatternRule creates a rule based on topic pattern matching
|
||||
func CreateTopicPatternRule(name string, pattern string, transport TransportType, priority int) RoutingRule {
|
||||
return RoutingRule{
|
||||
Name: name,
|
||||
Condition: func(msg *Message) bool {
|
||||
// Simple pattern matching (can be enhanced with regex)
|
||||
return msg.Topic == pattern ||
|
||||
(len(pattern) > 0 && pattern[len(pattern)-1] == '*' &&
|
||||
len(msg.Topic) >= len(pattern)-1 &&
|
||||
msg.Topic[:len(pattern)-1] == pattern[:len(pattern)-1])
|
||||
},
|
||||
Transport: transport,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePriorityRule creates a rule based on message priority
|
||||
func CreatePriorityRule(name string, msgPriority MessagePriority, transport TransportType, priority int) RoutingRule {
|
||||
return RoutingRule{
|
||||
Name: name,
|
||||
Condition: func(msg *Message) bool { return msg.Priority == msgPriority },
|
||||
Transport: transport,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTypeRule creates a rule based on message type
|
||||
func CreateTypeRule(name string, msgType MessageType, transport TransportType, priority int) RoutingRule {
|
||||
return RoutingRule{
|
||||
Name: name,
|
||||
Condition: func(msg *Message) bool { return msg.Type == msgType },
|
||||
Transport: transport,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSourceRule creates a rule based on message source
|
||||
func CreateSourceRule(name string, source string, transport TransportType, priority int) RoutingRule {
|
||||
return RoutingRule{
|
||||
Name: name,
|
||||
Condition: func(msg *Message) bool { return msg.Source == source },
|
||||
Transport: transport,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
566
pkg/transport/serialization.go
Normal file
566
pkg/transport/serialization.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SerializationFormat defines supported serialization formats
|
||||
type SerializationFormat string
|
||||
|
||||
const (
|
||||
SerializationJSON SerializationFormat = "json"
|
||||
SerializationMsgPack SerializationFormat = "msgpack"
|
||||
SerializationProtobuf SerializationFormat = "protobuf"
|
||||
SerializationAvro SerializationFormat = "avro"
|
||||
)
|
||||
|
||||
// CompressionType defines supported compression algorithms
|
||||
type CompressionType string
|
||||
|
||||
const (
|
||||
CompressionNone CompressionType = "none"
|
||||
CompressionGZip CompressionType = "gzip"
|
||||
CompressionLZ4 CompressionType = "lz4"
|
||||
CompressionSnappy CompressionType = "snappy"
|
||||
)
|
||||
|
||||
// SerializationConfig configures serialization behavior
|
||||
type SerializationConfig struct {
|
||||
Format SerializationFormat
|
||||
Compression CompressionType
|
||||
Encryption bool
|
||||
Validation bool
|
||||
}
|
||||
|
||||
// SerializedMessage represents a serialized message with metadata
|
||||
type SerializedMessage struct {
|
||||
Format SerializationFormat `json:"format"`
|
||||
Compression CompressionType `json:"compression"`
|
||||
Encrypted bool `json:"encrypted"`
|
||||
Checksum string `json:"checksum"`
|
||||
Data []byte `json:"data"`
|
||||
Size int `json:"size"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Serializer interface defines serialization operations
|
||||
type Serializer interface {
|
||||
Serialize(msg *Message) (*SerializedMessage, error)
|
||||
Deserialize(serialized *SerializedMessage) (*Message, error)
|
||||
GetFormat() SerializationFormat
|
||||
GetConfig() SerializationConfig
|
||||
}
|
||||
|
||||
// SerializationLayer manages multiple serializers and format selection
|
||||
type SerializationLayer struct {
|
||||
serializers map[SerializationFormat]Serializer
|
||||
defaultFormat SerializationFormat
|
||||
compressor Compressor
|
||||
encryptor Encryptor
|
||||
validator Validator
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Compressor interface for data compression
|
||||
type Compressor interface {
|
||||
Compress(data []byte, algorithm CompressionType) ([]byte, error)
|
||||
Decompress(data []byte, algorithm CompressionType) ([]byte, error)
|
||||
GetSupportedAlgorithms() []CompressionType
|
||||
}
|
||||
|
||||
// Encryptor interface for data encryption
|
||||
type Encryptor interface {
|
||||
Encrypt(data []byte) ([]byte, error)
|
||||
Decrypt(data []byte) ([]byte, error)
|
||||
IsEnabled() bool
|
||||
}
|
||||
|
||||
// Validator interface for data validation
|
||||
type Validator interface {
|
||||
Validate(msg *Message) error
|
||||
GenerateChecksum(data []byte) string
|
||||
VerifyChecksum(data []byte, checksum string) bool
|
||||
}
|
||||
|
||||
// NewSerializationLayer creates a new serialization layer
|
||||
func NewSerializationLayer() *SerializationLayer {
|
||||
sl := &SerializationLayer{
|
||||
serializers: make(map[SerializationFormat]Serializer),
|
||||
defaultFormat: SerializationJSON,
|
||||
compressor: NewDefaultCompressor(),
|
||||
validator: NewDefaultValidator(),
|
||||
}
|
||||
|
||||
// Register default serializers
|
||||
sl.RegisterSerializer(NewJSONSerializer())
|
||||
|
||||
return sl
|
||||
}
|
||||
|
||||
// RegisterSerializer registers a new serializer
|
||||
func (sl *SerializationLayer) RegisterSerializer(serializer Serializer) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
sl.serializers[serializer.GetFormat()] = serializer
|
||||
}
|
||||
|
||||
// SetDefaultFormat sets the default serialization format
|
||||
func (sl *SerializationLayer) SetDefaultFormat(format SerializationFormat) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
sl.defaultFormat = format
|
||||
}
|
||||
|
||||
// SetCompressor sets the compression handler
|
||||
func (sl *SerializationLayer) SetCompressor(compressor Compressor) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
sl.compressor = compressor
|
||||
}
|
||||
|
||||
// SetEncryptor sets the encryption handler
|
||||
func (sl *SerializationLayer) SetEncryptor(encryptor Encryptor) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
sl.encryptor = encryptor
|
||||
}
|
||||
|
||||
// SetValidator sets the validation handler
|
||||
func (sl *SerializationLayer) SetValidator(validator Validator) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
sl.validator = validator
|
||||
}
|
||||
|
||||
// Serialize serializes a message using the specified or default format
|
||||
func (sl *SerializationLayer) Serialize(msg *Message, format ...SerializationFormat) (*SerializedMessage, error) {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
// Determine format to use
|
||||
selectedFormat := sl.defaultFormat
|
||||
if len(format) > 0 {
|
||||
selectedFormat = format[0]
|
||||
}
|
||||
|
||||
// Get serializer
|
||||
serializer, exists := sl.serializers[selectedFormat]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported serialization format: %s", selectedFormat)
|
||||
}
|
||||
|
||||
// Validate message if validator is configured
|
||||
if sl.validator != nil {
|
||||
if err := sl.validator.Validate(msg); err != nil {
|
||||
return nil, fmt.Errorf("message validation failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
serialized, err := serializer.Serialize(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialization failed: %w", err)
|
||||
}
|
||||
|
||||
// Apply compression if configured
|
||||
config := serializer.GetConfig()
|
||||
if config.Compression != CompressionNone && sl.compressor != nil {
|
||||
compressed, err := sl.compressor.Compress(serialized.Data, config.Compression)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compression failed: %w", err)
|
||||
}
|
||||
serialized.Data = compressed
|
||||
serialized.Compression = config.Compression
|
||||
}
|
||||
|
||||
// Apply encryption if configured
|
||||
if config.Encryption && sl.encryptor != nil && sl.encryptor.IsEnabled() {
|
||||
encrypted, err := sl.encryptor.Encrypt(serialized.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encryption failed: %w", err)
|
||||
}
|
||||
serialized.Data = encrypted
|
||||
serialized.Encrypted = true
|
||||
}
|
||||
|
||||
// Generate checksum
|
||||
if sl.validator != nil {
|
||||
serialized.Checksum = sl.validator.GenerateChecksum(serialized.Data)
|
||||
}
|
||||
|
||||
// Update metadata
|
||||
serialized.Size = len(serialized.Data)
|
||||
serialized.Timestamp = msg.Timestamp.UnixNano()
|
||||
|
||||
return serialized, nil
|
||||
}
|
||||
|
||||
// Deserialize deserializes a message
|
||||
func (sl *SerializationLayer) Deserialize(serialized *SerializedMessage) (*Message, error) {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
// Verify checksum if available
|
||||
if sl.validator != nil && serialized.Checksum != "" {
|
||||
if !sl.validator.VerifyChecksum(serialized.Data, serialized.Checksum) {
|
||||
return nil, fmt.Errorf("checksum verification failed")
|
||||
}
|
||||
}
|
||||
|
||||
data := serialized.Data
|
||||
|
||||
// Apply decryption if needed
|
||||
if serialized.Encrypted && sl.encryptor != nil && sl.encryptor.IsEnabled() {
|
||||
decrypted, err := sl.encryptor.Decrypt(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
data = decrypted
|
||||
}
|
||||
|
||||
// Apply decompression if needed
|
||||
if serialized.Compression != CompressionNone && sl.compressor != nil {
|
||||
decompressed, err := sl.compressor.Decompress(data, serialized.Compression)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompression failed: %w", err)
|
||||
}
|
||||
data = decompressed
|
||||
}
|
||||
|
||||
// Get serializer
|
||||
serializer, exists := sl.serializers[serialized.Format]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported serialization format: %s", serialized.Format)
|
||||
}
|
||||
|
||||
// Create temporary serialized message for deserializer
|
||||
tempSerialized := &SerializedMessage{
|
||||
Format: serialized.Format,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
// Deserialize message
|
||||
msg, err := serializer.Deserialize(tempSerialized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deserialization failed: %w", err)
|
||||
}
|
||||
|
||||
// Validate deserialized message if validator is configured
|
||||
if sl.validator != nil {
|
||||
if err := sl.validator.Validate(msg); err != nil {
|
||||
return nil, fmt.Errorf("deserialized message validation failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// GetSupportedFormats returns all supported serialization formats
|
||||
func (sl *SerializationLayer) GetSupportedFormats() []SerializationFormat {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
formats := make([]SerializationFormat, 0, len(sl.serializers))
|
||||
for format := range sl.serializers {
|
||||
formats = append(formats, format)
|
||||
}
|
||||
return formats
|
||||
}
|
||||
|
||||
// JSONSerializer implements JSON serialization
|
||||
type JSONSerializer struct {
|
||||
config SerializationConfig
|
||||
}
|
||||
|
||||
// NewJSONSerializer creates a new JSON serializer
|
||||
func NewJSONSerializer() *JSONSerializer {
|
||||
return &JSONSerializer{
|
||||
config: SerializationConfig{
|
||||
Format: SerializationJSON,
|
||||
Compression: CompressionNone,
|
||||
Encryption: false,
|
||||
Validation: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig updates the serializer configuration
|
||||
func (js *JSONSerializer) SetConfig(config SerializationConfig) {
|
||||
js.config = config
|
||||
js.config.Format = SerializationJSON // Ensure format is correct
|
||||
}
|
||||
|
||||
// Serialize serializes a message to JSON
|
||||
func (js *JSONSerializer) Serialize(msg *Message) (*SerializedMessage, error) {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("JSON marshal failed: %w", err)
|
||||
}
|
||||
|
||||
return &SerializedMessage{
|
||||
Format: SerializationJSON,
|
||||
Compression: CompressionNone,
|
||||
Encrypted: false,
|
||||
Data: data,
|
||||
Size: len(data),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Deserialize deserializes a message from JSON
|
||||
func (js *JSONSerializer) Deserialize(serialized *SerializedMessage) (*Message, error) {
|
||||
var msg Message
|
||||
if err := json.Unmarshal(serialized.Data, &msg); err != nil {
|
||||
return nil, fmt.Errorf("JSON unmarshal failed: %w", err)
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
// GetFormat returns the serialization format
|
||||
func (js *JSONSerializer) GetFormat() SerializationFormat {
|
||||
return SerializationJSON
|
||||
}
|
||||
|
||||
// GetConfig returns the serializer configuration
|
||||
func (js *JSONSerializer) GetConfig() SerializationConfig {
|
||||
return js.config
|
||||
}
|
||||
|
||||
// DefaultCompressor implements basic compression operations
|
||||
type DefaultCompressor struct {
|
||||
supportedAlgorithms []CompressionType
|
||||
}
|
||||
|
||||
// NewDefaultCompressor creates a new default compressor
|
||||
func NewDefaultCompressor() *DefaultCompressor {
|
||||
return &DefaultCompressor{
|
||||
supportedAlgorithms: []CompressionType{
|
||||
CompressionNone,
|
||||
CompressionGZip,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Compress compresses data using the specified algorithm
|
||||
func (dc *DefaultCompressor) Compress(data []byte, algorithm CompressionType) ([]byte, error) {
|
||||
switch algorithm {
|
||||
case CompressionNone:
|
||||
return data, nil
|
||||
|
||||
case CompressionGZip:
|
||||
var buf bytes.Buffer
|
||||
writer := gzip.NewWriter(&buf)
|
||||
|
||||
if _, err := writer.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("gzip write failed: %w", err)
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, fmt.Errorf("gzip close failed: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
// Decompress decompresses data using the specified algorithm
|
||||
func (dc *DefaultCompressor) Decompress(data []byte, algorithm CompressionType) ([]byte, error) {
|
||||
switch algorithm {
|
||||
case CompressionNone:
|
||||
return data, nil
|
||||
|
||||
case CompressionGZip:
|
||||
reader, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip reader creation failed: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip read failed: %w", err)
|
||||
}
|
||||
|
||||
return decompressed, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
// GetSupportedAlgorithms returns supported compression algorithms
|
||||
func (dc *DefaultCompressor) GetSupportedAlgorithms() []CompressionType {
|
||||
return dc.supportedAlgorithms
|
||||
}
|
||||
|
||||
// DefaultValidator implements basic message validation
|
||||
type DefaultValidator struct {
|
||||
strictMode bool
|
||||
}
|
||||
|
||||
// NewDefaultValidator creates a new default validator
|
||||
func NewDefaultValidator() *DefaultValidator {
|
||||
return &DefaultValidator{
|
||||
strictMode: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SetStrictMode enables/disables strict validation
|
||||
func (dv *DefaultValidator) SetStrictMode(enabled bool) {
|
||||
dv.strictMode = enabled
|
||||
}
|
||||
|
||||
// Validate validates a message
|
||||
func (dv *DefaultValidator) Validate(msg *Message) error {
|
||||
if msg == nil {
|
||||
return fmt.Errorf("message is nil")
|
||||
}
|
||||
|
||||
if msg.ID == "" {
|
||||
return fmt.Errorf("message ID is empty")
|
||||
}
|
||||
|
||||
if msg.Topic == "" {
|
||||
return fmt.Errorf("message topic is empty")
|
||||
}
|
||||
|
||||
if msg.Source == "" {
|
||||
return fmt.Errorf("message source is empty")
|
||||
}
|
||||
|
||||
if msg.Type == "" {
|
||||
return fmt.Errorf("message type is empty")
|
||||
}
|
||||
|
||||
if dv.strictMode {
|
||||
if msg.Data == nil {
|
||||
return fmt.Errorf("message data is nil")
|
||||
}
|
||||
|
||||
if msg.Timestamp.IsZero() {
|
||||
return fmt.Errorf("message timestamp is zero")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateChecksum generates a simple checksum for data
|
||||
func (dv *DefaultValidator) GenerateChecksum(data []byte) string {
|
||||
// Simple checksum implementation
|
||||
// In production, use a proper hash function like SHA-256
|
||||
var sum uint32
|
||||
for _, b := range data {
|
||||
sum += uint32(b)
|
||||
}
|
||||
return fmt.Sprintf("%08x", sum)
|
||||
}
|
||||
|
||||
// VerifyChecksum verifies a checksum
|
||||
func (dv *DefaultValidator) VerifyChecksum(data []byte, checksum string) bool {
|
||||
return dv.GenerateChecksum(data) == checksum
|
||||
}
|
||||
|
||||
// NoOpEncryptor implements a no-operation encryptor for testing
|
||||
type NoOpEncryptor struct {
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewNoOpEncryptor creates a new no-op encryptor
|
||||
func NewNoOpEncryptor() *NoOpEncryptor {
|
||||
return &NoOpEncryptor{enabled: false}
|
||||
}
|
||||
|
||||
// SetEnabled enables/disables the encryptor
|
||||
func (noe *NoOpEncryptor) SetEnabled(enabled bool) {
|
||||
noe.enabled = enabled
|
||||
}
|
||||
|
||||
// Encrypt returns data unchanged
|
||||
func (noe *NoOpEncryptor) Encrypt(data []byte) ([]byte, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Decrypt returns data unchanged
|
||||
func (noe *NoOpEncryptor) Decrypt(data []byte) ([]byte, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// IsEnabled returns whether encryption is enabled
|
||||
func (noe *NoOpEncryptor) IsEnabled() bool {
|
||||
return noe.enabled
|
||||
}
|
||||
|
||||
// SerializationMetrics tracks serialization performance
|
||||
type SerializationMetrics struct {
|
||||
SerializedMessages int64
|
||||
DeserializedMessages int64
|
||||
SerializationErrors int64
|
||||
CompressionRatio float64
|
||||
AverageMessageSize int64
|
||||
TotalDataProcessed int64
|
||||
}
|
||||
|
||||
// MetricsCollector collects serialization metrics
|
||||
type MetricsCollector struct {
|
||||
metrics SerializationMetrics
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMetricsCollector creates a new metrics collector
|
||||
func NewMetricsCollector() *MetricsCollector {
|
||||
return &MetricsCollector{}
|
||||
}
|
||||
|
||||
// RecordSerialization records a serialization operation
|
||||
func (mc *MetricsCollector) RecordSerialization(originalSize, serializedSize int) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
mc.metrics.SerializedMessages++
|
||||
mc.metrics.TotalDataProcessed += int64(originalSize)
|
||||
|
||||
// Update compression ratio
|
||||
if originalSize > 0 {
|
||||
ratio := float64(serializedSize) / float64(originalSize)
|
||||
mc.metrics.CompressionRatio = (mc.metrics.CompressionRatio + ratio) / 2
|
||||
}
|
||||
|
||||
// Update average message size
|
||||
mc.metrics.AverageMessageSize = mc.metrics.TotalDataProcessed / mc.metrics.SerializedMessages
|
||||
}
|
||||
|
||||
// RecordDeserialization records a deserialization operation
|
||||
func (mc *MetricsCollector) RecordDeserialization() {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.metrics.DeserializedMessages++
|
||||
}
|
||||
|
||||
// RecordError records a serialization error
|
||||
func (mc *MetricsCollector) RecordError() {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.metrics.SerializationErrors++
|
||||
}
|
||||
|
||||
// GetMetrics returns current metrics
|
||||
func (mc *MetricsCollector) GetMetrics() SerializationMetrics {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
return mc.metrics
|
||||
}
|
||||
|
||||
// Reset resets all metrics
|
||||
func (mc *MetricsCollector) Reset() {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.metrics = SerializationMetrics{}
|
||||
}
|
||||
490
pkg/transport/tcp_transport.go
Normal file
490
pkg/transport/tcp_transport.go
Normal file
@@ -0,0 +1,490 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TCPTransport implements TCP transport for remote communication
|
||||
type TCPTransport struct {
|
||||
address string
|
||||
port int
|
||||
listener net.Listener
|
||||
connections map[string]net.Conn
|
||||
metrics TransportMetrics
|
||||
connected bool
|
||||
isServer bool
|
||||
receiveChan chan *Message
|
||||
tlsConfig *tls.Config
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
retryConfig RetryConfig
|
||||
}
|
||||
|
||||
// NewTCPTransport creates a new TCP transport
|
||||
func NewTCPTransport(address string, port int, isServer bool) *TCPTransport {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &TCPTransport{
|
||||
address: address,
|
||||
port: port,
|
||||
connections: make(map[string]net.Conn),
|
||||
metrics: TransportMetrics{},
|
||||
isServer: isServer,
|
||||
receiveChan: make(chan *Message, 1000),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
retryConfig: RetryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
Jitter: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetTLSConfig configures TLS for secure communication
|
||||
func (tt *TCPTransport) SetTLSConfig(config *tls.Config) {
|
||||
tt.tlsConfig = config
|
||||
}
|
||||
|
||||
// SetRetryConfig configures retry behavior
|
||||
func (tt *TCPTransport) SetRetryConfig(config RetryConfig) {
|
||||
tt.retryConfig = config
|
||||
}
|
||||
|
||||
// Connect establishes the TCP connection
|
||||
func (tt *TCPTransport) Connect(ctx context.Context) error {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
if tt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tt.isServer {
|
||||
return tt.startServer()
|
||||
} else {
|
||||
return tt.connectToServer(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnect closes the TCP connection
|
||||
func (tt *TCPTransport) Disconnect(ctx context.Context) error {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
if !tt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
tt.cancel()
|
||||
|
||||
if tt.isServer && tt.listener != nil {
|
||||
tt.listener.Close()
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
for id, conn := range tt.connections {
|
||||
conn.Close()
|
||||
delete(tt.connections, id)
|
||||
}
|
||||
|
||||
close(tt.receiveChan)
|
||||
tt.connected = false
|
||||
tt.metrics.Connections = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send transmits a message through TCP
|
||||
func (tt *TCPTransport) Send(ctx context.Context, msg *Message) error {
|
||||
start := time.Now()
|
||||
|
||||
tt.mu.RLock()
|
||||
if !tt.connected {
|
||||
tt.mu.RUnlock()
|
||||
tt.metrics.Errors++
|
||||
return fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
tt.mu.RUnlock()
|
||||
tt.metrics.Errors++
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Add length prefix for framing
|
||||
frame := fmt.Sprintf("%d\n%s", len(data), data)
|
||||
frameBytes := []byte(frame)
|
||||
|
||||
// Send to all connections with retry
|
||||
var sendErr error
|
||||
connectionCount := len(tt.connections)
|
||||
tt.mu.RUnlock()
|
||||
|
||||
if connectionCount == 0 {
|
||||
tt.metrics.Errors++
|
||||
return fmt.Errorf("no active connections")
|
||||
}
|
||||
|
||||
tt.mu.RLock()
|
||||
for connID, conn := range tt.connections {
|
||||
if err := tt.sendWithRetry(ctx, conn, frameBytes); err != nil {
|
||||
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
|
||||
// Remove failed connection
|
||||
go tt.removeConnection(connID)
|
||||
}
|
||||
}
|
||||
tt.mu.RUnlock()
|
||||
|
||||
if sendErr == nil {
|
||||
tt.updateSendMetrics(msg, time.Since(start))
|
||||
} else {
|
||||
tt.metrics.Errors++
|
||||
}
|
||||
|
||||
return sendErr
|
||||
}
|
||||
|
||||
// Receive returns a channel for receiving messages
|
||||
func (tt *TCPTransport) Receive(ctx context.Context) (<-chan *Message, error) {
|
||||
tt.mu.RLock()
|
||||
defer tt.mu.RUnlock()
|
||||
|
||||
if !tt.connected {
|
||||
return nil, fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
return tt.receiveChan, nil
|
||||
}
|
||||
|
||||
// Health returns the health status of the transport
|
||||
func (tt *TCPTransport) Health() ComponentHealth {
|
||||
tt.mu.RLock()
|
||||
defer tt.mu.RUnlock()
|
||||
|
||||
status := "unhealthy"
|
||||
var responseTime time.Duration
|
||||
|
||||
if tt.connected {
|
||||
if len(tt.connections) > 0 {
|
||||
status = "healthy"
|
||||
responseTime = time.Millisecond * 10 // Estimate for TCP
|
||||
} else {
|
||||
status = "degraded" // Connected but no active connections
|
||||
}
|
||||
}
|
||||
|
||||
return ComponentHealth{
|
||||
Status: status,
|
||||
LastCheck: time.Now(),
|
||||
ResponseTime: responseTime,
|
||||
ErrorCount: tt.metrics.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns transport-specific metrics
|
||||
func (tt *TCPTransport) GetMetrics() TransportMetrics {
|
||||
tt.mu.RLock()
|
||||
defer tt.mu.RUnlock()
|
||||
|
||||
return TransportMetrics{
|
||||
BytesSent: tt.metrics.BytesSent,
|
||||
BytesReceived: tt.metrics.BytesReceived,
|
||||
MessagesSent: tt.metrics.MessagesSent,
|
||||
MessagesReceived: tt.metrics.MessagesReceived,
|
||||
Connections: len(tt.connections),
|
||||
Errors: tt.metrics.Errors,
|
||||
Latency: tt.metrics.Latency,
|
||||
}
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (tt *TCPTransport) startServer() error {
|
||||
addr := fmt.Sprintf("%s:%d", tt.address, tt.port)
|
||||
|
||||
var listener net.Listener
|
||||
var err error
|
||||
|
||||
if tt.tlsConfig != nil {
|
||||
listener, err = tls.Listen("tcp", addr, tt.tlsConfig)
|
||||
} else {
|
||||
listener, err = net.Listen("tcp", addr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
tt.listener = listener
|
||||
tt.connected = true
|
||||
|
||||
// Start accepting connections
|
||||
go tt.acceptConnections()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) connectToServer(ctx context.Context) error {
|
||||
addr := fmt.Sprintf("%s:%d", tt.address, tt.port)
|
||||
|
||||
var conn net.Conn
|
||||
var err error
|
||||
|
||||
// Retry connection with exponential backoff
|
||||
delay := tt.retryConfig.InitialDelay
|
||||
for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ {
|
||||
if tt.tlsConfig != nil {
|
||||
conn, err = tls.Dial("tcp", addr, tt.tlsConfig)
|
||||
} else {
|
||||
conn, err = net.Dial("tcp", addr)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if attempt == tt.retryConfig.MaxRetries {
|
||||
return fmt.Errorf("failed to connect to %s after %d attempts: %w", addr, attempt+1, err)
|
||||
}
|
||||
|
||||
// Wait with exponential backoff
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor)
|
||||
if delay > tt.retryConfig.MaxDelay {
|
||||
delay = tt.retryConfig.MaxDelay
|
||||
}
|
||||
// Add jitter if enabled
|
||||
if tt.retryConfig.Jitter {
|
||||
jitter := time.Duration(float64(delay) * 0.1)
|
||||
delay += time.Duration(float64(jitter) * (2*time.Now().UnixNano()%1000/1000.0 - 1))
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
|
||||
tt.connections[connID] = conn
|
||||
tt.connected = true
|
||||
tt.metrics.Connections = 1
|
||||
|
||||
// Start receiving from server
|
||||
go tt.handleConnection(connID, conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) acceptConnections() {
|
||||
for {
|
||||
select {
|
||||
case <-tt.ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn, err := tt.listener.Accept()
|
||||
if err != nil {
|
||||
if tt.ctx.Err() != nil {
|
||||
return // Context cancelled
|
||||
}
|
||||
tt.metrics.Errors++
|
||||
continue
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
|
||||
|
||||
tt.mu.Lock()
|
||||
tt.connections[connID] = conn
|
||||
tt.metrics.Connections = len(tt.connections)
|
||||
tt.mu.Unlock()
|
||||
|
||||
go tt.handleConnection(connID, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) handleConnection(connID string, conn net.Conn) {
|
||||
defer tt.removeConnection(connID)
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
var messageBuffer []byte
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tt.ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue // Continue on timeout
|
||||
}
|
||||
return // Connection closed or error
|
||||
}
|
||||
|
||||
messageBuffer = append(messageBuffer, buffer[:n]...)
|
||||
|
||||
// Process complete messages
|
||||
for {
|
||||
msg, remaining, err := tt.extractMessage(messageBuffer)
|
||||
if err != nil {
|
||||
return // Invalid message format
|
||||
}
|
||||
if msg == nil {
|
||||
break // No complete message yet
|
||||
}
|
||||
|
||||
// Deliver message
|
||||
select {
|
||||
case tt.receiveChan <- msg:
|
||||
tt.updateReceiveMetrics(msg)
|
||||
case <-tt.ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Channel full, drop message
|
||||
tt.metrics.Errors++
|
||||
}
|
||||
|
||||
messageBuffer = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) sendWithRetry(ctx context.Context, conn net.Conn, data []byte) error {
|
||||
delay := tt.retryConfig.InitialDelay
|
||||
|
||||
for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ {
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
_, err := conn.Write(data)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if attempt == tt.retryConfig.MaxRetries {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait with exponential backoff
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor)
|
||||
if delay > tt.retryConfig.MaxDelay {
|
||||
delay = tt.retryConfig.MaxDelay
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) removeConnection(connID string) {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
if conn, exists := tt.connections[connID]; exists {
|
||||
conn.Close()
|
||||
delete(tt.connections, connID)
|
||||
tt.metrics.Connections = len(tt.connections)
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) extractMessage(buffer []byte) (*Message, []byte, error) {
|
||||
// Look for length prefix (format: "length\nmessage_data")
|
||||
newlineIndex := -1
|
||||
for i, b := range buffer {
|
||||
if b == '\n' {
|
||||
newlineIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if newlineIndex == -1 {
|
||||
return nil, buffer, nil // No complete length prefix yet
|
||||
}
|
||||
|
||||
// Parse length
|
||||
lengthStr := string(buffer[:newlineIndex])
|
||||
var messageLength int
|
||||
if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid length prefix: %s", lengthStr)
|
||||
}
|
||||
|
||||
// Check if we have the complete message
|
||||
messageStart := newlineIndex + 1
|
||||
messageEnd := messageStart + messageLength
|
||||
if len(buffer) < messageEnd {
|
||||
return nil, buffer, nil // Incomplete message
|
||||
}
|
||||
|
||||
// Extract and parse message
|
||||
messageData := buffer[messageStart:messageEnd]
|
||||
var msg Message
|
||||
if err := json.Unmarshal(messageData, &msg); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal message: %w", err)
|
||||
}
|
||||
|
||||
// Return message and remaining buffer
|
||||
remaining := buffer[messageEnd:]
|
||||
return &msg, remaining, nil
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) updateSendMetrics(msg *Message, latency time.Duration) {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
tt.metrics.MessagesSent++
|
||||
tt.metrics.Latency = latency
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
tt.metrics.BytesSent += messageSize
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) updateReceiveMetrics(msg *Message) {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
tt.metrics.MessagesReceived++
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
tt.metrics.BytesReceived += messageSize
|
||||
}
|
||||
|
||||
// GetAddress returns the transport address (for testing/debugging)
|
||||
func (tt *TCPTransport) GetAddress() string {
|
||||
return fmt.Sprintf("%s:%d", tt.address, tt.port)
|
||||
}
|
||||
|
||||
// GetConnectionCount returns the number of active connections
|
||||
func (tt *TCPTransport) GetConnectionCount() int {
|
||||
tt.mu.RLock()
|
||||
defer tt.mu.RUnlock()
|
||||
return len(tt.connections)
|
||||
}
|
||||
|
||||
// IsSecure returns whether TLS is enabled
|
||||
func (tt *TCPTransport) IsSecure() bool {
|
||||
return tt.tlsConfig != nil
|
||||
}
|
||||
399
pkg/transport/unix_transport.go
Normal file
399
pkg/transport/unix_transport.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UnixSocketTransport implements Unix socket transport for local IPC
|
||||
type UnixSocketTransport struct {
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
connections map[string]net.Conn
|
||||
metrics TransportMetrics
|
||||
connected bool
|
||||
isServer bool
|
||||
receiveChan chan *Message
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewUnixSocketTransport creates a new Unix socket transport
|
||||
func NewUnixSocketTransport(socketPath string, isServer bool) *UnixSocketTransport {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &UnixSocketTransport{
|
||||
socketPath: socketPath,
|
||||
connections: make(map[string]net.Conn),
|
||||
metrics: TransportMetrics{},
|
||||
isServer: isServer,
|
||||
receiveChan: make(chan *Message, 1000),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes the Unix socket connection
|
||||
func (ut *UnixSocketTransport) Connect(ctx context.Context) error {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
if ut.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ut.isServer {
|
||||
return ut.startServer()
|
||||
} else {
|
||||
return ut.connectToServer()
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnect closes the Unix socket connection
|
||||
func (ut *UnixSocketTransport) Disconnect(ctx context.Context) error {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
if !ut.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
ut.cancel()
|
||||
|
||||
if ut.isServer && ut.listener != nil {
|
||||
ut.listener.Close()
|
||||
// Remove socket file
|
||||
os.Remove(ut.socketPath)
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
for id, conn := range ut.connections {
|
||||
conn.Close()
|
||||
delete(ut.connections, id)
|
||||
}
|
||||
|
||||
close(ut.receiveChan)
|
||||
ut.connected = false
|
||||
ut.metrics.Connections = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send transmits a message through the Unix socket
|
||||
func (ut *UnixSocketTransport) Send(ctx context.Context, msg *Message) error {
|
||||
start := time.Now()
|
||||
|
||||
ut.mu.RLock()
|
||||
if !ut.connected {
|
||||
ut.mu.RUnlock()
|
||||
ut.metrics.Errors++
|
||||
return fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
ut.mu.RUnlock()
|
||||
ut.metrics.Errors++
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Add length prefix for framing
|
||||
frame := fmt.Sprintf("%d\n%s", len(data), data)
|
||||
frameBytes := []byte(frame)
|
||||
|
||||
// Send to all connections
|
||||
var sendErr error
|
||||
connectionCount := len(ut.connections)
|
||||
ut.mu.RUnlock()
|
||||
|
||||
if connectionCount == 0 {
|
||||
ut.metrics.Errors++
|
||||
return fmt.Errorf("no active connections")
|
||||
}
|
||||
|
||||
ut.mu.RLock()
|
||||
for connID, conn := range ut.connections {
|
||||
if err := ut.sendToConnection(conn, frameBytes); err != nil {
|
||||
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
|
||||
// Remove failed connection
|
||||
go ut.removeConnection(connID)
|
||||
}
|
||||
}
|
||||
ut.mu.RUnlock()
|
||||
|
||||
if sendErr == nil {
|
||||
ut.updateSendMetrics(msg, time.Since(start))
|
||||
} else {
|
||||
ut.metrics.Errors++
|
||||
}
|
||||
|
||||
return sendErr
|
||||
}
|
||||
|
||||
// Receive returns a channel for receiving messages
|
||||
func (ut *UnixSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
if !ut.connected {
|
||||
return nil, fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
return ut.receiveChan, nil
|
||||
}
|
||||
|
||||
// Health returns the health status of the transport
|
||||
func (ut *UnixSocketTransport) Health() ComponentHealth {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
status := "unhealthy"
|
||||
if ut.connected {
|
||||
if len(ut.connections) > 0 {
|
||||
status = "healthy"
|
||||
} else {
|
||||
status = "degraded" // Connected but no active connections
|
||||
}
|
||||
}
|
||||
|
||||
return ComponentHealth{
|
||||
Status: status,
|
||||
LastCheck: time.Now(),
|
||||
ResponseTime: time.Millisecond, // Fast for local sockets
|
||||
ErrorCount: ut.metrics.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns transport-specific metrics
|
||||
func (ut *UnixSocketTransport) GetMetrics() TransportMetrics {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
return TransportMetrics{
|
||||
BytesSent: ut.metrics.BytesSent,
|
||||
BytesReceived: ut.metrics.BytesReceived,
|
||||
MessagesSent: ut.metrics.MessagesSent,
|
||||
MessagesReceived: ut.metrics.MessagesReceived,
|
||||
Connections: len(ut.connections),
|
||||
Errors: ut.metrics.Errors,
|
||||
Latency: ut.metrics.Latency,
|
||||
}
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (ut *UnixSocketTransport) startServer() error {
|
||||
// Remove existing socket file
|
||||
os.Remove(ut.socketPath)
|
||||
|
||||
listener, err := net.Listen("unix", ut.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on socket %s: %w", ut.socketPath, err)
|
||||
}
|
||||
|
||||
ut.listener = listener
|
||||
ut.connected = true
|
||||
|
||||
// Start accepting connections
|
||||
go ut.acceptConnections()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) connectToServer() error {
|
||||
conn, err := net.Dial("unix", ut.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to socket %s: %w", ut.socketPath, err)
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
|
||||
ut.connections[connID] = conn
|
||||
ut.connected = true
|
||||
ut.metrics.Connections = 1
|
||||
|
||||
// Start receiving from server
|
||||
go ut.handleConnection(connID, conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) acceptConnections() {
|
||||
for {
|
||||
select {
|
||||
case <-ut.ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn, err := ut.listener.Accept()
|
||||
if err != nil {
|
||||
if ut.ctx.Err() != nil {
|
||||
return // Context cancelled
|
||||
}
|
||||
ut.metrics.Errors++
|
||||
continue
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
|
||||
|
||||
ut.mu.Lock()
|
||||
ut.connections[connID] = conn
|
||||
ut.metrics.Connections = len(ut.connections)
|
||||
ut.mu.Unlock()
|
||||
|
||||
go ut.handleConnection(connID, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) handleConnection(connID string, conn net.Conn) {
|
||||
defer ut.removeConnection(connID)
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
var messageBuffer []byte
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ut.ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second)) // Non-blocking read
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue // Continue on timeout
|
||||
}
|
||||
return // Connection closed or error
|
||||
}
|
||||
|
||||
messageBuffer = append(messageBuffer, buffer[:n]...)
|
||||
|
||||
// Process complete messages
|
||||
for {
|
||||
msg, remaining, err := ut.extractMessage(messageBuffer)
|
||||
if err != nil {
|
||||
return // Invalid message format
|
||||
}
|
||||
if msg == nil {
|
||||
break // No complete message yet
|
||||
}
|
||||
|
||||
// Deliver message
|
||||
select {
|
||||
case ut.receiveChan <- msg:
|
||||
ut.updateReceiveMetrics(msg)
|
||||
case <-ut.ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Channel full, drop message
|
||||
ut.metrics.Errors++
|
||||
}
|
||||
|
||||
messageBuffer = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) sendToConnection(conn net.Conn, data []byte) error {
|
||||
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
_, err := conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) removeConnection(connID string) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
if conn, exists := ut.connections[connID]; exists {
|
||||
conn.Close()
|
||||
delete(ut.connections, connID)
|
||||
ut.metrics.Connections = len(ut.connections)
|
||||
}
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) extractMessage(buffer []byte) (*Message, []byte, error) {
|
||||
// Look for length prefix (format: "length\nmessage_data")
|
||||
newlineIndex := -1
|
||||
for i, b := range buffer {
|
||||
if b == '\n' {
|
||||
newlineIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if newlineIndex == -1 {
|
||||
return nil, buffer, nil // No complete length prefix yet
|
||||
}
|
||||
|
||||
// Parse length
|
||||
lengthStr := string(buffer[:newlineIndex])
|
||||
var messageLength int
|
||||
if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid length prefix: %s", lengthStr)
|
||||
}
|
||||
|
||||
// Check if we have the complete message
|
||||
messageStart := newlineIndex + 1
|
||||
messageEnd := messageStart + messageLength
|
||||
if len(buffer) < messageEnd {
|
||||
return nil, buffer, nil // Incomplete message
|
||||
}
|
||||
|
||||
// Extract and parse message
|
||||
messageData := buffer[messageStart:messageEnd]
|
||||
var msg Message
|
||||
if err := json.Unmarshal(messageData, &msg); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal message: %w", err)
|
||||
}
|
||||
|
||||
// Return message and remaining buffer
|
||||
remaining := buffer[messageEnd:]
|
||||
return &msg, remaining, nil
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
ut.metrics.MessagesSent++
|
||||
ut.metrics.Latency = latency
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
ut.metrics.BytesSent += messageSize
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) updateReceiveMetrics(msg *Message) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
ut.metrics.MessagesReceived++
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
ut.metrics.BytesReceived += messageSize
|
||||
}
|
||||
|
||||
// GetSocketPath returns the socket path (for testing/debugging)
|
||||
func (ut *UnixSocketTransport) GetSocketPath() string {
|
||||
return ut.socketPath
|
||||
}
|
||||
|
||||
// GetConnectionCount returns the number of active connections
|
||||
func (ut *UnixSocketTransport) GetConnectionCount() int {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
return len(ut.connections)
|
||||
}
|
||||
427
pkg/transport/websocket_transport.go
Normal file
427
pkg/transport/websocket_transport.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WebSocketTransport implements WebSocket transport for real-time monitoring
|
||||
type WebSocketTransport struct {
|
||||
address string
|
||||
port int
|
||||
path string
|
||||
upgrader websocket.Upgrader
|
||||
connections map[string]*websocket.Conn
|
||||
metrics TransportMetrics
|
||||
connected bool
|
||||
isServer bool
|
||||
receiveChan chan *Message
|
||||
server *http.Server
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
pingInterval time.Duration
|
||||
pongTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewWebSocketTransport creates a new WebSocket transport
|
||||
func NewWebSocketTransport(address string, port int, path string, isServer bool) *WebSocketTransport {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &WebSocketTransport{
|
||||
address: address,
|
||||
port: port,
|
||||
path: path,
|
||||
connections: make(map[string]*websocket.Conn),
|
||||
metrics: TransportMetrics{},
|
||||
isServer: isServer,
|
||||
receiveChan: make(chan *Message, 1000),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins for now
|
||||
},
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
},
|
||||
pingInterval: 30 * time.Second,
|
||||
pongTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPingPongSettings configures WebSocket ping/pong settings
|
||||
func (wt *WebSocketTransport) SetPingPongSettings(pingInterval, pongTimeout time.Duration) {
|
||||
wt.pingInterval = pingInterval
|
||||
wt.pongTimeout = pongTimeout
|
||||
}
|
||||
|
||||
// Connect establishes the WebSocket connection
|
||||
func (wt *WebSocketTransport) Connect(ctx context.Context) error {
|
||||
wt.mu.Lock()
|
||||
defer wt.mu.Unlock()
|
||||
|
||||
if wt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
if wt.isServer {
|
||||
return wt.startServer()
|
||||
} else {
|
||||
return wt.connectToServer(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnect closes the WebSocket connection
|
||||
func (wt *WebSocketTransport) Disconnect(ctx context.Context) error {
|
||||
wt.mu.Lock()
|
||||
defer wt.mu.Unlock()
|
||||
|
||||
if !wt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
wt.cancel()
|
||||
|
||||
if wt.isServer && wt.server != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
wt.server.Shutdown(shutdownCtx)
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
for id, conn := range wt.connections {
|
||||
conn.Close()
|
||||
delete(wt.connections, id)
|
||||
}
|
||||
|
||||
close(wt.receiveChan)
|
||||
wt.connected = false
|
||||
wt.metrics.Connections = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send transmits a message through WebSocket
|
||||
func (wt *WebSocketTransport) Send(ctx context.Context, msg *Message) error {
|
||||
start := time.Now()
|
||||
|
||||
wt.mu.RLock()
|
||||
if !wt.connected {
|
||||
wt.mu.RUnlock()
|
||||
wt.metrics.Errors++
|
||||
return fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
wt.mu.RUnlock()
|
||||
wt.metrics.Errors++
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Send to all connections
|
||||
var sendErr error
|
||||
connectionCount := len(wt.connections)
|
||||
wt.mu.RUnlock()
|
||||
|
||||
if connectionCount == 0 {
|
||||
wt.metrics.Errors++
|
||||
return fmt.Errorf("no active connections")
|
||||
}
|
||||
|
||||
wt.mu.RLock()
|
||||
for connID, conn := range wt.connections {
|
||||
if err := wt.sendToConnection(conn, data); err != nil {
|
||||
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
|
||||
// Remove failed connection
|
||||
go wt.removeConnection(connID)
|
||||
}
|
||||
}
|
||||
wt.mu.RUnlock()
|
||||
|
||||
if sendErr == nil {
|
||||
wt.updateSendMetrics(msg, time.Since(start))
|
||||
} else {
|
||||
wt.metrics.Errors++
|
||||
}
|
||||
|
||||
return sendErr
|
||||
}
|
||||
|
||||
// Receive returns a channel for receiving messages
|
||||
func (wt *WebSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) {
|
||||
wt.mu.RLock()
|
||||
defer wt.mu.RUnlock()
|
||||
|
||||
if !wt.connected {
|
||||
return nil, fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
return wt.receiveChan, nil
|
||||
}
|
||||
|
||||
// Health returns the health status of the transport
|
||||
func (wt *WebSocketTransport) Health() ComponentHealth {
|
||||
wt.mu.RLock()
|
||||
defer wt.mu.RUnlock()
|
||||
|
||||
status := "unhealthy"
|
||||
if wt.connected {
|
||||
if len(wt.connections) > 0 {
|
||||
status = "healthy"
|
||||
} else {
|
||||
status = "degraded" // Connected but no active connections
|
||||
}
|
||||
}
|
||||
|
||||
return ComponentHealth{
|
||||
Status: status,
|
||||
LastCheck: time.Now(),
|
||||
ResponseTime: time.Millisecond * 5, // Very fast for WebSocket
|
||||
ErrorCount: wt.metrics.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns transport-specific metrics
|
||||
func (wt *WebSocketTransport) GetMetrics() TransportMetrics {
|
||||
wt.mu.RLock()
|
||||
defer wt.mu.RUnlock()
|
||||
|
||||
return TransportMetrics{
|
||||
BytesSent: wt.metrics.BytesSent,
|
||||
BytesReceived: wt.metrics.BytesReceived,
|
||||
MessagesSent: wt.metrics.MessagesSent,
|
||||
MessagesReceived: wt.metrics.MessagesReceived,
|
||||
Connections: len(wt.connections),
|
||||
Errors: wt.metrics.Errors,
|
||||
Latency: wt.metrics.Latency,
|
||||
}
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (wt *WebSocketTransport) startServer() error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(wt.path, wt.handleWebSocket)
|
||||
|
||||
// Add health check endpoint
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
health := wt.Health()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(health)
|
||||
})
|
||||
|
||||
// Add metrics endpoint
|
||||
mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
|
||||
metrics := wt.GetMetrics()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(metrics)
|
||||
})
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", wt.address, wt.port)
|
||||
wt.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 60 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
wt.connected = true
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
if err := wt.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
wt.metrics.Errors++
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) connectToServer(ctx context.Context) error {
|
||||
url := fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path)
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(ctx, url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to WebSocket server: %w", err)
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
|
||||
wt.connections[connID] = conn
|
||||
wt.connected = true
|
||||
wt.metrics.Connections = 1
|
||||
|
||||
// Start handling connection
|
||||
go wt.handleConnection(connID, conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := wt.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
wt.metrics.Errors++
|
||||
return
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
|
||||
|
||||
wt.mu.Lock()
|
||||
wt.connections[connID] = conn
|
||||
wt.metrics.Connections = len(wt.connections)
|
||||
wt.mu.Unlock()
|
||||
|
||||
go wt.handleConnection(connID, conn)
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) handleConnection(connID string, conn *websocket.Conn) {
|
||||
defer wt.removeConnection(connID)
|
||||
|
||||
// Set up ping/pong handling
|
||||
conn.SetReadDeadline(time.Now().Add(wt.pongTimeout))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(wt.pongTimeout))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start ping routine
|
||||
go wt.pingRoutine(connID, conn)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-wt.ctx.Done():
|
||||
return
|
||||
default:
|
||||
var msg Message
|
||||
err := conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
wt.metrics.Errors++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Deliver message
|
||||
select {
|
||||
case wt.receiveChan <- &msg:
|
||||
wt.updateReceiveMetrics(&msg)
|
||||
case <-wt.ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Channel full, drop message
|
||||
wt.metrics.Errors++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) pingRoutine(connID string, conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(wt.pingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
|
||||
return // Connection is likely closed
|
||||
}
|
||||
case <-wt.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) sendToConnection(conn *websocket.Conn, data []byte) error {
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
return conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) removeConnection(connID string) {
|
||||
wt.mu.Lock()
|
||||
defer wt.mu.Unlock()
|
||||
|
||||
if conn, exists := wt.connections[connID]; exists {
|
||||
conn.Close()
|
||||
delete(wt.connections, connID)
|
||||
wt.metrics.Connections = len(wt.connections)
|
||||
}
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) {
|
||||
wt.mu.Lock()
|
||||
defer wt.mu.Unlock()
|
||||
|
||||
wt.metrics.MessagesSent++
|
||||
wt.metrics.Latency = latency
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
wt.metrics.BytesSent += messageSize
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) updateReceiveMetrics(msg *Message) {
|
||||
wt.mu.Lock()
|
||||
defer wt.mu.Unlock()
|
||||
|
||||
wt.metrics.MessagesReceived++
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
wt.metrics.BytesReceived += messageSize
|
||||
}
|
||||
|
||||
// Broadcast sends a message to all connected clients (server mode only)
|
||||
func (wt *WebSocketTransport) Broadcast(ctx context.Context, msg *Message) error {
|
||||
if !wt.isServer {
|
||||
return fmt.Errorf("broadcast only available in server mode")
|
||||
}
|
||||
|
||||
return wt.Send(ctx, msg)
|
||||
}
|
||||
|
||||
// GetURL returns the WebSocket URL (for testing/debugging)
|
||||
func (wt *WebSocketTransport) GetURL() string {
|
||||
return fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path)
|
||||
}
|
||||
|
||||
// GetConnectionCount returns the number of active connections
|
||||
func (wt *WebSocketTransport) GetConnectionCount() int {
|
||||
wt.mu.RLock()
|
||||
defer wt.mu.RUnlock()
|
||||
return len(wt.connections)
|
||||
}
|
||||
|
||||
// SetAllowedOrigins configures CORS for WebSocket connections
|
||||
func (wt *WebSocketTransport) SetAllowedOrigins(origins []string) {
|
||||
if len(origins) == 0 {
|
||||
wt.upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||
return true // Allow all origins
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
originMap := make(map[string]bool)
|
||||
for _, origin := range origins {
|
||||
originMap[origin] = true
|
||||
}
|
||||
|
||||
wt.upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
return originMap[origin]
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user