feat: create v2-prep branch with comprehensive planning

Restructured project for V2 refactor:

**Structure Changes:**
- Moved all V1 code to orig/ folder (preserved with git mv)
- Created docs/planning/ directory
- Added orig/README_V1.md explaining V1 preservation

**Planning Documents:**
- 00_V2_MASTER_PLAN.md: Complete architecture overview
  - Executive summary of critical V1 issues
  - High-level component architecture diagrams
  - 5-phase implementation roadmap
  - Success metrics and risk mitigation

- 07_TASK_BREAKDOWN.md: Atomic task breakdown
  - 99+ hours of detailed tasks
  - Every task < 2 hours (atomic)
  - Clear dependencies and success criteria
  - Organized by implementation phase

**V2 Key Improvements:**
- Per-exchange parsers (factory pattern)
- Multi-layer strict validation
- Multi-index pool cache
- Background validation pipeline
- Comprehensive observability

**Critical Issues Addressed:**
- Zero address tokens (strict validation + cache enrichment)
- Parsing accuracy (protocol-specific parsers)
- No audit trail (background validation channel)
- Inefficient lookups (multi-index cache)
- Stats disconnection (event-driven metrics)

Next Steps:
1. Review planning documents
2. Begin Phase 1: Foundation (P1-001 through P1-010)
3. Implement parsers in Phase 2
4. Build cache system in Phase 3
5. Add validation pipeline in Phase 4
6. Migrate and test in Phase 5

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Administrator
2025-11-10 10:14:26 +01:00
parent 1773daffe7
commit 803de231ba
411 changed files with 20390 additions and 8680 deletions

View File

@@ -0,0 +1,673 @@
package transport
import (
"context"
"fmt"
"math"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/fraktal/mev-beta/internal/logger"
"github.com/fraktal/mev-beta/pkg/security"
)
// BenchmarkSuite provides comprehensive performance testing for the transport layer
type BenchmarkSuite struct {
logger *logger.Logger
messageBus *UniversalMessageBus
results []BenchmarkResult
config BenchmarkConfig
metrics BenchmarkMetrics
mu sync.RWMutex
}
// BenchmarkConfig configures benchmark parameters
type BenchmarkConfig struct {
MessageSizes []int // Message payload sizes to test
Concurrency []int // Concurrency levels to test
Duration time.Duration // Duration of each benchmark
WarmupDuration time.Duration // Warmup period before measurements
TransportTypes []TransportType // Transport types to benchmark
MessageTypes []MessageType // Message types to test
SerializationFormats []SerializationFormat // Serialization formats to test
EnableMetrics bool // Whether to collect detailed metrics
OutputFormat string // Output format (json, csv, console)
}
// BenchmarkResult contains results from a single benchmark run
type BenchmarkResult struct {
TestName string `json:"test_name"`
Transport TransportType `json:"transport"`
MessageSize int `json:"message_size"`
Concurrency int `json:"concurrency"`
Serialization SerializationFormat `json:"serialization"`
Duration time.Duration `json:"duration"`
MessagesSent int64 `json:"messages_sent"`
MessagesReceived int64 `json:"messages_received"`
BytesSent int64 `json:"bytes_sent"`
BytesReceived int64 `json:"bytes_received"`
ThroughputMsgSec float64 `json:"throughput_msg_sec"`
ThroughputByteSec float64 `json:"throughput_byte_sec"`
LatencyP50 time.Duration `json:"latency_p50"`
LatencyP95 time.Duration `json:"latency_p95"`
LatencyP99 time.Duration `json:"latency_p99"`
ErrorRate float64 `json:"error_rate"`
CPUUsage float64 `json:"cpu_usage"`
MemoryUsage int64 `json:"memory_usage"`
GCPauses int64 `json:"gc_pauses"`
Timestamp time.Time `json:"timestamp"`
}
// BenchmarkMetrics tracks overall benchmark statistics
type BenchmarkMetrics struct {
TotalTests int `json:"total_tests"`
PassedTests int `json:"passed_tests"`
FailedTests int `json:"failed_tests"`
TotalDuration time.Duration `json:"total_duration"`
HighestThroughput float64 `json:"highest_throughput"`
LowestLatency time.Duration `json:"lowest_latency"`
BestTransport TransportType `json:"best_transport"`
Timestamp time.Time `json:"timestamp"`
}
// LatencyTracker tracks message latencies
type LatencyTracker struct {
latencies []time.Duration
mu sync.Mutex
}
// NewBenchmarkSuite creates a new benchmark suite
func NewBenchmarkSuite(messageBus *UniversalMessageBus, logger *logger.Logger) *BenchmarkSuite {
return &BenchmarkSuite{
logger: logger,
messageBus: messageBus,
results: make([]BenchmarkResult, 0),
config: BenchmarkConfig{
MessageSizes: []int{64, 256, 1024, 4096, 16384},
Concurrency: []int{1, 10, 50, 100},
Duration: 30 * time.Second,
WarmupDuration: 5 * time.Second,
TransportTypes: []TransportType{TransportMemory, TransportUnixSocket, TransportTCP},
MessageTypes: []MessageType{MessageTypeEvent, MessageTypeCommand},
SerializationFormats: []SerializationFormat{SerializationJSON},
EnableMetrics: true,
OutputFormat: "console",
},
}
}
// SetConfig updates the benchmark configuration
func (bs *BenchmarkSuite) SetConfig(config BenchmarkConfig) {
bs.mu.Lock()
defer bs.mu.Unlock()
bs.config = config
}
// RunAll executes all benchmark tests
func (bs *BenchmarkSuite) RunAll(ctx context.Context) error {
bs.mu.Lock()
defer bs.mu.Unlock()
startTime := time.Now()
bs.metrics = BenchmarkMetrics{
Timestamp: startTime,
}
for _, transport := range bs.config.TransportTypes {
for _, msgSize := range bs.config.MessageSizes {
for _, concurrency := range bs.config.Concurrency {
for _, serialization := range bs.config.SerializationFormats {
result, err := bs.runSingleBenchmark(ctx, transport, msgSize, concurrency, serialization)
if err != nil {
bs.metrics.FailedTests++
continue
}
bs.results = append(bs.results, result)
bs.metrics.PassedTests++
bs.updateBestMetrics(result)
}
}
}
}
bs.metrics.TotalTests = bs.metrics.PassedTests + bs.metrics.FailedTests
bs.metrics.TotalDuration = time.Since(startTime)
return nil
}
// RunThroughputBenchmark tests message throughput
func (bs *BenchmarkSuite) RunThroughputBenchmark(ctx context.Context, transport TransportType, messageSize int, concurrency int) (BenchmarkResult, error) {
return bs.runSingleBenchmark(ctx, transport, messageSize, concurrency, SerializationJSON)
}
// RunLatencyBenchmark tests message latency
func (bs *BenchmarkSuite) RunLatencyBenchmark(ctx context.Context, transport TransportType, messageSize int) (BenchmarkResult, error) {
return bs.runSingleBenchmark(ctx, transport, messageSize, 1, SerializationJSON)
}
// RunScalabilityBenchmark tests scalability across different concurrency levels
func (bs *BenchmarkSuite) RunScalabilityBenchmark(ctx context.Context, transport TransportType, messageSize int) ([]BenchmarkResult, error) {
var results []BenchmarkResult
for _, concurrency := range bs.config.Concurrency {
result, err := bs.runSingleBenchmark(ctx, transport, messageSize, concurrency, SerializationJSON)
if err != nil {
return nil, fmt.Errorf("scalability benchmark failed at concurrency %d: %w", concurrency, err)
}
results = append(results, result)
}
return results, nil
}
// GetResults returns all benchmark results
func (bs *BenchmarkSuite) GetResults() []BenchmarkResult {
bs.mu.RLock()
defer bs.mu.RUnlock()
results := make([]BenchmarkResult, len(bs.results))
copy(results, bs.results)
return results
}
// GetMetrics returns benchmark metrics
func (bs *BenchmarkSuite) GetMetrics() BenchmarkMetrics {
bs.mu.RLock()
defer bs.mu.RUnlock()
return bs.metrics
}
// GetBestPerformingTransport returns the transport with the highest throughput
func (bs *BenchmarkSuite) GetBestPerformingTransport() TransportType {
bs.mu.RLock()
defer bs.mu.RUnlock()
return bs.metrics.BestTransport
}
// Private methods
func (bs *BenchmarkSuite) runSingleBenchmark(ctx context.Context, transport TransportType, messageSize int, concurrency int, serialization SerializationFormat) (BenchmarkResult, error) {
testName := fmt.Sprintf("%s_%db_%dc_%s", transport, messageSize, concurrency, serialization)
result := BenchmarkResult{
TestName: testName,
Transport: transport,
MessageSize: messageSize,
Concurrency: concurrency,
Serialization: serialization,
Duration: bs.config.Duration,
Timestamp: time.Now(),
}
// Setup test environment
latencyTracker := &LatencyTracker{
latencies: make([]time.Duration, 0),
}
// Create test topic
topic := fmt.Sprintf("benchmark_%s", testName)
// Subscribe to topic
subscription, err := bs.messageBus.Subscribe(topic, func(ctx context.Context, msg *Message) error {
if startTime, ok := msg.Metadata["start_time"].(time.Time); ok {
latency := time.Since(startTime)
latencyTracker.AddLatency(latency)
}
atomic.AddInt64(&result.MessagesReceived, 1)
atomic.AddInt64(&result.BytesReceived, int64(messageSize))
return nil
})
if err != nil {
return result, fmt.Errorf("failed to subscribe: %w", err)
}
defer bs.messageBus.Unsubscribe(subscription.ID)
// Warmup phase
if bs.config.WarmupDuration > 0 {
bs.warmup(ctx, topic, messageSize, concurrency, bs.config.WarmupDuration)
}
// Start system monitoring
var cpuUsage float64
var memUsageBefore, memUsageAfter runtime.MemStats
runtime.ReadMemStats(&memUsageBefore)
monitorCtx, monitorCancel := context.WithCancel(ctx)
defer monitorCancel()
go bs.monitorSystemResources(monitorCtx, &cpuUsage)
// Main benchmark
startTime := time.Now()
benchmarkCtx, cancel := context.WithTimeout(ctx, bs.config.Duration)
defer cancel()
// Launch concurrent senders
var wg sync.WaitGroup
var totalSent int64
var totalErrors int64
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
bs.senderWorker(benchmarkCtx, topic, messageSize, &totalSent, &totalErrors)
}()
}
wg.Wait()
// Wait a bit for remaining messages to be processed
time.Sleep(100 * time.Millisecond)
actualDuration := time.Since(startTime)
runtime.ReadMemStats(&memUsageAfter)
// Calculate results
result.MessagesSent = totalSent
result.BytesSent = totalSent * int64(messageSize)
result.ThroughputMsgSec = float64(totalSent) / actualDuration.Seconds()
result.ThroughputByteSec = float64(result.BytesSent) / actualDuration.Seconds()
result.ErrorRate = float64(totalErrors) / float64(totalSent) * 100
result.CPUUsage = cpuUsage
// Calculate memory usage difference safely
memDiff := memUsageAfter.Alloc - memUsageBefore.Alloc
memDiffInt64, err := security.SafeUint64ToInt64(memDiff)
if err != nil {
bs.logger.Warn("Memory usage difference exceeds int64 max", "diff", memDiff, "error", err)
result.MemoryUsage = math.MaxInt64
} else {
result.MemoryUsage = memDiffInt64
}
// Calculate GC pauses difference safely
gcDiff := int64(memUsageAfter.NumGC) - int64(memUsageBefore.NumGC)
result.GCPauses = gcDiff
// Calculate latency percentiles
if len(latencyTracker.latencies) > 0 {
result.LatencyP50 = latencyTracker.GetPercentile(50)
result.LatencyP95 = latencyTracker.GetPercentile(95)
result.LatencyP99 = latencyTracker.GetPercentile(99)
}
return result, nil
}
func (bs *BenchmarkSuite) warmup(ctx context.Context, topic string, messageSize int, concurrency int, duration time.Duration) {
warmupCtx, cancel := context.WithTimeout(ctx, duration)
defer cancel()
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
var dummy1, dummy2 int64
bs.senderWorker(warmupCtx, topic, messageSize, &dummy1, &dummy2)
}()
}
wg.Wait()
}
func (bs *BenchmarkSuite) senderWorker(ctx context.Context, topic string, messageSize int, totalSent, totalErrors *int64) {
payload := make([]byte, messageSize)
for i := range payload {
payload[i] = byte(i % 256)
}
for {
select {
case <-ctx.Done():
return
default:
msg := NewMessage(MessageTypeEvent, topic, "benchmark", payload)
msg.Metadata["start_time"] = time.Now()
if err := bs.messageBus.Publish(ctx, msg); err != nil {
atomic.AddInt64(totalErrors, 1)
} else {
atomic.AddInt64(totalSent, 1)
}
}
}
}
func (bs *BenchmarkSuite) monitorSystemResources(ctx context.Context, cpuUsage *float64) {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
var samples []float64
startTime := time.Now()
for {
select {
case <-ctx.Done():
// Calculate average CPU usage
if len(samples) > 0 {
var total float64
for _, sample := range samples {
total += sample
}
*cpuUsage = total / float64(len(samples))
}
return
case <-ticker.C:
// Simple CPU usage estimation based on runtime stats
var stats runtime.MemStats
runtime.ReadMemStats(&stats)
// This is a simplified CPU usage calculation
// In production, you'd want to use proper OS-specific CPU monitoring
elapsed := time.Since(startTime).Seconds()
cpuSample := float64(stats.NumGC) / elapsed * 100 // Rough approximation
if cpuSample > 100 {
cpuSample = 100
}
samples = append(samples, cpuSample)
}
}
}
func (bs *BenchmarkSuite) updateBestMetrics(result BenchmarkResult) {
if result.ThroughputMsgSec > bs.metrics.HighestThroughput {
bs.metrics.HighestThroughput = result.ThroughputMsgSec
bs.metrics.BestTransport = result.Transport
}
if bs.metrics.LowestLatency == 0 || result.LatencyP50 < bs.metrics.LowestLatency {
bs.metrics.LowestLatency = result.LatencyP50
}
}
// LatencyTracker methods
func (lt *LatencyTracker) AddLatency(latency time.Duration) {
lt.mu.Lock()
defer lt.mu.Unlock()
lt.latencies = append(lt.latencies, latency)
}
func (lt *LatencyTracker) GetPercentile(percentile int) time.Duration {
lt.mu.Lock()
defer lt.mu.Unlock()
if len(lt.latencies) == 0 {
return 0
}
// Sort latencies
sorted := make([]time.Duration, len(lt.latencies))
copy(sorted, lt.latencies)
// Simple insertion sort for small datasets
for i := 1; i < len(sorted); i++ {
for j := i; j > 0 && sorted[j] < sorted[j-1]; j-- {
sorted[j], sorted[j-1] = sorted[j-1], sorted[j]
}
}
// Calculate percentile index
index := int(float64(len(sorted)) * float64(percentile) / 100.0)
if index >= len(sorted) {
index = len(sorted) - 1
}
return sorted[index]
}
// Benchmark report generation
// GenerateReport generates a comprehensive benchmark report
func (bs *BenchmarkSuite) GenerateReport() BenchmarkReport {
bs.mu.RLock()
defer bs.mu.RUnlock()
report := BenchmarkReport{
Summary: bs.generateSummary(),
Results: bs.results,
Metrics: bs.metrics,
Config: bs.config,
Timestamp: time.Now(),
}
report.Analysis = bs.generateAnalysis()
return report
}
// BenchmarkReport contains a complete benchmark report
type BenchmarkReport struct {
Summary ReportSummary `json:"summary"`
Results []BenchmarkResult `json:"results"`
Metrics BenchmarkMetrics `json:"metrics"`
Config BenchmarkConfig `json:"config"`
Analysis ReportAnalysis `json:"analysis"`
Timestamp time.Time `json:"timestamp"`
}
// ReportSummary provides a high-level summary
type ReportSummary struct {
TotalTests int `json:"total_tests"`
Duration time.Duration `json:"duration"`
BestThroughput float64 `json:"best_throughput"`
BestLatency time.Duration `json:"best_latency"`
RecommendedTransport TransportType `json:"recommended_transport"`
TransportRankings []TransportRanking `json:"transport_rankings"`
}
// TransportRanking ranks transports by performance
type TransportRanking struct {
Transport TransportType `json:"transport"`
AvgThroughput float64 `json:"avg_throughput"`
AvgLatency time.Duration `json:"avg_latency"`
Score float64 `json:"score"`
Rank int `json:"rank"`
}
// ReportAnalysis provides detailed analysis
type ReportAnalysis struct {
ScalabilityAnalysis ScalabilityAnalysis `json:"scalability"`
PerformanceBottlenecks []PerformanceIssue `json:"bottlenecks"`
Recommendations []Recommendation `json:"recommendations"`
}
// ScalabilityAnalysis analyzes scaling characteristics
type ScalabilityAnalysis struct {
LinearScaling bool `json:"linear_scaling"`
ScalingFactor float64 `json:"scaling_factor"`
OptimalConcurrency int `json:"optimal_concurrency"`
}
// PerformanceIssue identifies performance problems
type PerformanceIssue struct {
Issue string `json:"issue"`
Severity string `json:"severity"`
Impact string `json:"impact"`
Suggestion string `json:"suggestion"`
}
// Recommendation provides optimization suggestions
type Recommendation struct {
Category string `json:"category"`
Description string `json:"description"`
Priority string `json:"priority"`
Expected string `json:"expected_improvement"`
}
func (bs *BenchmarkSuite) generateSummary() ReportSummary {
rankings := bs.calculateTransportRankings()
return ReportSummary{
TotalTests: bs.metrics.TotalTests,
Duration: bs.metrics.TotalDuration,
BestThroughput: bs.metrics.HighestThroughput,
BestLatency: bs.metrics.LowestLatency,
RecommendedTransport: bs.metrics.BestTransport,
TransportRankings: rankings,
}
}
func (bs *BenchmarkSuite) calculateTransportRankings() []TransportRanking {
// Group results by transport
transportStats := make(map[TransportType][]BenchmarkResult)
for _, result := range bs.results {
transportStats[result.Transport] = append(transportStats[result.Transport], result)
}
var rankings []TransportRanking
for transport, results := range transportStats {
var totalThroughput float64
var totalLatency time.Duration
for _, result := range results {
totalThroughput += result.ThroughputMsgSec
totalLatency += result.LatencyP50
}
avgThroughput := totalThroughput / float64(len(results))
avgLatency := totalLatency / time.Duration(len(results))
// Score calculation (higher throughput + lower latency = better score)
score := avgThroughput / float64(avgLatency.Microseconds())
rankings = append(rankings, TransportRanking{
Transport: transport,
AvgThroughput: avgThroughput,
AvgLatency: avgLatency,
Score: score,
})
}
// Sort by score (descending)
for i := 0; i < len(rankings); i++ {
for j := i + 1; j < len(rankings); j++ {
if rankings[j].Score > rankings[i].Score {
rankings[i], rankings[j] = rankings[j], rankings[i]
}
}
}
// Assign ranks
for i := range rankings {
rankings[i].Rank = i + 1
}
return rankings
}
func (bs *BenchmarkSuite) generateAnalysis() ReportAnalysis {
return ReportAnalysis{
ScalabilityAnalysis: bs.analyzeScalability(),
PerformanceBottlenecks: bs.identifyBottlenecks(),
Recommendations: bs.generateRecommendations(),
}
}
func (bs *BenchmarkSuite) analyzeScalability() ScalabilityAnalysis {
if len(bs.results) < 2 {
return ScalabilityAnalysis{
LinearScaling: false,
ScalingFactor: 0.0,
OptimalConcurrency: 1,
}
}
// Analyze throughput vs concurrency relationship
var throughputData []float64
var concurrencyData []int
for _, result := range bs.results {
if result.Concurrency > 0 && result.Duration > 0 {
throughput := float64(result.MessagesReceived) / result.Duration.Seconds()
throughputData = append(throughputData, throughput)
concurrencyData = append(concurrencyData, result.Concurrency)
}
}
if len(throughputData) < 2 {
return ScalabilityAnalysis{
LinearScaling: false,
ScalingFactor: 0.0,
OptimalConcurrency: 1,
}
}
// Calculate scaling efficiency
// Compare actual throughput improvement with ideal linear scaling
maxThroughput := 0.0
maxThroughputConcurrency := 1
baseThroughput := throughputData[0]
baseConcurrency := float64(concurrencyData[0])
for i, throughput := range throughputData {
if throughput > maxThroughput {
maxThroughput = throughput
maxThroughputConcurrency = concurrencyData[i]
}
}
// Calculate scaling factor (actual vs ideal)
idealThroughput := baseThroughput * float64(maxThroughputConcurrency) / baseConcurrency
actualScalingFactor := maxThroughput / idealThroughput
// Determine if scaling is linear (within 20% of ideal)
linearScaling := actualScalingFactor >= 0.8
return ScalabilityAnalysis{
LinearScaling: linearScaling,
ScalingFactor: actualScalingFactor,
OptimalConcurrency: maxThroughputConcurrency,
}
}
func (bs *BenchmarkSuite) identifyBottlenecks() []PerformanceIssue {
var issues []PerformanceIssue
// Analyze results for common performance issues
for _, result := range bs.results {
if result.ErrorRate > 5.0 {
issues = append(issues, PerformanceIssue{
Issue: fmt.Sprintf("High error rate (%0.2f%%) for %s", result.ErrorRate, result.Transport),
Severity: "high",
Impact: "Reduced reliability and performance",
Suggestion: "Check transport configuration and network stability",
})
}
if result.LatencyP99 > 100*time.Millisecond {
issues = append(issues, PerformanceIssue{
Issue: fmt.Sprintf("High P99 latency (%v) for %s", result.LatencyP99, result.Transport),
Severity: "medium",
Impact: "Poor user experience for latency-sensitive operations",
Suggestion: "Consider using faster transport or optimizing message serialization",
})
}
}
return issues
}
func (bs *BenchmarkSuite) generateRecommendations() []Recommendation {
var recommendations []Recommendation
recommendations = append(recommendations, Recommendation{
Category: "Transport Selection",
Description: fmt.Sprintf("Use %s for best overall performance", bs.metrics.BestTransport),
Priority: "high",
Expected: "20-50% improvement in throughput",
})
recommendations = append(recommendations, Recommendation{
Category: "Concurrency",
Description: "Optimize concurrency level based on workload characteristics",
Priority: "medium",
Expected: "10-30% improvement in resource utilization",
})
return recommendations
}

591
orig/pkg/transport/dlq.go Normal file
View File

@@ -0,0 +1,591 @@
package transport
import (
"context"
"fmt"
"sort"
"sync"
"time"
)
// DeadLetterQueue handles failed messages with retry and reprocessing capabilities
type DeadLetterQueue struct {
messages map[string][]*DLQMessage
config DLQConfig
metrics DLQMetrics
reprocessor MessageReprocessor
mu sync.RWMutex
cleanupTicker *time.Ticker
ctx context.Context
cancel context.CancelFunc
}
// DLQMessage represents a message in the dead letter queue
type DLQMessage struct {
ID string
OriginalMessage *Message
Topic string
FirstFailed time.Time
LastAttempt time.Time
AttemptCount int
MaxRetries int
FailureReason string
RetryDelay time.Duration
NextRetry time.Time
Metadata map[string]interface{}
Permanent bool
}
// DLQConfig configures dead letter queue behavior
type DLQConfig struct {
MaxMessages int
MaxRetries int
RetentionTime time.Duration
AutoReprocess bool
ReprocessInterval time.Duration
BackoffStrategy BackoffStrategy
InitialRetryDelay time.Duration
MaxRetryDelay time.Duration
BackoffMultiplier float64
PermanentFailures []string // Error patterns that mark messages as permanently failed
ReprocessBatchSize int
}
// BackoffStrategy defines retry delay calculation methods
type BackoffStrategy string
const (
BackoffFixed BackoffStrategy = "fixed"
BackoffLinear BackoffStrategy = "linear"
BackoffExponential BackoffStrategy = "exponential"
BackoffCustom BackoffStrategy = "custom"
)
// DLQMetrics tracks dead letter queue statistics
type DLQMetrics struct {
MessagesAdded int64
MessagesReprocessed int64
MessagesExpired int64
MessagesPermanent int64
ReprocessSuccesses int64
ReprocessFailures int64
QueueSize int64
OldestMessage time.Time
}
// MessageReprocessor handles message reprocessing logic
type MessageReprocessor interface {
Reprocess(ctx context.Context, msg *DLQMessage) error
CanReprocess(msg *DLQMessage) bool
ShouldRetry(msg *DLQMessage, err error) bool
}
// DefaultMessageReprocessor implements basic reprocessing logic
type DefaultMessageReprocessor struct {
publisher MessagePublisher
}
// MessagePublisher interface for republishing messages
type MessagePublisher interface {
Publish(ctx context.Context, msg *Message) error
}
// NewDeadLetterQueue creates a new dead letter queue
func NewDeadLetterQueue(config DLQConfig) *DeadLetterQueue {
ctx, cancel := context.WithCancel(context.Background())
dlq := &DeadLetterQueue{
messages: make(map[string][]*DLQMessage),
config: config,
metrics: DLQMetrics{},
ctx: ctx,
cancel: cancel,
}
// Set default configuration values
if dlq.config.MaxMessages == 0 {
dlq.config.MaxMessages = 10000
}
if dlq.config.MaxRetries == 0 {
dlq.config.MaxRetries = 3
}
if dlq.config.RetentionTime == 0 {
dlq.config.RetentionTime = 24 * time.Hour
}
if dlq.config.ReprocessInterval == 0 {
dlq.config.ReprocessInterval = 5 * time.Minute
}
if dlq.config.InitialRetryDelay == 0 {
dlq.config.InitialRetryDelay = time.Minute
}
if dlq.config.MaxRetryDelay == 0 {
dlq.config.MaxRetryDelay = time.Hour
}
if dlq.config.BackoffMultiplier == 0 {
dlq.config.BackoffMultiplier = 2.0
}
if dlq.config.BackoffStrategy == "" {
dlq.config.BackoffStrategy = BackoffExponential
}
if dlq.config.ReprocessBatchSize == 0 {
dlq.config.ReprocessBatchSize = 10
}
// Start cleanup routine
dlq.startCleanupRoutine()
// Start reprocessing routine if enabled
if dlq.config.AutoReprocess {
dlq.startReprocessRoutine()
}
return dlq
}
// AddMessage adds a failed message to the dead letter queue
func (dlq *DeadLetterQueue) AddMessage(topic string, msg *Message) error {
return dlq.AddMessageWithReason(topic, msg, "unknown failure")
}
// AddMessageWithReason adds a failed message with a specific failure reason
func (dlq *DeadLetterQueue) AddMessageWithReason(topic string, msg *Message, reason string) error {
dlq.mu.Lock()
defer dlq.mu.Unlock()
// Check if we've exceeded max messages
totalMessages := dlq.getTotalMessageCount()
if totalMessages >= dlq.config.MaxMessages {
// Remove oldest message to make room
dlq.removeOldestMessage()
}
// Check if this is a permanent failure
permanent := dlq.isPermanentFailure(reason)
dlqMsg := &DLQMessage{
ID: fmt.Sprintf("dlq_%s_%d", topic, time.Now().UnixNano()),
OriginalMessage: msg,
Topic: topic,
FirstFailed: time.Now(),
LastAttempt: time.Now(),
AttemptCount: 1,
MaxRetries: dlq.config.MaxRetries,
FailureReason: reason,
Metadata: make(map[string]interface{}),
Permanent: permanent,
}
if !permanent {
dlqMsg.RetryDelay = dlq.calculateRetryDelay(dlqMsg)
dlqMsg.NextRetry = time.Now().Add(dlqMsg.RetryDelay)
}
// Add to queue
if _, exists := dlq.messages[topic]; !exists {
dlq.messages[topic] = make([]*DLQMessage, 0)
}
dlq.messages[topic] = append(dlq.messages[topic], dlqMsg)
// Update metrics
dlq.metrics.MessagesAdded++
dlq.metrics.QueueSize++
if permanent {
dlq.metrics.MessagesPermanent++
}
dlq.updateOldestMessage()
return nil
}
// GetMessages returns all messages for a topic
func (dlq *DeadLetterQueue) GetMessages(topic string) ([]*DLQMessage, error) {
dlq.mu.RLock()
defer dlq.mu.RUnlock()
messages, exists := dlq.messages[topic]
if !exists {
return []*DLQMessage{}, nil
}
// Return a copy to avoid race conditions
result := make([]*DLQMessage, len(messages))
copy(result, messages)
return result, nil
}
// GetAllMessages returns all messages across all topics
func (dlq *DeadLetterQueue) GetAllMessages() map[string][]*DLQMessage {
dlq.mu.RLock()
defer dlq.mu.RUnlock()
result := make(map[string][]*DLQMessage)
for topic, messages := range dlq.messages {
result[topic] = make([]*DLQMessage, len(messages))
copy(result[topic], messages)
}
return result
}
// ReprocessMessage attempts to reprocess a specific message
func (dlq *DeadLetterQueue) ReprocessMessage(messageID string) error {
dlq.mu.Lock()
defer dlq.mu.Unlock()
// Find message
var dlqMsg *DLQMessage
var topic string
var index int
for t, messages := range dlq.messages {
for i, msg := range messages {
if msg.ID == messageID {
dlqMsg = msg
topic = t
index = i
break
}
}
if dlqMsg != nil {
break
}
}
if dlqMsg == nil {
return fmt.Errorf("message not found: %s", messageID)
}
if dlqMsg.Permanent {
return fmt.Errorf("message marked as permanent failure: %s", messageID)
}
// Attempt reprocessing
err := dlq.attemptReprocess(dlqMsg)
if err == nil {
// Success - remove from queue
dlq.removeMessageByIndex(topic, index)
dlq.metrics.ReprocessSuccesses++
dlq.metrics.QueueSize--
return nil
}
// Failed - update retry information
dlqMsg.AttemptCount++
dlqMsg.LastAttempt = time.Now()
dlqMsg.FailureReason = err.Error()
if dlqMsg.AttemptCount >= dlqMsg.MaxRetries {
dlqMsg.Permanent = true
dlq.metrics.MessagesPermanent++
} else {
dlqMsg.RetryDelay = dlq.calculateRetryDelay(dlqMsg)
dlqMsg.NextRetry = time.Now().Add(dlqMsg.RetryDelay)
}
dlq.metrics.ReprocessFailures++
return fmt.Errorf("reprocessing failed: %w", err)
}
// PurgeMessages removes all messages for a topic
func (dlq *DeadLetterQueue) PurgeMessages(topic string) error {
dlq.mu.Lock()
defer dlq.mu.Unlock()
if messages, exists := dlq.messages[topic]; exists {
count := len(messages)
delete(dlq.messages, topic)
dlq.metrics.QueueSize -= int64(count)
dlq.updateOldestMessage()
}
return nil
}
// PurgeAllMessages removes all messages from the queue
func (dlq *DeadLetterQueue) PurgeAllMessages() error {
dlq.mu.Lock()
defer dlq.mu.Unlock()
dlq.messages = make(map[string][]*DLQMessage)
dlq.metrics.QueueSize = 0
dlq.metrics.OldestMessage = time.Time{}
return nil
}
// GetMessageCount returns the total number of messages in the queue
func (dlq *DeadLetterQueue) GetMessageCount() int {
dlq.mu.RLock()
defer dlq.mu.RUnlock()
return dlq.getTotalMessageCount()
}
// GetMetrics returns current DLQ metrics
func (dlq *DeadLetterQueue) GetMetrics() DLQMetrics {
dlq.mu.RLock()
defer dlq.mu.RUnlock()
return dlq.metrics
}
// SetReprocessor sets the message reprocessor
func (dlq *DeadLetterQueue) SetReprocessor(reprocessor MessageReprocessor) {
dlq.mu.Lock()
defer dlq.mu.Unlock()
dlq.reprocessor = reprocessor
}
// Cleanup removes expired messages
func (dlq *DeadLetterQueue) Cleanup(maxAge time.Duration) error {
dlq.mu.Lock()
defer dlq.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
expiredCount := 0
for topic, messages := range dlq.messages {
filtered := make([]*DLQMessage, 0)
for _, msg := range messages {
if msg.FirstFailed.After(cutoff) {
filtered = append(filtered, msg)
} else {
expiredCount++
}
}
dlq.messages[topic] = filtered
// Remove empty topics
if len(filtered) == 0 {
delete(dlq.messages, topic)
}
}
dlq.metrics.MessagesExpired += int64(expiredCount)
dlq.metrics.QueueSize -= int64(expiredCount)
dlq.updateOldestMessage()
return nil
}
// Stop gracefully shuts down the dead letter queue
func (dlq *DeadLetterQueue) Stop() error {
dlq.cancel()
if dlq.cleanupTicker != nil {
dlq.cleanupTicker.Stop()
}
return nil
}
// Private helper methods
func (dlq *DeadLetterQueue) getTotalMessageCount() int {
count := 0
for _, messages := range dlq.messages {
count += len(messages)
}
return count
}
func (dlq *DeadLetterQueue) removeOldestMessage() {
var oldestTime time.Time
var oldestTopic string
var oldestIndex int
for topic, messages := range dlq.messages {
for i, msg := range messages {
if oldestTime.IsZero() || msg.FirstFailed.Before(oldestTime) {
oldestTime = msg.FirstFailed
oldestTopic = topic
oldestIndex = i
}
}
}
if !oldestTime.IsZero() {
dlq.removeMessageByIndex(oldestTopic, oldestIndex)
dlq.metrics.QueueSize--
}
}
func (dlq *DeadLetterQueue) removeMessageByIndex(topic string, index int) {
messages := dlq.messages[topic]
dlq.messages[topic] = append(messages[:index], messages[index+1:]...)
if len(dlq.messages[topic]) == 0 {
delete(dlq.messages, topic)
}
}
func (dlq *DeadLetterQueue) isPermanentFailure(reason string) bool {
for _, pattern := range dlq.config.PermanentFailures {
if pattern == reason {
return true
}
// Simple pattern matching (can be enhanced with regex)
if len(pattern) > 0 && pattern[len(pattern)-1] == '*' {
prefix := pattern[:len(pattern)-1]
if len(reason) >= len(prefix) && reason[:len(prefix)] == prefix {
return true
}
}
}
return false
}
func (dlq *DeadLetterQueue) calculateRetryDelay(msg *DLQMessage) time.Duration {
switch dlq.config.BackoffStrategy {
case BackoffFixed:
return dlq.config.InitialRetryDelay
case BackoffLinear:
delay := time.Duration(msg.AttemptCount) * dlq.config.InitialRetryDelay
if delay > dlq.config.MaxRetryDelay {
return dlq.config.MaxRetryDelay
}
return delay
case BackoffExponential:
delay := time.Duration(float64(dlq.config.InitialRetryDelay) *
pow(dlq.config.BackoffMultiplier, float64(msg.AttemptCount-1)))
if delay > dlq.config.MaxRetryDelay {
return dlq.config.MaxRetryDelay
}
return delay
default:
return dlq.config.InitialRetryDelay
}
}
func (dlq *DeadLetterQueue) attemptReprocess(msg *DLQMessage) error {
if dlq.reprocessor == nil {
return fmt.Errorf("no reprocessor configured")
}
if !dlq.reprocessor.CanReprocess(msg) {
return fmt.Errorf("message cannot be reprocessed")
}
return dlq.reprocessor.Reprocess(dlq.ctx, msg)
}
func (dlq *DeadLetterQueue) updateOldestMessage() {
var oldest time.Time
for _, messages := range dlq.messages {
for _, msg := range messages {
if oldest.IsZero() || msg.FirstFailed.Before(oldest) {
oldest = msg.FirstFailed
}
}
}
dlq.metrics.OldestMessage = oldest
}
func (dlq *DeadLetterQueue) startCleanupRoutine() {
dlq.cleanupTicker = time.NewTicker(dlq.config.ReprocessInterval)
go func() {
for {
select {
case <-dlq.cleanupTicker.C:
dlq.Cleanup(dlq.config.RetentionTime)
case <-dlq.ctx.Done():
return
}
}
}()
}
func (dlq *DeadLetterQueue) startReprocessRoutine() {
ticker := time.NewTicker(dlq.config.ReprocessInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
dlq.processRetryableMessages()
case <-dlq.ctx.Done():
return
}
}
}()
}
func (dlq *DeadLetterQueue) processRetryableMessages() {
dlq.mu.Lock()
retryable := dlq.getRetryableMessages()
dlq.mu.Unlock()
// Sort by next retry time
sort.Slice(retryable, func(i, j int) bool {
return retryable[i].NextRetry.Before(retryable[j].NextRetry)
})
// Process batch
batchSize := dlq.config.ReprocessBatchSize
if len(retryable) < batchSize {
batchSize = len(retryable)
}
for i := 0; i < batchSize; i++ {
msg := retryable[i]
if time.Now().After(msg.NextRetry) {
dlq.ReprocessMessage(msg.ID)
}
}
}
func (dlq *DeadLetterQueue) getRetryableMessages() []*DLQMessage {
var retryable []*DLQMessage
for _, messages := range dlq.messages {
for _, msg := range messages {
if !msg.Permanent && msg.AttemptCount < msg.MaxRetries {
retryable = append(retryable, msg)
}
}
}
return retryable
}
// Implementation of DefaultMessageReprocessor
func NewDefaultMessageReprocessor(publisher MessagePublisher) *DefaultMessageReprocessor {
return &DefaultMessageReprocessor{
publisher: publisher,
}
}
func (r *DefaultMessageReprocessor) Reprocess(ctx context.Context, msg *DLQMessage) error {
if r.publisher == nil {
return fmt.Errorf("no publisher configured")
}
return r.publisher.Publish(ctx, msg.OriginalMessage)
}
func (r *DefaultMessageReprocessor) CanReprocess(msg *DLQMessage) bool {
return !msg.Permanent && msg.AttemptCount < msg.MaxRetries
}
func (r *DefaultMessageReprocessor) ShouldRetry(msg *DLQMessage, err error) bool {
// Simple retry logic - can be enhanced based on error types
return msg.AttemptCount < msg.MaxRetries
}
// Helper function for power calculation
func pow(base, exp float64) float64 {
if exp == 0 {
return 1
}
result := base
for i := 1; i < int(exp); i++ {
result *= base
}
return result
}

View File

@@ -0,0 +1,612 @@
package transport
import (
"context"
"fmt"
"sync"
"time"
)
// FailoverManager handles transport failover and redundancy
type FailoverManager struct {
transports map[string]*ManagedTransport
primaryTransport string
backupTransports []string
failoverPolicy FailoverPolicy
healthChecker HealthChecker
circuitBreaker *CircuitBreaker
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
metrics FailoverMetrics
notifications chan FailoverEvent
}
// ManagedTransport wraps a transport with management metadata
type ManagedTransport struct {
Transport Transport
ID string
Name string
Priority int
Status TransportStatus
LastHealthCheck time.Time
FailureCount int
LastFailure time.Time
Config TransportConfig
Metrics TransportMetrics
}
// TransportStatus represents the current status of a transport
type TransportStatus string
const (
StatusHealthy TransportStatus = "healthy"
StatusDegraded TransportStatus = "degraded"
StatusUnhealthy TransportStatus = "unhealthy"
StatusDisabled TransportStatus = "disabled"
)
// FailoverPolicy defines when and how to failover
type FailoverPolicy struct {
FailureThreshold int // Number of failures before marking unhealthy
HealthCheckInterval time.Duration // How often to check health
FailoverTimeout time.Duration // Timeout for failover operations
RetryInterval time.Duration // Interval between retry attempts
MaxRetries int // Maximum retry attempts
AutoFailback bool // Whether to automatically failback to primary
FailbackDelay time.Duration // Delay before attempting failback
RequireAllHealthy bool // Whether all transports must be healthy
}
// FailoverMetrics tracks failover statistics
type FailoverMetrics struct {
TotalFailovers int64 `json:"total_failovers"`
TotalFailbacks int64 `json:"total_failbacks"`
CurrentTransport string `json:"current_transport"`
LastFailover time.Time `json:"last_failover"`
LastFailback time.Time `json:"last_failback"`
FailoverDuration time.Duration `json:"failover_duration"`
FailoverSuccessRate float64 `json:"failover_success_rate"`
HealthCheckFailures int64 `json:"health_check_failures"`
CircuitBreakerTrips int64 `json:"circuit_breaker_trips"`
}
// FailoverEvent represents a failover-related event
type FailoverEvent struct {
Type FailoverEventType `json:"type"`
FromTransport string `json:"from_transport"`
ToTransport string `json:"to_transport"`
Reason string `json:"reason"`
Timestamp time.Time `json:"timestamp"`
Success bool `json:"success"`
Duration time.Duration `json:"duration"`
}
// FailoverEventType defines types of failover events
type FailoverEventType string
const (
EventFailover FailoverEventType = "failover"
EventFailback FailoverEventType = "failback"
EventHealthCheck FailoverEventType = "health_check"
EventCircuitBreak FailoverEventType = "circuit_break"
EventRecovery FailoverEventType = "recovery"
)
// HealthChecker interface for custom health checking logic
type HealthChecker interface {
CheckHealth(ctx context.Context, transport Transport) (bool, error)
GetHealthScore(transport Transport) float64
}
// NewFailoverManager creates a new failover manager
func NewFailoverManager(policy FailoverPolicy) *FailoverManager {
ctx, cancel := context.WithCancel(context.Background())
fm := &FailoverManager{
transports: make(map[string]*ManagedTransport),
failoverPolicy: policy,
healthChecker: NewDefaultHealthChecker(),
circuitBreaker: NewCircuitBreaker(CircuitBreakerConfig{
FailureThreshold: policy.FailureThreshold,
RecoveryTimeout: policy.RetryInterval,
MaxRetries: policy.MaxRetries,
}),
ctx: ctx,
cancel: cancel,
notifications: make(chan FailoverEvent, 100),
}
// Start background routines
go fm.healthCheckLoop()
go fm.failoverMonitorLoop()
return fm
}
// RegisterTransport adds a transport to the failover manager
func (fm *FailoverManager) RegisterTransport(id, name string, transport Transport, priority int, config TransportConfig) error {
fm.mu.Lock()
defer fm.mu.Unlock()
managedTransport := &ManagedTransport{
Transport: transport,
ID: id,
Name: name,
Priority: priority,
Status: StatusHealthy,
LastHealthCheck: time.Now(),
Config: config,
}
fm.transports[id] = managedTransport
// Set as primary if it's the first or highest priority transport
if fm.primaryTransport == "" || priority > fm.transports[fm.primaryTransport].Priority {
fm.primaryTransport = id
} else {
fm.backupTransports = append(fm.backupTransports, id)
}
return nil
}
// UnregisterTransport removes a transport from the failover manager
func (fm *FailoverManager) UnregisterTransport(id string) error {
fm.mu.Lock()
defer fm.mu.Unlock()
if _, exists := fm.transports[id]; !exists {
return fmt.Errorf("transport not found: %s", id)
}
delete(fm.transports, id)
// Update primary if needed
if fm.primaryTransport == id {
fm.selectNewPrimary()
}
// Remove from backups
for i, backupID := range fm.backupTransports {
if backupID == id {
fm.backupTransports = append(fm.backupTransports[:i], fm.backupTransports[i+1:]...)
break
}
}
return nil
}
// GetActiveTransport returns the currently active transport
func (fm *FailoverManager) GetActiveTransport() (Transport, error) {
fm.mu.RLock()
defer fm.mu.RUnlock()
if fm.primaryTransport == "" {
return nil, fmt.Errorf("no active transport available")
}
transport, exists := fm.transports[fm.primaryTransport]
if !exists {
return nil, fmt.Errorf("primary transport not found: %s", fm.primaryTransport)
}
if transport.Status == StatusHealthy || transport.Status == StatusDegraded {
return transport.Transport, nil
}
// Try to failover to a backup
if err := fm.performFailover(); err != nil {
return nil, fmt.Errorf("failover failed: %w", err)
}
// Return new primary after failover
newPrimary := fm.transports[fm.primaryTransport]
return newPrimary.Transport, nil
}
// Send sends a message through the active transport with automatic failover
func (fm *FailoverManager) Send(ctx context.Context, msg *Message) error {
transport, err := fm.GetActiveTransport()
if err != nil {
return fmt.Errorf("no available transport: %w", err)
}
// Try to send through circuit breaker
return fm.circuitBreaker.Execute(func() error {
return transport.Send(ctx, msg)
})
}
// Receive receives messages from the active transport
func (fm *FailoverManager) Receive(ctx context.Context) (<-chan *Message, error) {
transport, err := fm.GetActiveTransport()
if err != nil {
return nil, fmt.Errorf("no available transport: %w", err)
}
return transport.Receive(ctx)
}
// ForceFailover manually triggers a failover to a specific transport
func (fm *FailoverManager) ForceFailover(targetTransportID string) error {
fm.mu.Lock()
defer fm.mu.Unlock()
target, exists := fm.transports[targetTransportID]
if !exists {
return fmt.Errorf("target transport not found: %s", targetTransportID)
}
if target.Status != StatusHealthy && target.Status != StatusDegraded {
return fmt.Errorf("target transport is not healthy: %s", target.Status)
}
return fm.switchPrimary(targetTransportID, "manual failover")
}
// GetTransportStatus returns the status of all transports
func (fm *FailoverManager) GetTransportStatus() map[string]TransportStatus {
fm.mu.RLock()
defer fm.mu.RUnlock()
status := make(map[string]TransportStatus)
for id, transport := range fm.transports {
status[id] = transport.Status
}
return status
}
// GetMetrics returns failover metrics
func (fm *FailoverManager) GetMetrics() FailoverMetrics {
fm.mu.RLock()
defer fm.mu.RUnlock()
return fm.metrics
}
// GetNotifications returns a channel for failover events
func (fm *FailoverManager) GetNotifications() <-chan FailoverEvent {
return fm.notifications
}
// SetHealthChecker sets a custom health checker
func (fm *FailoverManager) SetHealthChecker(checker HealthChecker) {
fm.mu.Lock()
defer fm.mu.Unlock()
fm.healthChecker = checker
}
// Stop gracefully stops the failover manager
func (fm *FailoverManager) Stop() error {
fm.cancel()
close(fm.notifications)
return nil
}
// Private methods
func (fm *FailoverManager) healthCheckLoop() {
ticker := time.NewTicker(fm.failoverPolicy.HealthCheckInterval)
defer ticker.Stop()
for {
select {
case <-fm.ctx.Done():
return
case <-ticker.C:
fm.performHealthChecks()
}
}
}
func (fm *FailoverManager) failoverMonitorLoop() {
for {
select {
case <-fm.ctx.Done():
return
default:
if fm.shouldPerformFailover() {
if err := fm.performFailover(); err != nil {
fm.metrics.HealthCheckFailures++
}
}
if fm.shouldPerformFailback() {
if err := fm.performFailback(); err != nil {
fm.metrics.HealthCheckFailures++
}
}
time.Sleep(time.Second) // Check every second
}
}
}
func (fm *FailoverManager) performHealthChecks() {
fm.mu.Lock()
defer fm.mu.Unlock()
for id, transport := range fm.transports {
healthy, err := fm.healthChecker.CheckHealth(fm.ctx, transport.Transport)
transport.LastHealthCheck = time.Now()
previousStatus := transport.Status
if err != nil || !healthy {
transport.FailureCount++
transport.LastFailure = time.Now()
if transport.FailureCount >= fm.failoverPolicy.FailureThreshold {
transport.Status = StatusUnhealthy
} else {
transport.Status = StatusDegraded
}
} else {
// Reset failure count on successful health check
transport.FailureCount = 0
transport.Status = StatusHealthy
}
// Notify status change
if previousStatus != transport.Status {
fm.notifyEvent(FailoverEvent{
Type: EventHealthCheck,
ToTransport: id,
Reason: fmt.Sprintf("status changed from %s to %s", previousStatus, transport.Status),
Timestamp: time.Now(),
Success: transport.Status == StatusHealthy,
})
}
}
}
func (fm *FailoverManager) shouldPerformFailover() bool {
fm.mu.RLock()
defer fm.mu.RUnlock()
if fm.primaryTransport == "" {
return false
}
primary := fm.transports[fm.primaryTransport]
return primary.Status == StatusUnhealthy
}
func (fm *FailoverManager) shouldPerformFailback() bool {
if !fm.failoverPolicy.AutoFailback {
return false
}
fm.mu.RLock()
defer fm.mu.RUnlock()
// Find the highest priority healthy transport
var highestPriority int
var highestPriorityID string
for id, transport := range fm.transports {
if transport.Status == StatusHealthy && transport.Priority > highestPriority {
highestPriority = transport.Priority
highestPriorityID = id
}
}
// Failback if there's a higher priority transport available
return highestPriorityID != "" && highestPriorityID != fm.primaryTransport
}
func (fm *FailoverManager) performFailover() error {
fm.mu.Lock()
defer fm.mu.Unlock()
// Find the best backup transport
var bestBackup string
var bestPriority int
for _, backupID := range fm.backupTransports {
backup := fm.transports[backupID]
if (backup.Status == StatusHealthy || backup.Status == StatusDegraded) && backup.Priority > bestPriority {
bestBackup = backupID
bestPriority = backup.Priority
}
}
if bestBackup == "" {
return fmt.Errorf("no healthy backup transport available")
}
return fm.switchPrimary(bestBackup, "automatic failover")
}
func (fm *FailoverManager) performFailback() error {
fm.mu.Lock()
defer fm.mu.Unlock()
// Find the highest priority healthy transport
var highestPriority int
var highestPriorityID string
for id, transport := range fm.transports {
if transport.Status == StatusHealthy && transport.Priority > highestPriority {
highestPriority = transport.Priority
highestPriorityID = id
}
}
if highestPriorityID == "" || highestPriorityID == fm.primaryTransport {
return nil // No failback needed
}
// Wait for failback delay
if time.Since(fm.metrics.LastFailover) < fm.failoverPolicy.FailbackDelay {
return nil
}
return fm.switchPrimary(highestPriorityID, "automatic failback")
}
func (fm *FailoverManager) switchPrimary(newPrimaryID, reason string) error {
start := time.Now()
oldPrimary := fm.primaryTransport
// Update primary and backup lists
fm.primaryTransport = newPrimaryID
// Rebuild backup list
fm.backupTransports = make([]string, 0)
for id := range fm.transports {
if id != newPrimaryID {
fm.backupTransports = append(fm.backupTransports, id)
}
}
// Update metrics
duration := time.Since(start)
if oldPrimary != newPrimaryID {
if reason == "automatic failback" {
fm.metrics.TotalFailbacks++
fm.metrics.LastFailback = time.Now()
} else {
fm.metrics.TotalFailovers++
fm.metrics.LastFailover = time.Now()
}
fm.metrics.FailoverDuration = duration
fm.metrics.CurrentTransport = newPrimaryID
}
// Notify
eventType := EventFailover
if reason == "automatic failback" {
eventType = EventFailback
}
fm.notifyEvent(FailoverEvent{
Type: eventType,
FromTransport: oldPrimary,
ToTransport: newPrimaryID,
Reason: reason,
Timestamp: time.Now(),
Success: true,
Duration: duration,
})
return nil
}
func (fm *FailoverManager) selectNewPrimary() {
var bestID string
var bestPriority int
for id, transport := range fm.transports {
if transport.Status == StatusHealthy && transport.Priority > bestPriority {
bestID = id
bestPriority = transport.Priority
}
}
fm.primaryTransport = bestID
}
func (fm *FailoverManager) notifyEvent(event FailoverEvent) {
select {
case fm.notifications <- event:
default:
// Channel full, drop event
}
}
// DefaultHealthChecker implements basic health checking
type DefaultHealthChecker struct{}
func NewDefaultHealthChecker() *DefaultHealthChecker {
return &DefaultHealthChecker{}
}
func (dhc *DefaultHealthChecker) CheckHealth(ctx context.Context, transport Transport) (bool, error) {
health := transport.Health()
return health.Status == "healthy", nil
}
func (dhc *DefaultHealthChecker) GetHealthScore(transport Transport) float64 {
health := transport.Health()
switch health.Status {
case "healthy":
return 1.0
case "degraded":
return 0.5
default:
return 0.0
}
}
// CircuitBreaker implements circuit breaker pattern for transport operations
type CircuitBreaker struct {
config CircuitBreakerConfig
state CircuitBreakerState
failureCount int
lastFailure time.Time
mu sync.Mutex
}
type CircuitBreakerConfig struct {
FailureThreshold int
RecoveryTimeout time.Duration
MaxRetries int
}
type CircuitBreakerState string
const (
StateClosed CircuitBreakerState = "closed"
StateOpen CircuitBreakerState = "open"
StateHalfOpen CircuitBreakerState = "half_open"
)
func NewCircuitBreaker(config CircuitBreakerConfig) *CircuitBreaker {
return &CircuitBreaker{
config: config,
state: StateClosed,
}
}
func (cb *CircuitBreaker) Execute(operation func() error) error {
cb.mu.Lock()
defer cb.mu.Unlock()
if cb.state == StateOpen {
if time.Since(cb.lastFailure) < cb.config.RecoveryTimeout {
return fmt.Errorf("circuit breaker is open")
}
cb.state = StateHalfOpen
}
err := operation()
if err != nil {
cb.onFailure()
return err
}
cb.onSuccess()
return nil
}
func (cb *CircuitBreaker) onFailure() {
cb.failureCount++
cb.lastFailure = time.Now()
if cb.failureCount >= cb.config.FailureThreshold {
cb.state = StateOpen
}
}
func (cb *CircuitBreaker) onSuccess() {
cb.failureCount = 0
cb.state = StateClosed
}
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
cb.mu.Lock()
defer cb.mu.Unlock()
return cb.state
}

View File

@@ -0,0 +1,230 @@
package transport
import (
"context"
"fmt"
"sync"
"time"
)
// MemoryTransport implements in-memory message transport for local communication
type MemoryTransport struct {
channels map[string]chan *Message
metrics TransportMetrics
connected bool
mu sync.RWMutex
}
// NewMemoryTransport creates a new in-memory transport
func NewMemoryTransport() *MemoryTransport {
return &MemoryTransport{
channels: make(map[string]chan *Message),
metrics: TransportMetrics{},
}
}
// Connect establishes the transport connection
func (mt *MemoryTransport) Connect(ctx context.Context) error {
mt.mu.Lock()
defer mt.mu.Unlock()
if mt.connected {
return nil
}
mt.connected = true
mt.metrics.Connections = 1
return nil
}
// Disconnect closes the transport connection
func (mt *MemoryTransport) Disconnect(ctx context.Context) error {
mt.mu.Lock()
defer mt.mu.Unlock()
if !mt.connected {
return nil
}
// Close all channels
for _, ch := range mt.channels {
close(ch)
}
mt.channels = make(map[string]chan *Message)
mt.connected = false
mt.metrics.Connections = 0
return nil
}
// Send transmits a message through the memory transport
func (mt *MemoryTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
mt.mu.RLock()
if !mt.connected {
mt.mu.RUnlock()
mt.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Get or create channel for topic
ch, exists := mt.channels[msg.Topic]
if !exists {
mt.mu.RUnlock()
mt.mu.Lock()
// Double-check after acquiring write lock
if ch, exists = mt.channels[msg.Topic]; !exists {
ch = make(chan *Message, 1000) // Buffered channel
mt.channels[msg.Topic] = ch
}
mt.mu.Unlock()
} else {
mt.mu.RUnlock()
}
// Send message
select {
case ch <- msg:
mt.updateSendMetrics(msg, time.Since(start))
return nil
case <-ctx.Done():
mt.metrics.Errors++
return ctx.Err()
default:
mt.metrics.Errors++
return fmt.Errorf("channel full for topic: %s", msg.Topic)
}
}
// Receive returns a channel for receiving messages
func (mt *MemoryTransport) Receive(ctx context.Context) (<-chan *Message, error) {
mt.mu.RLock()
defer mt.mu.RUnlock()
if !mt.connected {
return nil, fmt.Errorf("transport not connected")
}
// Create a merged channel that receives from all topic channels
merged := make(chan *Message, 1000)
go func() {
defer close(merged)
// Use a wait group to handle multiple topic channels
var wg sync.WaitGroup
mt.mu.RLock()
for topic, ch := range mt.channels {
wg.Add(1)
go func(topicCh <-chan *Message, topicName string) {
defer wg.Done()
for {
select {
case msg, ok := <-topicCh:
if !ok {
return
}
select {
case merged <- msg:
mt.updateReceiveMetrics(msg)
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}(ch, topic)
}
mt.mu.RUnlock()
wg.Wait()
}()
return merged, nil
}
// Health returns the health status of the transport
func (mt *MemoryTransport) Health() ComponentHealth {
mt.mu.RLock()
defer mt.mu.RUnlock()
status := "unhealthy"
if mt.connected {
status = "healthy"
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: time.Microsecond, // Very fast for memory transport
ErrorCount: mt.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (mt *MemoryTransport) GetMetrics() TransportMetrics {
mt.mu.RLock()
defer mt.mu.RUnlock()
// Create a copy to avoid race conditions
return TransportMetrics{
BytesSent: mt.metrics.BytesSent,
BytesReceived: mt.metrics.BytesReceived,
MessagesSent: mt.metrics.MessagesSent,
MessagesReceived: mt.metrics.MessagesReceived,
Connections: mt.metrics.Connections,
Errors: mt.metrics.Errors,
Latency: mt.metrics.Latency,
}
}
// Private helper methods
func (mt *MemoryTransport) updateSendMetrics(msg *Message, latency time.Duration) {
mt.mu.Lock()
defer mt.mu.Unlock()
mt.metrics.MessagesSent++
mt.metrics.Latency = latency
// Estimate message size (simplified)
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
mt.metrics.BytesSent += messageSize
}
func (mt *MemoryTransport) updateReceiveMetrics(msg *Message) {
mt.mu.Lock()
defer mt.mu.Unlock()
mt.metrics.MessagesReceived++
// Estimate message size (simplified)
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
mt.metrics.BytesReceived += messageSize
}
// GetChannelForTopic returns the channel for a specific topic (for testing/debugging)
func (mt *MemoryTransport) GetChannelForTopic(topic string) (<-chan *Message, bool) {
mt.mu.RLock()
defer mt.mu.RUnlock()
ch, exists := mt.channels[topic]
return ch, exists
}
// GetTopicCount returns the number of active topic channels
func (mt *MemoryTransport) GetTopicCount() int {
mt.mu.RLock()
defer mt.mu.RUnlock()
return len(mt.channels)
}

View File

@@ -0,0 +1,410 @@
package transport
import (
"context"
"fmt"
"sync"
"time"
)
// MessageType represents the type of message being sent
type MessageType string
const (
MessageTypeEvent MessageType = "event"
MessageTypeCommand MessageType = "command"
MessageTypeQuery MessageType = "query"
MessageTypeResponse MessageType = "response"
MessageTypeNotification MessageType = "notification"
MessageTypeHeartbeat MessageType = "heartbeat"
)
// MessagePriority defines message processing priority
type MessagePriority int
const (
PriorityLow MessagePriority = iota
PriorityNormal
PriorityHigh
PriorityCritical
)
// Message represents a universal message in the system
type Message struct {
ID string `json:"id"`
Type MessageType `json:"type"`
Topic string `json:"topic"`
Source string `json:"source"`
Target string `json:"target,omitempty"`
Priority MessagePriority `json:"priority"`
Timestamp time.Time `json:"timestamp"`
Data interface{} `json:"data"`
Headers map[string]string `json:"headers,omitempty"`
CorrelationID string `json:"correlation_id,omitempty"`
TTL time.Duration `json:"ttl,omitempty"`
Retries int `json:"retries"`
MaxRetries int `json:"max_retries"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// MessageHandler processes incoming messages
type MessageHandler func(ctx context.Context, msg *Message) error
// MessageFilter determines if a message should be processed
type MessageFilter func(msg *Message) bool
// Subscription represents a topic subscription
type Subscription struct {
ID string
Topic string
Filter MessageFilter
Handler MessageHandler
Options SubscriptionOptions
created time.Time
active bool
mu sync.RWMutex
}
// SubscriptionOptions configures subscription behavior
type SubscriptionOptions struct {
QueueSize int
BatchSize int
BatchTimeout time.Duration
DLQEnabled bool
RetryEnabled bool
Persistent bool
Durable bool
}
// MessageBusInterface defines the universal message bus contract
type MessageBusInterface interface {
// Core messaging operations
Publish(ctx context.Context, msg *Message) error
Subscribe(topic string, handler MessageHandler, opts ...SubscriptionOption) (*Subscription, error)
Unsubscribe(subscriptionID string) error
// Advanced messaging patterns
Request(ctx context.Context, msg *Message, timeout time.Duration) (*Message, error)
Reply(ctx context.Context, originalMsg *Message, response *Message) error
// Topic management
CreateTopic(topic string, config TopicConfig) error
DeleteTopic(topic string) error
ListTopics() []string
GetTopicInfo(topic string) (*TopicInfo, error)
// Queue operations
QueueMessage(topic string, msg *Message) error
DequeueMessage(topic string, timeout time.Duration) (*Message, error)
PeekMessage(topic string) (*Message, error)
// Dead letter queue
GetDLQMessages(topic string) ([]*Message, error)
ReprocessDLQMessage(messageID string) error
PurgeDLQ(topic string) error
// Lifecycle management
Start(ctx context.Context) error
Stop(ctx context.Context) error
Health() HealthStatus
// Metrics and monitoring
GetMetrics() MessageBusMetrics
GetSubscriptions() []*Subscription
GetActiveConnections() int
}
// SubscriptionOption configures subscription behavior
type SubscriptionOption func(*SubscriptionOptions)
// TopicConfig defines topic configuration
type TopicConfig struct {
Persistent bool
Replicated bool
RetentionPolicy RetentionPolicy
Partitions int
MaxMessageSize int64
TTL time.Duration
}
// RetentionPolicy defines message retention behavior
type RetentionPolicy struct {
MaxMessages int
MaxAge time.Duration
MaxSize int64
}
// TopicInfo provides topic statistics
type TopicInfo struct {
Name string
Config TopicConfig
MessageCount int64
SubscriberCount int
LastActivity time.Time
SizeBytes int64
}
// HealthStatus represents system health
type HealthStatus struct {
Status string
Uptime time.Duration
LastCheck time.Time
Components map[string]ComponentHealth
Errors []HealthError
}
// ComponentHealth represents component-specific health
type ComponentHealth struct {
Status string
LastCheck time.Time
ResponseTime time.Duration
ErrorCount int64
}
// HealthError represents a health check error
type HealthError struct {
Component string
Message string
Timestamp time.Time
Severity string
}
// MessageBusMetrics provides operational metrics
type MessageBusMetrics struct {
MessagesPublished int64
MessagesConsumed int64
MessagesFailed int64
MessagesInDLQ int64
ActiveSubscriptions int
TopicCount int
AverageLatency time.Duration
ThroughputPerSec float64
ErrorRate float64
MemoryUsage int64
CPUUsage float64
}
// TransportType defines available transport mechanisms
type TransportType string
const (
TransportMemory TransportType = "memory"
TransportUnixSocket TransportType = "unix"
TransportTCP TransportType = "tcp"
TransportWebSocket TransportType = "websocket"
TransportRedis TransportType = "redis"
TransportNATS TransportType = "nats"
)
// TransportConfig configures transport layer
type TransportConfig struct {
Type TransportType
Address string
Options map[string]interface{}
RetryConfig RetryConfig
SecurityConfig SecurityConfig
}
// RetryConfig defines retry behavior
type RetryConfig struct {
MaxRetries int
InitialDelay time.Duration
MaxDelay time.Duration
BackoffFactor float64
Jitter bool
}
// SecurityConfig defines security settings
type SecurityConfig struct {
Enabled bool
TLSConfig *TLSConfig
AuthConfig *AuthConfig
Encryption bool
Compression bool
}
// TLSConfig for secure transport
type TLSConfig struct {
CertFile string
KeyFile string
CAFile string
Verify bool
}
// AuthConfig for authentication
type AuthConfig struct {
Username string
Password string
Token string
Method string
}
// UniversalMessageBus implements MessageBusInterface
type UniversalMessageBus struct {
config MessageBusConfig
transports map[TransportType]Transport
router *MessageRouter
topics map[string]*Topic
subscriptions map[string]*Subscription
dlq *DeadLetterQueue
metrics *MetricsCollector
persistence PersistenceLayer
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
started bool
}
// MessageBusConfig configures the message bus
type MessageBusConfig struct {
DefaultTransport TransportType
EnablePersistence bool
EnableMetrics bool
EnableDLQ bool
MaxMessageSize int64
DefaultTTL time.Duration
HealthCheckInterval time.Duration
CleanupInterval time.Duration
}
// Transport interface for different transport mechanisms
type Transport interface {
Send(ctx context.Context, msg *Message) error
Receive(ctx context.Context) (<-chan *Message, error)
Connect(ctx context.Context) error
Disconnect(ctx context.Context) error
Health() ComponentHealth
GetMetrics() TransportMetrics
}
// TransportMetrics for transport-specific metrics
type TransportMetrics struct {
BytesSent int64
BytesReceived int64
MessagesSent int64
MessagesReceived int64
Connections int
Errors int64
Latency time.Duration
}
// Note: MessageRouter and RoutingRule are defined in router.go
// LoadBalancer for transport selection
type LoadBalancer interface {
SelectTransport(transports []TransportType, msg *Message) TransportType
UpdateStats(transport TransportType, latency time.Duration, success bool)
}
// Topic represents a message topic
type Topic struct {
Name string
Config TopicConfig
Messages []StoredMessage
Subscribers []*Subscription
Created time.Time
LastActivity time.Time
mu sync.RWMutex
}
// StoredMessage represents a persisted message
type StoredMessage struct {
Message *Message
Stored time.Time
Processed bool
}
// Note: DeadLetterQueue and DLQConfig are defined in dlq.go
// Note: MetricsCollector is defined in serialization.go
// PersistenceLayer handles message persistence
type PersistenceLayer interface {
Store(msg *Message) error
Retrieve(id string) (*Message, error)
Delete(id string) error
List(topic string, limit int) ([]*Message, error)
Cleanup(maxAge time.Duration) error
}
// Factory functions for common subscription options
func WithQueueSize(size int) SubscriptionOption {
return func(opts *SubscriptionOptions) {
opts.QueueSize = size
}
}
func WithBatchProcessing(size int, timeout time.Duration) SubscriptionOption {
return func(opts *SubscriptionOptions) {
opts.BatchSize = size
opts.BatchTimeout = timeout
}
}
func WithDLQ(enabled bool) SubscriptionOption {
return func(opts *SubscriptionOptions) {
opts.DLQEnabled = enabled
}
}
func WithRetry(enabled bool) SubscriptionOption {
return func(opts *SubscriptionOptions) {
opts.RetryEnabled = enabled
}
}
func WithPersistence(enabled bool) SubscriptionOption {
return func(opts *SubscriptionOptions) {
opts.Persistent = enabled
}
}
// NewUniversalMessageBus creates a new message bus instance
func NewUniversalMessageBus(config MessageBusConfig) *UniversalMessageBus {
ctx, cancel := context.WithCancel(context.Background())
return &UniversalMessageBus{
config: config,
transports: make(map[TransportType]Transport),
topics: make(map[string]*Topic),
subscriptions: make(map[string]*Subscription),
router: NewMessageRouter(),
dlq: NewDeadLetterQueue(DLQConfig{}),
metrics: NewMetricsCollector(),
ctx: ctx,
cancel: cancel,
}
}
// NewMessageRouter creates a new message router
func NewMessageRouter() *MessageRouter {
return &MessageRouter{
rules: make([]RoutingRule, 0),
fallback: TransportMemory,
}
}
// Note: NewDeadLetterQueue is defined in dlq.go
// Note: NewMetricsCollector is defined in serialization.go
// Helper function to generate message ID
func GenerateMessageID() string {
return fmt.Sprintf("msg_%d_%d", time.Now().UnixNano(), time.Now().Nanosecond())
}
// Helper function to create message with defaults
func NewMessage(msgType MessageType, topic string, source string, data interface{}) *Message {
return &Message{
ID: GenerateMessageID(),
Type: msgType,
Topic: topic,
Source: source,
Priority: PriorityNormal,
Timestamp: time.Now(),
Data: data,
Headers: make(map[string]string),
Metadata: make(map[string]interface{}),
Retries: 0,
MaxRetries: 3,
}
}

View File

@@ -0,0 +1,742 @@
package transport
import (
"context"
"fmt"
"time"
)
// Publish sends a message to the specified topic
func (mb *UniversalMessageBus) Publish(ctx context.Context, msg *Message) error {
if !mb.started {
return fmt.Errorf("message bus not started")
}
// Validate message
if err := mb.validateMessage(msg); err != nil {
return fmt.Errorf("invalid message: %w", err)
}
// Set timestamp if not set
if msg.Timestamp.IsZero() {
msg.Timestamp = time.Now()
}
// Set ID if not set
if msg.ID == "" {
msg.ID = GenerateMessageID()
}
// Update metrics
mb.metrics.IncrementCounter("messages_published_total")
mb.metrics.RecordLatency("publish_latency", time.Since(msg.Timestamp))
// Route message to appropriate transport
transport, err := mb.router.RouteMessage(msg, mb.transports)
if err != nil {
mb.metrics.IncrementCounter("routing_errors_total")
return fmt.Errorf("routing failed: %w", err)
}
// Send via transport
if err := transport.Send(ctx, msg); err != nil {
mb.metrics.IncrementCounter("send_errors_total")
// Try dead letter queue if enabled
if mb.config.EnableDLQ {
if dlqErr := mb.dlq.AddMessage(msg.Topic, msg); dlqErr != nil {
return fmt.Errorf("send failed and DLQ failed: %v, original error: %w", dlqErr, err)
}
}
return fmt.Errorf("send failed: %w", err)
}
// Store in topic if persistence enabled
if mb.config.EnablePersistence {
if err := mb.addMessageToTopic(msg); err != nil {
// Log error but don't fail the publish
mb.metrics.IncrementCounter("persistence_errors_total")
}
}
// Deliver to local subscribers
go mb.deliverToSubscribers(ctx, msg)
return nil
}
// Subscribe creates a subscription to a topic
func (mb *UniversalMessageBus) Subscribe(topic string, handler MessageHandler, opts ...SubscriptionOption) (*Subscription, error) {
if !mb.started {
return nil, fmt.Errorf("message bus not started")
}
// Apply subscription options
options := SubscriptionOptions{
QueueSize: 1000,
BatchSize: 1,
BatchTimeout: time.Second,
DLQEnabled: mb.config.EnableDLQ,
RetryEnabled: true,
Persistent: false,
Durable: false,
}
for _, opt := range opts {
opt(&options)
}
// Create subscription
subscription := &Subscription{
ID: fmt.Sprintf("sub_%s_%d", topic, time.Now().UnixNano()),
Topic: topic,
Handler: handler,
Options: options,
created: time.Now(),
active: true,
}
mb.mu.Lock()
mb.subscriptions[subscription.ID] = subscription
mb.mu.Unlock()
// Add to topic subscribers
mb.addSubscriberToTopic(topic, subscription)
mb.metrics.IncrementCounter("subscriptions_created_total")
return subscription, nil
}
// Unsubscribe removes a subscription
func (mb *UniversalMessageBus) Unsubscribe(subscriptionID string) error {
mb.mu.Lock()
defer mb.mu.Unlock()
subscription, exists := mb.subscriptions[subscriptionID]
if !exists {
return fmt.Errorf("subscription not found: %s", subscriptionID)
}
// Mark as inactive
subscription.mu.Lock()
subscription.active = false
subscription.mu.Unlock()
// Remove from subscriptions map
delete(mb.subscriptions, subscriptionID)
// Remove from topic subscribers
mb.removeSubscriberFromTopic(subscription.Topic, subscriptionID)
mb.metrics.IncrementCounter("subscriptions_removed_total")
return nil
}
// Request sends a request and waits for a response
func (mb *UniversalMessageBus) Request(ctx context.Context, msg *Message, timeout time.Duration) (*Message, error) {
if !mb.started {
return nil, fmt.Errorf("message bus not started")
}
// Set correlation ID for request-response
if msg.CorrelationID == "" {
msg.CorrelationID = GenerateMessageID()
}
// Create response channel
responseChannel := make(chan *Message, 1)
defer close(responseChannel)
// Subscribe to response topic
responseTopic := fmt.Sprintf("response.%s", msg.CorrelationID)
subscription, err := mb.Subscribe(responseTopic, func(ctx context.Context, response *Message) error {
select {
case responseChannel <- response:
default:
// Channel full, ignore
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to subscribe to response topic: %w", err)
}
defer mb.Unsubscribe(subscription.ID)
// Send request
if err := mb.Publish(ctx, msg); err != nil {
return nil, fmt.Errorf("failed to publish request: %w", err)
}
// Wait for response with timeout
select {
case response := <-responseChannel:
return response, nil
case <-time.After(timeout):
return nil, fmt.Errorf("request timeout after %v", timeout)
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Reply sends a response to a request
func (mb *UniversalMessageBus) Reply(ctx context.Context, originalMsg *Message, response *Message) error {
if originalMsg.CorrelationID == "" {
return fmt.Errorf("original message has no correlation ID")
}
// Set response properties
response.Type = MessageTypeResponse
response.Topic = fmt.Sprintf("response.%s", originalMsg.CorrelationID)
response.CorrelationID = originalMsg.CorrelationID
response.Target = originalMsg.Source
return mb.Publish(ctx, response)
}
// CreateTopic creates a new topic with configuration
func (mb *UniversalMessageBus) CreateTopic(topicName string, config TopicConfig) error {
mb.mu.Lock()
defer mb.mu.Unlock()
if _, exists := mb.topics[topicName]; exists {
return fmt.Errorf("topic already exists: %s", topicName)
}
topic := &Topic{
Name: topicName,
Config: config,
Messages: make([]StoredMessage, 0),
Subscribers: make([]*Subscription, 0),
Created: time.Now(),
LastActivity: time.Now(),
}
mb.topics[topicName] = topic
mb.metrics.IncrementCounter("topics_created_total")
return nil
}
// DeleteTopic removes a topic
func (mb *UniversalMessageBus) DeleteTopic(topicName string) error {
mb.mu.Lock()
defer mb.mu.Unlock()
topic, exists := mb.topics[topicName]
if !exists {
return fmt.Errorf("topic not found: %s", topicName)
}
// Remove all subscribers
for _, sub := range topic.Subscribers {
mb.Unsubscribe(sub.ID)
}
delete(mb.topics, topicName)
mb.metrics.IncrementCounter("topics_deleted_total")
return nil
}
// ListTopics returns all topic names
func (mb *UniversalMessageBus) ListTopics() []string {
mb.mu.RLock()
defer mb.mu.RUnlock()
topics := make([]string, 0, len(mb.topics))
for name := range mb.topics {
topics = append(topics, name)
}
return topics
}
// GetTopicInfo returns topic information
func (mb *UniversalMessageBus) GetTopicInfo(topicName string) (*TopicInfo, error) {
mb.mu.RLock()
defer mb.mu.RUnlock()
topic, exists := mb.topics[topicName]
if !exists {
return nil, fmt.Errorf("topic not found: %s", topicName)
}
topic.mu.RLock()
defer topic.mu.RUnlock()
// Calculate size
var sizeBytes int64
for _, stored := range topic.Messages {
// Rough estimation of message size
sizeBytes += int64(len(fmt.Sprintf("%+v", stored.Message)))
}
return &TopicInfo{
Name: topic.Name,
Config: topic.Config,
MessageCount: int64(len(topic.Messages)),
SubscriberCount: len(topic.Subscribers),
LastActivity: topic.LastActivity,
SizeBytes: sizeBytes,
}, nil
}
// QueueMessage adds a message to a topic queue
func (mb *UniversalMessageBus) QueueMessage(topic string, msg *Message) error {
return mb.addMessageToTopic(msg)
}
// DequeueMessage removes a message from a topic queue
func (mb *UniversalMessageBus) DequeueMessage(topic string, timeout time.Duration) (*Message, error) {
start := time.Now()
for time.Since(start) < timeout {
mb.mu.RLock()
topicObj, exists := mb.topics[topic]
mb.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("topic not found: %s", topic)
}
topicObj.mu.Lock()
if len(topicObj.Messages) > 0 {
// Get first unprocessed message
for i, stored := range topicObj.Messages {
if !stored.Processed {
topicObj.Messages[i].Processed = true
topicObj.mu.Unlock()
return stored.Message, nil
}
}
}
topicObj.mu.Unlock()
// Wait a bit before trying again
time.Sleep(10 * time.Millisecond)
}
return nil, fmt.Errorf("no message available within timeout")
}
// PeekMessage returns the next message without removing it
func (mb *UniversalMessageBus) PeekMessage(topic string) (*Message, error) {
mb.mu.RLock()
topicObj, exists := mb.topics[topic]
mb.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("topic not found: %s", topic)
}
topicObj.mu.RLock()
defer topicObj.mu.RUnlock()
for _, stored := range topicObj.Messages {
if !stored.Processed {
return stored.Message, nil
}
}
return nil, fmt.Errorf("no messages available")
}
// Start initializes and starts the message bus
func (mb *UniversalMessageBus) Start(ctx context.Context) error {
mb.mu.Lock()
defer mb.mu.Unlock()
if mb.started {
return fmt.Errorf("message bus already started")
}
// Initialize default transport if none configured
if len(mb.transports) == 0 {
memTransport := NewMemoryTransport()
mb.transports[TransportMemory] = memTransport
if err := memTransport.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect default transport: %w", err)
}
}
// Start background routines
go mb.healthCheckLoop()
go mb.cleanupLoop()
go mb.metricsLoop()
mb.started = true
mb.metrics.RecordEvent("message_bus_started")
return nil
}
// Stop gracefully shuts down the message bus
func (mb *UniversalMessageBus) Stop(ctx context.Context) error {
mb.mu.Lock()
defer mb.mu.Unlock()
if !mb.started {
return nil
}
// Cancel context to stop background routines
mb.cancel()
// Disconnect all transports
for _, transport := range mb.transports {
if err := transport.Disconnect(ctx); err != nil {
// Log error but continue shutdown
}
}
mb.started = false
mb.metrics.RecordEvent("message_bus_stopped")
return nil
}
// Health returns the current health status
func (mb *UniversalMessageBus) Health() HealthStatus {
components := make(map[string]ComponentHealth)
// Check transport health
for transportType, transport := range mb.transports {
components[string(transportType)] = transport.Health()
}
// Overall status
status := "healthy"
var errors []HealthError
for name, component := range components {
if component.Status != "healthy" {
status = "degraded"
errors = append(errors, HealthError{
Component: name,
Message: fmt.Sprintf("Component %s is %s", name, component.Status),
Timestamp: time.Now(),
Severity: "warning",
})
}
}
return HealthStatus{
Status: status,
Uptime: time.Since(time.Now()), // Would track actual uptime
LastCheck: time.Now(),
Components: components,
Errors: errors,
}
}
// GetMetrics returns current operational metrics
func (mb *UniversalMessageBus) GetMetrics() MessageBusMetrics {
_ = mb.metrics.GetAll() // metrics not used
return MessageBusMetrics{
MessagesPublished: mb.getMetricInt64("messages_published_total"),
MessagesConsumed: mb.getMetricInt64("messages_consumed_total"),
MessagesFailed: mb.getMetricInt64("send_errors_total"),
MessagesInDLQ: int64(mb.dlq.GetMessageCount()),
ActiveSubscriptions: len(mb.subscriptions),
TopicCount: len(mb.topics),
AverageLatency: mb.getMetricDuration("average_latency"),
ThroughputPerSec: mb.getMetricFloat64("throughput_per_second"),
ErrorRate: mb.getMetricFloat64("error_rate"),
MemoryUsage: mb.getMetricInt64("memory_usage_bytes"),
CPUUsage: mb.getMetricFloat64("cpu_usage_percent"),
}
}
// GetSubscriptions returns all active subscriptions
func (mb *UniversalMessageBus) GetSubscriptions() []*Subscription {
mb.mu.RLock()
defer mb.mu.RUnlock()
subscriptions := make([]*Subscription, 0, len(mb.subscriptions))
for _, sub := range mb.subscriptions {
subscriptions = append(subscriptions, sub)
}
return subscriptions
}
// GetActiveConnections returns the number of active connections
func (mb *UniversalMessageBus) GetActiveConnections() int {
count := 0
for _, transport := range mb.transports {
metrics := transport.GetMetrics()
count += metrics.Connections
}
return count
}
// Helper methods
func (mb *UniversalMessageBus) validateMessage(msg *Message) error {
if msg == nil {
return fmt.Errorf("message is nil")
}
if msg.Topic == "" {
return fmt.Errorf("message topic is empty")
}
if msg.Source == "" {
return fmt.Errorf("message source is empty")
}
if msg.Data == nil {
return fmt.Errorf("message data is nil")
}
return nil
}
func (mb *UniversalMessageBus) addMessageToTopic(msg *Message) error {
mb.mu.RLock()
topic, exists := mb.topics[msg.Topic]
mb.mu.RUnlock()
if !exists {
// Create topic automatically
config := TopicConfig{
Persistent: true,
RetentionPolicy: RetentionPolicy{MaxMessages: 10000, MaxAge: 24 * time.Hour},
}
if err := mb.CreateTopic(msg.Topic, config); err != nil {
return err
}
topic = mb.topics[msg.Topic]
}
topic.mu.Lock()
defer topic.mu.Unlock()
stored := StoredMessage{
Message: msg,
Stored: time.Now(),
Processed: false,
}
topic.Messages = append(topic.Messages, stored)
topic.LastActivity = time.Now()
// Apply retention policy
mb.applyRetentionPolicy(topic)
return nil
}
func (mb *UniversalMessageBus) addSubscriberToTopic(topicName string, subscription *Subscription) {
mb.mu.RLock()
topic, exists := mb.topics[topicName]
mb.mu.RUnlock()
if !exists {
// Create topic automatically
config := TopicConfig{Persistent: false}
mb.CreateTopic(topicName, config)
topic = mb.topics[topicName]
}
topic.mu.Lock()
topic.Subscribers = append(topic.Subscribers, subscription)
topic.mu.Unlock()
}
func (mb *UniversalMessageBus) removeSubscriberFromTopic(topicName, subscriptionID string) {
mb.mu.RLock()
topic, exists := mb.topics[topicName]
mb.mu.RUnlock()
if !exists {
return
}
topic.mu.Lock()
defer topic.mu.Unlock()
for i, sub := range topic.Subscribers {
if sub.ID == subscriptionID {
topic.Subscribers = append(topic.Subscribers[:i], topic.Subscribers[i+1:]...)
break
}
}
}
func (mb *UniversalMessageBus) deliverToSubscribers(ctx context.Context, msg *Message) {
mb.mu.RLock()
topic, exists := mb.topics[msg.Topic]
mb.mu.RUnlock()
if !exists {
return
}
topic.mu.RLock()
subscribers := make([]*Subscription, len(topic.Subscribers))
copy(subscribers, topic.Subscribers)
topic.mu.RUnlock()
for _, sub := range subscribers {
sub.mu.RLock()
if !sub.active {
sub.mu.RUnlock()
continue
}
// Apply filter if present
if sub.Filter != nil && !sub.Filter(msg) {
sub.mu.RUnlock()
continue
}
handler := sub.Handler
sub.mu.RUnlock()
// Deliver message in goroutine
go func(subscription *Subscription, message *Message) {
defer func() {
if r := recover(); r != nil {
mb.metrics.IncrementCounter("handler_panics_total")
}
}()
if err := handler(ctx, message); err != nil {
mb.metrics.IncrementCounter("handler_errors_total")
if mb.config.EnableDLQ {
mb.dlq.AddMessage(message.Topic, message)
}
} else {
mb.metrics.IncrementCounter("messages_consumed_total")
}
}(sub, msg)
}
}
func (mb *UniversalMessageBus) applyRetentionPolicy(topic *Topic) {
policy := topic.Config.RetentionPolicy
// Remove old messages
if policy.MaxAge > 0 {
cutoff := time.Now().Add(-policy.MaxAge)
filtered := make([]StoredMessage, 0)
for _, stored := range topic.Messages {
if stored.Stored.After(cutoff) {
filtered = append(filtered, stored)
}
}
topic.Messages = filtered
}
// Limit number of messages
if policy.MaxMessages > 0 && len(topic.Messages) > policy.MaxMessages {
topic.Messages = topic.Messages[len(topic.Messages)-policy.MaxMessages:]
}
}
func (mb *UniversalMessageBus) healthCheckLoop() {
ticker := time.NewTicker(mb.config.HealthCheckInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
mb.performHealthCheck()
case <-mb.ctx.Done():
return
}
}
}
func (mb *UniversalMessageBus) cleanupLoop() {
ticker := time.NewTicker(mb.config.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
mb.performCleanup()
case <-mb.ctx.Done():
return
}
}
}
func (mb *UniversalMessageBus) metricsLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
mb.updateMetrics()
case <-mb.ctx.Done():
return
}
}
}
func (mb *UniversalMessageBus) performHealthCheck() {
// Check all transports
for transportType, transport := range mb.transports {
health := transport.Health()
mb.metrics.RecordGauge(fmt.Sprintf("transport_%s_healthy", transportType),
map[string]float64{"healthy": 1, "unhealthy": 0, "degraded": 0.5}[health.Status])
}
}
func (mb *UniversalMessageBus) performCleanup() {
// Clean up processed messages in topics
mb.mu.RLock()
topics := make([]*Topic, 0, len(mb.topics))
for _, topic := range mb.topics {
topics = append(topics, topic)
}
mb.mu.RUnlock()
for _, topic := range topics {
topic.mu.Lock()
mb.applyRetentionPolicy(topic)
topic.mu.Unlock()
}
// Clean up DLQ
mb.dlq.Cleanup(time.Hour * 24) // Clean messages older than 24 hours
}
func (mb *UniversalMessageBus) updateMetrics() {
// Update throughput metrics
publishedCount := mb.getMetricInt64("messages_published_total")
if publishedCount > 0 {
// Calculate per-second rate (simplified)
mb.metrics.RecordGauge("throughput_per_second", float64(publishedCount)/60.0)
}
// Update error rate
errorCount := mb.getMetricInt64("send_errors_total")
totalCount := publishedCount
if totalCount > 0 {
errorRate := float64(errorCount) / float64(totalCount)
mb.metrics.RecordGauge("error_rate", errorRate)
}
}
func (mb *UniversalMessageBus) getMetricInt64(key string) int64 {
if val, ok := mb.metrics.Get(key).(int64); ok {
return val
}
return 0
}
func (mb *UniversalMessageBus) getMetricFloat64(key string) float64 {
if val, ok := mb.metrics.Get(key).(float64); ok {
return val
}
return 0
}
func (mb *UniversalMessageBus) getMetricDuration(key string) time.Duration {
if val, ok := mb.metrics.Get(key).(time.Duration); ok {
return val
}
return 0
}

View File

@@ -0,0 +1,820 @@
package transport
import (
"bytes"
"compress/gzip"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"sync"
"time"
)
// FilePersistenceLayer implements file-based message persistence
type FilePersistenceLayer struct {
basePath string
maxFileSize int64
maxFiles int
compression bool
encryption EncryptionConfig
mu sync.RWMutex
}
// EncryptionConfig configures message encryption at rest
type EncryptionConfig struct {
Enabled bool
Algorithm string
Key []byte
}
// PersistedMessage represents a message stored on disk
type PersistedMessage struct {
ID string `json:"id"`
Topic string `json:"topic"`
Message *Message `json:"message"`
Stored time.Time `json:"stored"`
Metadata map[string]interface{} `json:"metadata"`
Encrypted bool `json:"encrypted"`
}
// PersistenceMetrics tracks persistence layer statistics
type PersistenceMetrics struct {
MessagesStored int64
MessagesRetrieved int64
MessagesDeleted int64
StorageSize int64
FileCount int
LastCleanup time.Time
Errors int64
}
// NewFilePersistenceLayer creates a new file-based persistence layer
func NewFilePersistenceLayer(basePath string) *FilePersistenceLayer {
return &FilePersistenceLayer{
basePath: basePath,
maxFileSize: 100 * 1024 * 1024, // 100MB default
maxFiles: 1000,
compression: false,
}
}
// SetMaxFileSize configures the maximum file size
func (fpl *FilePersistenceLayer) SetMaxFileSize(size int64) {
fpl.mu.Lock()
defer fpl.mu.Unlock()
fpl.maxFileSize = size
}
// SetMaxFiles configures the maximum number of files
func (fpl *FilePersistenceLayer) SetMaxFiles(count int) {
fpl.mu.Lock()
defer fpl.mu.Unlock()
fpl.maxFiles = count
}
// EnableCompression enables/disables compression
func (fpl *FilePersistenceLayer) EnableCompression(enabled bool) {
fpl.mu.Lock()
defer fpl.mu.Unlock()
fpl.compression = enabled
}
// SetEncryption configures encryption settings
func (fpl *FilePersistenceLayer) SetEncryption(config EncryptionConfig) {
fpl.mu.Lock()
defer fpl.mu.Unlock()
fpl.encryption = config
}
// Store persists a message to disk
func (fpl *FilePersistenceLayer) Store(msg *Message) error {
fpl.mu.Lock()
defer fpl.mu.Unlock()
// Create directory if it doesn't exist
topicDir := filepath.Join(fpl.basePath, msg.Topic)
if err := os.MkdirAll(topicDir, 0750); err != nil {
return fmt.Errorf("failed to create topic directory: %w", err)
}
// Create persisted message
persistedMsg := &PersistedMessage{
ID: msg.ID,
Topic: msg.Topic,
Message: msg,
Stored: time.Now(),
Metadata: make(map[string]interface{}),
}
// Serialize message
data, err := json.Marshal(persistedMsg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
// Apply encryption if enabled
if fpl.encryption.Enabled {
encryptedData, err := fpl.encrypt(data)
if err != nil {
return fmt.Errorf("encryption failed: %w", err)
}
data = encryptedData
persistedMsg.Encrypted = true
}
// Apply compression if enabled
if fpl.compression {
compressedData, err := fpl.compress(data)
if err != nil {
return fmt.Errorf("compression failed: %w", err)
}
data = compressedData
}
// Find appropriate file to write to
filename, err := fpl.getWritableFile(topicDir, len(data))
if err != nil {
return fmt.Errorf("failed to get writable file: %w", err)
}
// Write to file
file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
// Write length prefix and data
lengthPrefix := fmt.Sprintf("%d\n", len(data))
if _, err := file.WriteString(lengthPrefix); err != nil {
return fmt.Errorf("failed to write length prefix: %w", err)
}
if _, err := file.Write(data); err != nil {
return fmt.Errorf("failed to write data: %w", err)
}
return nil
}
// Retrieve loads a message from disk by ID
func (fpl *FilePersistenceLayer) Retrieve(id string) (*Message, error) {
fpl.mu.RLock()
defer fpl.mu.RUnlock()
// Search all topic directories
topicDirs, err := fpl.getTopicDirectories()
if err != nil {
return nil, fmt.Errorf("failed to get topic directories: %w", err)
}
for _, topicDir := range topicDirs {
files, err := fpl.getTopicFiles(topicDir)
if err != nil {
continue
}
for _, file := range files {
msg, err := fpl.findMessageInFile(file, id)
if err != nil {
continue
}
if msg != nil {
return msg, nil
}
}
}
return nil, fmt.Errorf("message not found: %s", id)
}
// Delete removes a message from disk by ID
func (fpl *FilePersistenceLayer) Delete(id string) error {
fpl.mu.Lock()
defer fpl.mu.Unlock()
// Search all topic directories to find the message
topicDirs, err := fpl.getTopicDirectories()
if err != nil {
return fmt.Errorf("failed to get topic directories: %w", err)
}
for _, topicDir := range topicDirs {
files, err := fpl.getTopicFiles(topicDir)
if err != nil {
continue
}
for _, file := range files {
if err := fpl.deleteMessageFromFile(file, id); err == nil {
return nil // Successfully deleted
}
}
}
return fmt.Errorf("message not found: %s", id)
}
// List returns messages for a topic with optional limit
func (fpl *FilePersistenceLayer) List(topic string, limit int) ([]*Message, error) {
fpl.mu.RLock()
defer fpl.mu.RUnlock()
topicDir := filepath.Join(fpl.basePath, topic)
if _, err := os.Stat(topicDir); os.IsNotExist(err) {
return []*Message{}, nil
}
files, err := fpl.getTopicFiles(topicDir)
if err != nil {
return nil, fmt.Errorf("failed to get topic files: %w", err)
}
var messages []*Message
count := 0
// Read files in chronological order (newest first)
sort.Slice(files, func(i, j int) bool {
infoI, _ := os.Stat(files[i])
infoJ, _ := os.Stat(files[j])
return infoI.ModTime().After(infoJ.ModTime())
})
for _, file := range files {
fileMessages, err := fpl.readMessagesFromFile(file)
if err != nil {
continue
}
for _, msg := range fileMessages {
messages = append(messages, msg)
count++
if limit > 0 && count >= limit {
break
}
}
if limit > 0 && count >= limit {
break
}
}
return messages, nil
}
// Cleanup removes messages older than maxAge
func (fpl *FilePersistenceLayer) Cleanup(maxAge time.Duration) error {
fpl.mu.Lock()
defer fpl.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
topicDirs, err := fpl.getTopicDirectories()
if err != nil {
return fmt.Errorf("failed to get topic directories: %w", err)
}
for _, topicDir := range topicDirs {
files, err := fpl.getTopicFiles(topicDir)
if err != nil {
continue
}
for _, file := range files {
// Check file modification time
info, err := os.Stat(file)
if err != nil {
continue
}
if info.ModTime().Before(cutoff) {
os.Remove(file)
}
}
// Remove empty topic directories
if isEmpty, _ := fpl.isDirectoryEmpty(topicDir); isEmpty {
os.Remove(topicDir)
}
}
return nil
}
// GetMetrics returns persistence layer metrics
func (fpl *FilePersistenceLayer) GetMetrics() (PersistenceMetrics, error) {
fpl.mu.RLock()
defer fpl.mu.RUnlock()
metrics := PersistenceMetrics{}
// Calculate storage size and file count
err := filepath.Walk(fpl.basePath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
metrics.FileCount++
metrics.StorageSize += info.Size()
}
return nil
})
return metrics, err
}
// Private helper methods
func (fpl *FilePersistenceLayer) getWritableFile(topicDir string, dataSize int) (string, error) {
files, err := fpl.getTopicFiles(topicDir)
if err != nil {
return "", err
}
// Find a file with enough space
for _, file := range files {
info, err := os.Stat(file)
if err != nil {
continue
}
if info.Size()+int64(dataSize) <= fpl.maxFileSize {
return file, nil
}
}
// Create new file
timestamp := time.Now().Format("20060102_150405")
filename := filepath.Join(topicDir, fmt.Sprintf("messages_%s.dat", timestamp))
return filename, nil
}
func (fpl *FilePersistenceLayer) getTopicDirectories() ([]string, error) {
entries, err := os.ReadDir(fpl.basePath)
if err != nil {
return nil, err
}
var dirs []string
for _, entry := range entries {
if entry.IsDir() {
dirs = append(dirs, filepath.Join(fpl.basePath, entry.Name()))
}
}
return dirs, nil
}
func (fpl *FilePersistenceLayer) getTopicFiles(topicDir string) ([]string, error) {
entries, err := os.ReadDir(topicDir)
if err != nil {
return nil, err
}
var files []string
for _, entry := range entries {
if !entry.IsDir() && filepath.Ext(entry.Name()) == ".dat" {
files = append(files, filepath.Join(topicDir, entry.Name()))
}
}
return files, nil
}
func (fpl *FilePersistenceLayer) findMessageInFile(filename, messageID string) (*Message, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}
// Parse messages from file data
messages, err := fpl.parseFileData(data)
if err != nil {
return nil, err
}
for _, msg := range messages {
if msg.ID == messageID {
return msg, nil
}
}
return nil, nil
}
func (fpl *FilePersistenceLayer) readMessagesFromFile(filename string) ([]*Message, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}
return fpl.parseFileData(data)
}
func (fpl *FilePersistenceLayer) parseFileData(data []byte) ([]*Message, error) {
var messages []*Message
offset := 0
for offset < len(data) {
// Read length prefix
lengthEnd := -1
for i := offset; i < len(data); i++ {
if data[i] == '\n' {
lengthEnd = i
break
}
}
if lengthEnd == -1 {
break // No more complete messages
}
lengthStr := string(data[offset:lengthEnd])
var messageLength int
if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil {
break // Invalid length prefix
}
messageStart := lengthEnd + 1
messageEnd := messageStart + messageLength
if messageEnd > len(data) {
break // Incomplete message
}
messageData := data[messageStart:messageEnd]
// Apply decompression if needed
if fpl.compression {
decompressed, err := fpl.decompress(messageData)
if err != nil {
offset = messageEnd
continue
}
messageData = decompressed
}
// Apply decryption if needed
if fpl.encryption.Enabled {
decrypted, err := fpl.decrypt(messageData)
if err != nil {
offset = messageEnd
continue
}
messageData = decrypted
}
// Parse message
var persistedMsg PersistedMessage
if err := json.Unmarshal(messageData, &persistedMsg); err != nil {
offset = messageEnd
continue
}
messages = append(messages, persistedMsg.Message)
offset = messageEnd
}
return messages, nil
}
func (fpl *FilePersistenceLayer) isDirectoryEmpty(dir string) (bool, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return false, err
}
return len(entries) == 0, nil
}
func (fpl *FilePersistenceLayer) encrypt(data []byte) ([]byte, error) {
if !fpl.encryption.Enabled || len(fpl.encryption.Key) == 0 {
return data, nil
}
// Create cipher block
block, err := aes.NewCipher(fpl.encryption.Key)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Generate random nonce
nonce := make([]byte, 12) // GCM standard nonce size
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
// Create GCM mode
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Encrypt and authenticate
ciphertext := gcm.Seal(nil, nonce, data, nil)
// Prepend nonce to ciphertext
result := make([]byte, len(nonce)+len(ciphertext))
copy(result, nonce)
copy(result[len(nonce):], ciphertext)
return result, nil
}
func (fpl *FilePersistenceLayer) decrypt(data []byte) ([]byte, error) {
if !fpl.encryption.Enabled || len(fpl.encryption.Key) == 0 {
return data, nil
}
if len(data) < 12 {
return nil, fmt.Errorf("encrypted data too short")
}
// Create cipher block
block, err := aes.NewCipher(fpl.encryption.Key)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM mode
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Extract nonce and ciphertext
nonce := data[:12]
ciphertext := data[12:]
// Decrypt and verify
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("decryption failed: %w", err)
}
return plaintext, nil
}
func (fpl *FilePersistenceLayer) compress(data []byte) ([]byte, error) {
var buf bytes.Buffer
gzWriter := gzip.NewWriter(&buf)
if _, err := gzWriter.Write(data); err != nil {
gzWriter.Close()
return nil, fmt.Errorf("compression failed: %w", err)
}
if err := gzWriter.Close(); err != nil {
return nil, fmt.Errorf("failed to close gzip writer: %w", err)
}
return buf.Bytes(), nil
}
func (fpl *FilePersistenceLayer) decompress(data []byte) ([]byte, error) {
buf := bytes.NewReader(data)
gzReader, err := gzip.NewReader(buf)
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzReader.Close()
decompressed, err := io.ReadAll(gzReader)
if err != nil {
return nil, fmt.Errorf("decompression failed: %w", err)
}
return decompressed, nil
}
// deleteMessageFromFile removes a specific message from a file
func (fpl *FilePersistenceLayer) deleteMessageFromFile(filename, messageID string) error {
// Read all messages from file
messages, err := fpl.readMessagesFromFile(filename)
if err != nil {
return fmt.Errorf("failed to read messages from file: %w", err)
}
// Check if message exists in this file
found := false
var filteredMessages []*Message
for _, msg := range messages {
if msg.ID != messageID {
filteredMessages = append(filteredMessages, msg)
} else {
found = true
}
}
if !found {
return fmt.Errorf("message not found in file")
}
// If no messages remain, delete the file
if len(filteredMessages) == 0 {
return os.Remove(filename)
}
// Rewrite file with remaining messages
return fpl.rewriteFileWithMessages(filename, filteredMessages)
}
// rewriteFileWithMessages rewrites a file with the given messages
func (fpl *FilePersistenceLayer) rewriteFileWithMessages(filename string, messages []*Message) error {
// Create temporary file
tempFile := filename + ".tmp"
file, err := os.Create(tempFile)
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
defer file.Close()
// Write each message to temp file
for _, msg := range messages {
// Create persisted message
persistedMsg := &PersistedMessage{
ID: msg.ID,
Topic: msg.Topic,
Message: msg,
Stored: time.Now(),
Metadata: make(map[string]interface{}),
}
// Serialize message
data, err := json.Marshal(persistedMsg)
if err != nil {
os.Remove(tempFile)
return fmt.Errorf("failed to marshal message: %w", err)
}
// Apply encryption if enabled
if fpl.encryption.Enabled {
encryptedData, err := fpl.encrypt(data)
if err != nil {
os.Remove(tempFile)
return fmt.Errorf("encryption failed: %w", err)
}
data = encryptedData
persistedMsg.Encrypted = true
}
// Apply compression if enabled
if fpl.compression {
compressedData, err := fpl.compress(data)
if err != nil {
os.Remove(tempFile)
return fmt.Errorf("compression failed: %w", err)
}
data = compressedData
}
// Write length prefix and data
lengthPrefix := fmt.Sprintf("%d\n", len(data))
if _, err := file.WriteString(lengthPrefix); err != nil {
os.Remove(tempFile)
return fmt.Errorf("failed to write length prefix: %w", err)
}
if _, err := file.Write(data); err != nil {
os.Remove(tempFile)
return fmt.Errorf("failed to write data: %w", err)
}
}
// Close temp file before rename
file.Close()
// Replace original file with temp file
return os.Rename(tempFile, filename)
}
// InMemoryPersistenceLayer implements in-memory persistence for testing/development
type InMemoryPersistenceLayer struct {
messages map[string]*Message
topics map[string][]string
mu sync.RWMutex
}
// NewInMemoryPersistenceLayer creates a new in-memory persistence layer
func NewInMemoryPersistenceLayer() *InMemoryPersistenceLayer {
return &InMemoryPersistenceLayer{
messages: make(map[string]*Message),
topics: make(map[string][]string),
}
}
// Store stores a message in memory
func (impl *InMemoryPersistenceLayer) Store(msg *Message) error {
impl.mu.Lock()
defer impl.mu.Unlock()
impl.messages[msg.ID] = msg
if _, exists := impl.topics[msg.Topic]; !exists {
impl.topics[msg.Topic] = make([]string, 0)
}
impl.topics[msg.Topic] = append(impl.topics[msg.Topic], msg.ID)
return nil
}
// Retrieve retrieves a message from memory
func (impl *InMemoryPersistenceLayer) Retrieve(id string) (*Message, error) {
impl.mu.RLock()
defer impl.mu.RUnlock()
msg, exists := impl.messages[id]
if !exists {
return nil, fmt.Errorf("message not found: %s", id)
}
return msg, nil
}
// Delete removes a message from memory
func (impl *InMemoryPersistenceLayer) Delete(id string) error {
impl.mu.Lock()
defer impl.mu.Unlock()
msg, exists := impl.messages[id]
if !exists {
return fmt.Errorf("message not found: %s", id)
}
delete(impl.messages, id)
// Remove from topic index
if messageIDs, exists := impl.topics[msg.Topic]; exists {
for i, msgID := range messageIDs {
if msgID == id {
impl.topics[msg.Topic] = append(messageIDs[:i], messageIDs[i+1:]...)
break
}
}
}
return nil
}
// List returns messages for a topic
func (impl *InMemoryPersistenceLayer) List(topic string, limit int) ([]*Message, error) {
impl.mu.RLock()
defer impl.mu.RUnlock()
messageIDs, exists := impl.topics[topic]
if !exists {
return []*Message{}, nil
}
var messages []*Message
count := 0
for _, msgID := range messageIDs {
if limit > 0 && count >= limit {
break
}
if msg, exists := impl.messages[msgID]; exists {
messages = append(messages, msg)
count++
}
}
return messages, nil
}
// Cleanup removes messages older than maxAge
func (impl *InMemoryPersistenceLayer) Cleanup(maxAge time.Duration) error {
impl.mu.Lock()
defer impl.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
var toDelete []string
for id, msg := range impl.messages {
if msg.Timestamp.Before(cutoff) {
toDelete = append(toDelete, id)
}
}
for _, id := range toDelete {
impl.Delete(id)
}
return nil
}

View File

@@ -0,0 +1,560 @@
package transport
import (
"context"
"fmt"
"net/http"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/rpc"
"golang.org/x/time/rate"
"gopkg.in/yaml.v3"
)
// ProviderConfig represents a single RPC provider configuration
type ProviderConfig struct {
Name string `yaml:"name"`
Type string `yaml:"type"`
HTTPEndpoint string `yaml:"http_endpoint"`
WSEndpoint string `yaml:"ws_endpoint"`
Priority int `yaml:"priority"`
RateLimit RateLimitConfig `yaml:"rate_limit"`
Features []string `yaml:"features"`
HealthCheck HealthCheckConfig `yaml:"health_check"`
AnvilConfig *AnvilConfig `yaml:"anvil_config,omitempty"` // For Anvil fork providers
}
// AnvilConfig defines Anvil-specific configuration
type AnvilConfig struct {
ForkURL string `yaml:"fork_url"`
ChainID int `yaml:"chain_id"`
Port int `yaml:"port"`
BlockTime int `yaml:"block_time"`
AutoImpersonate bool `yaml:"auto_impersonate"`
StateInterval int `yaml:"state_interval"`
}
// RateLimitConfig defines rate limiting parameters
type RateLimitConfig struct {
RequestsPerSecond int `yaml:"requests_per_second"`
Burst int `yaml:"burst"`
Timeout time.Duration `yaml:"timeout"`
RetryDelay time.Duration `yaml:"retry_delay"`
MaxRetries int `yaml:"max_retries"`
}
// HealthCheckConfig defines health check parameters
type HealthCheckConfig struct {
Enabled bool `yaml:"enabled"`
Interval time.Duration `yaml:"interval"`
Timeout time.Duration `yaml:"timeout"`
}
// RotationConfig defines provider rotation strategy
type RotationConfig struct {
Strategy string `yaml:"strategy"`
HealthCheckRequired bool `yaml:"health_check_required"`
FallbackEnabled bool `yaml:"fallback_enabled"`
RetryFailedAfter time.Duration `yaml:"retry_failed_after"`
}
// ProviderPoolConfig defines configuration for a provider pool
type ProviderPoolConfig struct {
Strategy string `yaml:"strategy"`
MaxConcurrentConnections int `yaml:"max_concurrent_connections"`
HealthCheckInterval string `yaml:"health_check_interval"`
FailoverEnabled bool `yaml:"failover_enabled"`
Providers []string `yaml:"providers"`
}
// ProvidersConfig represents the complete provider configuration
type ProvidersConfig struct {
ProviderPools map[string]ProviderPoolConfig `yaml:"provider_pools"`
Providers []ProviderConfig `yaml:"providers"`
Rotation RotationConfig `yaml:"rotation"`
GlobalLimits GlobalLimits `yaml:"global_limits"`
Monitoring MonitoringConfig `yaml:"monitoring"`
}
// GlobalLimits defines global connection limits
type GlobalLimits struct {
MaxConcurrentConnections int `yaml:"max_concurrent_connections"`
ConnectionTimeout time.Duration `yaml:"connection_timeout"`
ReadTimeout time.Duration `yaml:"read_timeout"`
WriteTimeout time.Duration `yaml:"write_timeout"`
IdleTimeout time.Duration `yaml:"idle_timeout"`
}
// MonitoringConfig defines monitoring settings
type MonitoringConfig struct {
Enabled bool `yaml:"enabled"`
MetricsInterval time.Duration `yaml:"metrics_interval"`
LogSlowRequests bool `yaml:"log_slow_requests"`
SlowRequestThreshold time.Duration `yaml:"slow_request_threshold"`
TrackProviderPerformance bool `yaml:"track_provider_performance"`
}
// Provider represents an active RPC provider connection
type Provider struct {
Config ProviderConfig
HTTPClient *ethclient.Client
WSClient *ethclient.Client
RateLimiter *rate.Limiter
HTTPConn *rpc.Client
WSConn *rpc.Client
IsHealthy bool
LastHealthCheck time.Time
RequestCount int64
ErrorCount int64
AvgResponseTime time.Duration
mutex sync.RWMutex
}
// ProviderManager manages multiple RPC providers with rotation and failover
type ProviderManager struct {
providers []*Provider
config ProvidersConfig
currentProvider int
mutex sync.RWMutex
healthTicker *time.Ticker
metricsTicker *time.Ticker
stopChan chan struct{}
}
// NewProviderManager creates a new provider manager from configuration
func NewProviderManager(configPath string) (*ProviderManager, error) {
// Load configuration
config, err := LoadProvidersConfig(configPath)
if err != nil {
return nil, fmt.Errorf("failed to load provider config: %w", err)
}
pm := &ProviderManager{
config: config,
stopChan: make(chan struct{}),
}
// Initialize providers
if err := pm.initializeProviders(); err != nil {
return nil, fmt.Errorf("failed to initialize providers: %w", err)
}
// Start health checks and metrics collection
pm.startBackgroundTasks()
return pm, nil
}
// LoadProvidersConfig loads provider configuration from YAML file
func LoadProvidersConfig(path string) (ProvidersConfig, error) {
var config ProvidersConfig
// Read the YAML file
data, err := os.ReadFile(path)
if err != nil {
return config, fmt.Errorf("failed to read config file %s: %w", path, err)
}
// Unmarshal the YAML data
expanded := os.ExpandEnv(string(data))
if strings.Contains(expanded, "${") {
return config, fmt.Errorf("unresolved environment variables found in provider config %s", path)
}
if err := yaml.Unmarshal([]byte(expanded), &config); err != nil {
return config, fmt.Errorf("failed to parse YAML config: %w", err)
}
// Validate the configuration
if err := validateConfig(&config); err != nil {
return config, fmt.Errorf("invalid configuration: %w", err)
}
return config, nil
}
// validateConfig validates the provider configuration
func validateConfig(config *ProvidersConfig) error {
if len(config.Providers) == 0 {
return fmt.Errorf("no providers configured")
}
for i, provider := range config.Providers {
if provider.Name == "" {
return fmt.Errorf("provider %d has no name", i)
}
if provider.HTTPEndpoint == "" && provider.WSEndpoint == "" {
return fmt.Errorf("provider %s has no endpoints", provider.Name)
}
if provider.RateLimit.RequestsPerSecond <= 0 {
return fmt.Errorf("provider %s has invalid rate limit", provider.Name)
}
}
return nil
}
// initializeProviders sets up all configured providers
func (pm *ProviderManager) initializeProviders() error {
pm.providers = make([]*Provider, 0, len(pm.config.Providers))
for _, providerConfig := range pm.config.Providers {
provider, err := createProvider(providerConfig)
if err != nil {
// Log error but continue with other providers
continue
}
pm.providers = append(pm.providers, provider)
}
if len(pm.providers) == 0 {
return fmt.Errorf("no providers successfully initialized")
}
return nil
}
// createProvider creates a new provider instance (shared utility function)
func createProvider(config ProviderConfig) (*Provider, error) {
// Create rate limiter
rateLimiter := rate.NewLimiter(
rate.Limit(config.RateLimit.RequestsPerSecond),
config.RateLimit.Burst,
)
provider := &Provider{
Config: config,
RateLimiter: rateLimiter,
IsHealthy: true, // Assume healthy until proven otherwise
}
// Initialize HTTP connection
if config.HTTPEndpoint != "" {
httpClient := &http.Client{
Timeout: config.RateLimit.Timeout, // Use config timeout
}
rpcClient, err := rpc.DialHTTPWithClient(config.HTTPEndpoint, httpClient)
if err != nil {
return nil, fmt.Errorf("failed to connect to HTTP endpoint %s: %w", config.HTTPEndpoint, err)
}
provider.HTTPConn = rpcClient
provider.HTTPClient = ethclient.NewClient(rpcClient)
}
// Initialize WebSocket connection
if config.WSEndpoint != "" {
wsClient, err := rpc.DialWebsocket(context.Background(), config.WSEndpoint, "")
if err != nil {
// Don't fail if WS connection fails, HTTP might still work
fmt.Printf("Warning: failed to connect to WebSocket endpoint %s: %v\n", config.WSEndpoint, err)
} else {
provider.WSConn = wsClient
provider.WSClient = ethclient.NewClient(wsClient)
}
}
return provider, nil
}
// GetHealthyProvider returns the next healthy provider based on rotation strategy
func (pm *ProviderManager) GetHealthyProvider() (*Provider, error) {
pm.mutex.RLock()
defer pm.mutex.RUnlock()
if len(pm.providers) == 0 {
return nil, fmt.Errorf("no providers available")
}
switch pm.config.Rotation.Strategy {
case "round_robin":
return pm.getNextRoundRobin()
case "weighted":
return pm.getWeightedProvider()
case "priority_based":
return pm.getPriorityProvider()
default:
return pm.getNextRoundRobin()
}
}
// getNextRoundRobin implements round-robin provider selection
func (pm *ProviderManager) getNextRoundRobin() (*Provider, error) {
startIndex := pm.currentProvider
for i := 0; i < len(pm.providers); i++ {
index := (startIndex + i) % len(pm.providers)
provider := pm.providers[index]
if pm.isProviderUsable(provider) {
pm.currentProvider = (index + 1) % len(pm.providers)
return provider, nil
}
}
return nil, fmt.Errorf("no healthy providers available")
}
// getPriorityProvider returns the highest priority healthy provider
func (pm *ProviderManager) getPriorityProvider() (*Provider, error) {
var bestProvider *Provider
highestPriority := int(^uint(0) >> 1) // Max int
for _, provider := range pm.providers {
if pm.isProviderUsable(provider) && provider.Config.Priority < highestPriority {
bestProvider = provider
highestPriority = provider.Config.Priority
}
}
if bestProvider == nil {
return nil, fmt.Errorf("no healthy providers available")
}
return bestProvider, nil
}
// getWeightedProvider implements weighted provider selection based on performance
func (pm *ProviderManager) getWeightedProvider() (*Provider, error) {
// For now, fallback to priority-based selection
// In a full implementation, this would consider response times and success rates
return pm.getPriorityProvider()
}
// isProviderUsable checks if a provider is healthy and within rate limits
func (pm *ProviderManager) isProviderUsable(provider *Provider) bool {
provider.mutex.RLock()
defer provider.mutex.RUnlock()
// Check health status
if pm.config.Rotation.HealthCheckRequired && !provider.IsHealthy {
return false
}
// Check rate limit
if !provider.RateLimiter.Allow() {
return false
}
return true
}
// GetHTTPClient returns an HTTP client for the current provider
func (pm *ProviderManager) GetHTTPClient() (*ethclient.Client, error) {
provider, err := pm.GetHealthyProvider()
if err != nil {
return nil, err
}
if provider.HTTPClient == nil {
return nil, fmt.Errorf("provider %s has no HTTP client", provider.Config.Name)
}
return provider.HTTPClient, nil
}
// GetWSClient returns a WebSocket client for the current provider
func (pm *ProviderManager) GetWSClient() (*ethclient.Client, error) {
provider, err := pm.GetHealthyProvider()
if err != nil {
return nil, err
}
if provider.WSClient == nil {
return nil, fmt.Errorf("provider %s has no WebSocket client", provider.Config.Name)
}
return provider.WSClient, nil
}
// GetRPCClient returns a raw RPC client for advanced operations
func (pm *ProviderManager) GetRPCClient(preferWS bool) (*rpc.Client, error) {
provider, err := pm.GetHealthyProvider()
if err != nil {
return nil, err
}
if preferWS && provider.WSConn != nil {
return provider.WSConn, nil
}
if provider.HTTPConn != nil {
return provider.HTTPConn, nil
}
return nil, fmt.Errorf("provider %s has no available RPC client", provider.Config.Name)
}
// startBackgroundTasks starts health checking and metrics collection
func (pm *ProviderManager) startBackgroundTasks() {
// Start health checks
if pm.config.Monitoring.Enabled {
pm.healthTicker = time.NewTicker(time.Minute) // Default 1 minute
go pm.healthCheckLoop()
pm.metricsTicker = time.NewTicker(pm.config.Monitoring.MetricsInterval)
go pm.metricsLoop()
}
}
// healthCheckLoop periodically checks provider health
func (pm *ProviderManager) healthCheckLoop() {
for {
select {
case <-pm.healthTicker.C:
pm.performHealthChecks()
case <-pm.stopChan:
return
}
}
}
// metricsLoop periodically collects provider metrics
func (pm *ProviderManager) metricsLoop() {
for {
select {
case <-pm.metricsTicker.C:
pm.collectMetrics()
case <-pm.stopChan:
return
}
}
}
// performHealthChecks checks all providers' health
func (pm *ProviderManager) performHealthChecks() {
for _, provider := range pm.providers {
go pm.checkProviderHealth(provider)
}
}
// checkProviderHealth performs a health check on a single provider
func (pm *ProviderManager) checkProviderHealth(provider *Provider) {
pm.performProviderHealthCheck(provider, func(ctx context.Context, provider *Provider) error {
// Try to get latest block number as health check
if provider.HTTPClient != nil {
_, err := provider.HTTPClient.BlockNumber(ctx)
return err
} else if provider.WSClient != nil {
_, err := provider.WSClient.BlockNumber(ctx)
return err
}
return fmt.Errorf("no client available for health check")
})
}
// RACE CONDITION FIX: performProviderHealthCheck executes health check with proper synchronization
func (pm *ProviderManager) performProviderHealthCheck(provider *Provider, healthChecker func(context.Context, *Provider) error) {
ctx, cancel := context.WithTimeout(context.Background(), provider.Config.HealthCheck.Timeout)
defer cancel()
start := time.Now()
err := healthChecker(ctx, provider)
duration := time.Since(start)
// RACE CONDITION FIX: Use atomic operations for counters
atomic.AddInt64(&provider.RequestCount, 1)
provider.mutex.Lock()
defer provider.mutex.Unlock()
provider.LastHealthCheck = time.Now()
if err != nil {
// RACE CONDITION FIX: Use atomic operation for error count
atomic.AddInt64(&provider.ErrorCount, 1)
provider.IsHealthy = false
} else {
provider.IsHealthy = true
}
// Update average response time
// Simple moving average calculation
if provider.AvgResponseTime == 0 {
provider.AvgResponseTime = duration
} else {
// Weight new measurement at 20% to smooth out spikes
provider.AvgResponseTime = time.Duration(
float64(provider.AvgResponseTime)*0.8 + float64(duration)*0.2,
)
}
}
// RACE CONDITION FIX: IncrementRequestCount safely increments request counter
func (p *Provider) IncrementRequestCount() {
atomic.AddInt64(&p.RequestCount, 1)
}
// RACE CONDITION FIX: IncrementErrorCount safely increments error counter
func (p *Provider) IncrementErrorCount() {
atomic.AddInt64(&p.ErrorCount, 1)
}
// RACE CONDITION FIX: GetRequestCount safely gets request count
func (p *Provider) GetRequestCount() int64 {
return atomic.LoadInt64(&p.RequestCount)
}
// RACE CONDITION FIX: GetErrorCount safely gets error count
func (p *Provider) GetErrorCount() int64 {
return atomic.LoadInt64(&p.ErrorCount)
}
// collectMetrics collects performance metrics
func (pm *ProviderManager) collectMetrics() {
// Implementation would collect and report metrics
// For now, just log basic stats
}
// Close shuts down the provider manager
func (pm *ProviderManager) Close() error {
close(pm.stopChan)
if pm.healthTicker != nil {
pm.healthTicker.Stop()
}
if pm.metricsTicker != nil {
pm.metricsTicker.Stop()
}
// Close all connections
for _, provider := range pm.providers {
if provider.HTTPConn != nil {
provider.HTTPConn.Close()
}
if provider.WSConn != nil {
provider.WSConn.Close()
}
}
return nil
}
// GetProviderStats returns current provider statistics
func (pm *ProviderManager) GetProviderStats() map[string]interface{} {
pm.mutex.RLock()
defer pm.mutex.RUnlock()
stats := make(map[string]interface{})
for _, provider := range pm.providers {
provider.mutex.RLock()
providerStats := map[string]interface{}{
"name": provider.Config.Name,
"healthy": provider.IsHealthy,
"last_health_check": provider.LastHealthCheck,
"request_count": provider.GetRequestCount(), // RACE CONDITION FIX: Use atomic getter
"error_count": provider.GetErrorCount(), // RACE CONDITION FIX: Use atomic getter
"avg_response_time": provider.AvgResponseTime,
}
provider.mutex.RUnlock()
stats[provider.Config.Name] = providerStats
}
return stats
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,513 @@
package transport
import (
"crypto/rand"
"fmt"
"math/big"
"sort"
"sync"
"time"
)
// cryptoRandInt returns a random integer in [0,n) using crypto/rand
func cryptoRandInt(n int) (int, error) {
if n <= 0 {
return 0, fmt.Errorf("n must be positive")
}
max := big.NewInt(int64(n))
randNum, err := rand.Int(rand.Reader, max)
if err != nil {
return 0, err
}
return int(randNum.Int64()), nil
}
// cryptoRandFloat returns a random float64 in [0.0,1.0) using crypto/rand
func cryptoRandFloat() (float64, error) {
// Generate a random number in [0, 2^53) and divide by 2^53
max := new(big.Int).Exp(big.NewInt(2), big.NewInt(53), nil) // 2^53
randNum, err := rand.Int(rand.Reader, max)
if err != nil {
return 0, err
}
return float64(randNum.Int64()) / float64(max.Int64()), nil
}
// MessageRouter handles intelligent message routing and transport selection
type MessageRouter struct {
rules []RoutingRule
fallback TransportType
loadBalancer LoadBalancer
mu sync.RWMutex
}
// RoutingRule defines message routing logic
type RoutingRule struct {
ID string
Name string
Condition MessageFilter
Transport TransportType
Priority int
Enabled bool
Created time.Time
LastUsed time.Time
UsageCount int64
}
// RouteMessage selects the appropriate transport for a message
func (mr *MessageRouter) RouteMessage(msg *Message, transports map[TransportType]Transport) (Transport, error) {
mr.mu.RLock()
defer mr.mu.RUnlock()
// Find matching rules (sorted by priority)
matchingRules := mr.findMatchingRules(msg)
// Try each matching rule in priority order
for _, rule := range matchingRules {
if transport, exists := transports[rule.Transport]; exists {
// Check transport health
if health := transport.Health(); health.Status == "healthy" {
mr.updateRuleUsage(rule.ID)
return transport, nil
}
}
}
// Use load balancer for available transports
if mr.loadBalancer != nil {
availableTransports := mr.getHealthyTransports(transports)
if len(availableTransports) > 0 {
selectedType := mr.loadBalancer.SelectTransport(availableTransports, msg)
if transport, exists := transports[selectedType]; exists {
return transport, nil
}
}
}
// Fall back to default transport
if fallbackTransport, exists := transports[mr.fallback]; exists {
if health := fallbackTransport.Health(); health.Status != "unhealthy" {
return fallbackTransport, nil
}
}
return nil, fmt.Errorf("no available transport for message")
}
// AddRule adds a new routing rule
func (mr *MessageRouter) AddRule(rule RoutingRule) {
mr.mu.Lock()
defer mr.mu.Unlock()
if rule.ID == "" {
rule.ID = fmt.Sprintf("rule_%d", time.Now().UnixNano())
}
rule.Created = time.Now()
rule.Enabled = true
mr.rules = append(mr.rules, rule)
mr.sortRulesByPriority()
}
// RemoveRule removes a routing rule by ID
func (mr *MessageRouter) RemoveRule(ruleID string) bool {
mr.mu.Lock()
defer mr.mu.Unlock()
for i, rule := range mr.rules {
if rule.ID == ruleID {
mr.rules = append(mr.rules[:i], mr.rules[i+1:]...)
return true
}
}
return false
}
// UpdateRule updates an existing routing rule
func (mr *MessageRouter) UpdateRule(ruleID string, updates func(*RoutingRule)) bool {
mr.mu.Lock()
defer mr.mu.Unlock()
for i := range mr.rules {
if mr.rules[i].ID == ruleID {
updates(&mr.rules[i])
mr.sortRulesByPriority()
return true
}
}
return false
}
// GetRules returns all routing rules
func (mr *MessageRouter) GetRules() []RoutingRule {
mr.mu.RLock()
defer mr.mu.RUnlock()
rules := make([]RoutingRule, len(mr.rules))
copy(rules, mr.rules)
return rules
}
// EnableRule enables a routing rule
func (mr *MessageRouter) EnableRule(ruleID string) bool {
return mr.UpdateRule(ruleID, func(rule *RoutingRule) {
rule.Enabled = true
})
}
// DisableRule disables a routing rule
func (mr *MessageRouter) DisableRule(ruleID string) bool {
return mr.UpdateRule(ruleID, func(rule *RoutingRule) {
rule.Enabled = false
})
}
// SetFallbackTransport sets the fallback transport type
func (mr *MessageRouter) SetFallbackTransport(transportType TransportType) {
mr.mu.Lock()
defer mr.mu.Unlock()
mr.fallback = transportType
}
// SetLoadBalancer sets the load balancer
func (mr *MessageRouter) SetLoadBalancer(lb LoadBalancer) {
mr.mu.Lock()
defer mr.mu.Unlock()
mr.loadBalancer = lb
}
// Private helper methods
func (mr *MessageRouter) findMatchingRules(msg *Message) []RoutingRule {
var matching []RoutingRule
for _, rule := range mr.rules {
if rule.Enabled && (rule.Condition == nil || rule.Condition(msg)) {
matching = append(matching, rule)
}
}
return matching
}
func (mr *MessageRouter) sortRulesByPriority() {
sort.Slice(mr.rules, func(i, j int) bool {
return mr.rules[i].Priority > mr.rules[j].Priority
})
}
func (mr *MessageRouter) updateRuleUsage(ruleID string) {
for i := range mr.rules {
if mr.rules[i].ID == ruleID {
mr.rules[i].LastUsed = time.Now()
mr.rules[i].UsageCount++
break
}
}
}
func (mr *MessageRouter) getHealthyTransports(transports map[TransportType]Transport) []TransportType {
var healthy []TransportType
for transportType, transport := range transports {
if health := transport.Health(); health.Status == "healthy" {
healthy = append(healthy, transportType)
}
}
return healthy
}
// LoadBalancer implementations
// RoundRobinLoadBalancer implements round-robin load balancing
type RoundRobinLoadBalancer struct {
counter int64
mu sync.Mutex
}
func NewRoundRobinLoadBalancer() *RoundRobinLoadBalancer {
return &RoundRobinLoadBalancer{}
}
func (lb *RoundRobinLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
if len(transports) == 0 {
return ""
}
lb.mu.Lock()
defer lb.mu.Unlock()
selected := transports[lb.counter%int64(len(transports))]
lb.counter++
return selected
}
func (lb *RoundRobinLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
// Round-robin doesn't use stats
}
// WeightedLoadBalancer implements weighted load balancing based on performance
type WeightedLoadBalancer struct {
stats map[TransportType]*TransportStats
mu sync.RWMutex
}
type TransportStats struct {
TotalRequests int64
SuccessRequests int64
TotalLatency time.Duration
LastUpdate time.Time
Weight float64
}
func NewWeightedLoadBalancer() *WeightedLoadBalancer {
return &WeightedLoadBalancer{
stats: make(map[TransportType]*TransportStats),
}
}
func (lb *WeightedLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
if len(transports) == 0 {
return ""
}
lb.mu.RLock()
defer lb.mu.RUnlock()
// Calculate weights and select based on weighted random selection
totalWeight := 0.0
weights := make(map[TransportType]float64)
for _, transport := range transports {
weight := lb.calculateWeight(transport)
weights[transport] = weight
totalWeight += weight
}
if totalWeight == 0 {
// Fall back to random selection using crypto/rand
idx, err := cryptoRandInt(len(transports))
if err != nil {
// Fallback to first transport if crypto/rand fails
return transports[0]
}
return transports[idx]
}
// Weighted random selection
targetF, err := cryptoRandFloat()
if err != nil {
// Fallback to first transport if crypto/rand fails
return transports[0]
}
target := targetF * totalWeight
current := 0.0
for _, transport := range transports {
current += weights[transport]
if current >= target {
return transport
}
}
// Fallback (shouldn't happen)
return transports[0]
}
func (lb *WeightedLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
lb.mu.Lock()
defer lb.mu.Unlock()
stats, exists := lb.stats[transport]
if !exists {
stats = &TransportStats{
Weight: 1.0, // Default weight
}
lb.stats[transport] = stats
}
stats.TotalRequests++
stats.TotalLatency += latency
stats.LastUpdate = time.Now()
if success {
stats.SuccessRequests++
}
// Recalculate weight based on performance
stats.Weight = lb.calculateWeight(transport)
}
func (lb *WeightedLoadBalancer) calculateWeight(transport TransportType) float64 {
stats, exists := lb.stats[transport]
if !exists {
return 1.0 // Default weight for unknown transports
}
if stats.TotalRequests == 0 {
return 1.0
}
// Calculate success rate
successRate := float64(stats.SuccessRequests) / float64(stats.TotalRequests)
// Calculate average latency
avgLatency := stats.TotalLatency / time.Duration(stats.TotalRequests)
// Weight formula: success rate / (latency factor)
// Lower latency and higher success rate = higher weight
latencyFactor := float64(avgLatency) / float64(time.Millisecond)
if latencyFactor < 1 {
latencyFactor = 1
}
weight := successRate / latencyFactor
// Ensure minimum weight
if weight < 0.1 {
weight = 0.1
}
return weight
}
// LeastLatencyLoadBalancer selects the transport with the lowest latency
type LeastLatencyLoadBalancer struct {
stats map[TransportType]*LatencyStats
mu sync.RWMutex
}
type LatencyStats struct {
RecentLatencies []time.Duration
MaxSamples int
LastUpdate time.Time
}
func NewLeastLatencyLoadBalancer() *LeastLatencyLoadBalancer {
return &LeastLatencyLoadBalancer{
stats: make(map[TransportType]*LatencyStats),
}
}
func (lb *LeastLatencyLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
if len(transports) == 0 {
return ""
}
lb.mu.RLock()
defer lb.mu.RUnlock()
bestTransport := transports[0]
bestLatency := time.Hour // Large initial value
for _, transport := range transports {
avgLatency := lb.getAverageLatency(transport)
if avgLatency < bestLatency {
bestLatency = avgLatency
bestTransport = transport
}
}
return bestTransport
}
func (lb *LeastLatencyLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
if !success {
return // Only track successful requests
}
lb.mu.Lock()
defer lb.mu.Unlock()
stats, exists := lb.stats[transport]
if !exists {
stats = &LatencyStats{
RecentLatencies: make([]time.Duration, 0),
MaxSamples: 10, // Keep last 10 samples
}
lb.stats[transport] = stats
}
// Add new latency sample
stats.RecentLatencies = append(stats.RecentLatencies, latency)
// Keep only recent samples
if len(stats.RecentLatencies) > stats.MaxSamples {
stats.RecentLatencies = stats.RecentLatencies[1:]
}
stats.LastUpdate = time.Now()
}
func (lb *LeastLatencyLoadBalancer) getAverageLatency(transport TransportType) time.Duration {
stats, exists := lb.stats[transport]
if !exists || len(stats.RecentLatencies) == 0 {
return time.Millisecond * 100 // Default estimate
}
total := time.Duration(0)
for _, latency := range stats.RecentLatencies {
total += latency
}
return total / time.Duration(len(stats.RecentLatencies))
}
// Common routing rule factory functions
// CreateTopicRule creates a rule based on message topic
func CreateTopicRule(name string, topic string, transport TransportType, priority int) RoutingRule {
return RoutingRule{
Name: name,
Condition: func(msg *Message) bool { return msg.Topic == topic },
Transport: transport,
Priority: priority,
}
}
// CreateTopicPatternRule creates a rule based on topic pattern matching
func CreateTopicPatternRule(name string, pattern string, transport TransportType, priority int) RoutingRule {
return RoutingRule{
Name: name,
Condition: func(msg *Message) bool {
// Simple pattern matching (can be enhanced with regex)
return msg.Topic == pattern ||
(len(pattern) > 0 && pattern[len(pattern)-1] == '*' &&
len(msg.Topic) >= len(pattern)-1 &&
msg.Topic[:len(pattern)-1] == pattern[:len(pattern)-1])
},
Transport: transport,
Priority: priority,
}
}
// CreatePriorityRule creates a rule based on message priority
func CreatePriorityRule(name string, msgPriority MessagePriority, transport TransportType, priority int) RoutingRule {
return RoutingRule{
Name: name,
Condition: func(msg *Message) bool { return msg.Priority == msgPriority },
Transport: transport,
Priority: priority,
}
}
// CreateTypeRule creates a rule based on message type
func CreateTypeRule(name string, msgType MessageType, transport TransportType, priority int) RoutingRule {
return RoutingRule{
Name: name,
Condition: func(msg *Message) bool { return msg.Type == msgType },
Transport: transport,
Priority: priority,
}
}
// CreateSourceRule creates a rule based on message source
func CreateSourceRule(name string, source string, transport TransportType, priority int) RoutingRule {
return RoutingRule{
Name: name,
Condition: func(msg *Message) bool { return msg.Source == source },
Transport: transport,
Priority: priority,
}
}

View File

@@ -0,0 +1,629 @@
package transport
import (
"bytes"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"sync"
"time"
)
// SerializationFormat defines supported serialization formats
type SerializationFormat string
const (
SerializationJSON SerializationFormat = "json"
SerializationMsgPack SerializationFormat = "msgpack"
SerializationProtobuf SerializationFormat = "protobuf"
SerializationAvro SerializationFormat = "avro"
)
// CompressionType defines supported compression algorithms
type CompressionType string
const (
CompressionNone CompressionType = "none"
CompressionGZip CompressionType = "gzip"
CompressionLZ4 CompressionType = "lz4"
CompressionSnappy CompressionType = "snappy"
)
// SerializationConfig configures serialization behavior
type SerializationConfig struct {
Format SerializationFormat
Compression CompressionType
Encryption bool
Validation bool
}
// SerializedMessage represents a serialized message with metadata
type SerializedMessage struct {
Format SerializationFormat `json:"format"`
Compression CompressionType `json:"compression"`
Encrypted bool `json:"encrypted"`
Checksum string `json:"checksum"`
Data []byte `json:"data"`
Size int `json:"size"`
Timestamp int64 `json:"timestamp"`
}
// Serializer interface defines serialization operations
type Serializer interface {
Serialize(msg *Message) (*SerializedMessage, error)
Deserialize(serialized *SerializedMessage) (*Message, error)
GetFormat() SerializationFormat
GetConfig() SerializationConfig
}
// SerializationLayer manages multiple serializers and format selection
type SerializationLayer struct {
serializers map[SerializationFormat]Serializer
defaultFormat SerializationFormat
compressor Compressor
encryptor Encryptor
validator Validator
mu sync.RWMutex
}
// Compressor interface for data compression
type Compressor interface {
Compress(data []byte, algorithm CompressionType) ([]byte, error)
Decompress(data []byte, algorithm CompressionType) ([]byte, error)
GetSupportedAlgorithms() []CompressionType
}
// Encryptor interface for data encryption
type Encryptor interface {
Encrypt(data []byte) ([]byte, error)
Decrypt(data []byte) ([]byte, error)
IsEnabled() bool
}
// Validator interface for data validation
type Validator interface {
Validate(msg *Message) error
GenerateChecksum(data []byte) string
VerifyChecksum(data []byte, checksum string) bool
}
// NewSerializationLayer creates a new serialization layer
func NewSerializationLayer() *SerializationLayer {
sl := &SerializationLayer{
serializers: make(map[SerializationFormat]Serializer),
defaultFormat: SerializationJSON,
compressor: NewDefaultCompressor(),
validator: NewDefaultValidator(),
}
// Register default serializers
sl.RegisterSerializer(NewJSONSerializer())
return sl
}
// RegisterSerializer registers a new serializer
func (sl *SerializationLayer) RegisterSerializer(serializer Serializer) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.serializers[serializer.GetFormat()] = serializer
}
// SetDefaultFormat sets the default serialization format
func (sl *SerializationLayer) SetDefaultFormat(format SerializationFormat) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.defaultFormat = format
}
// SetCompressor sets the compression handler
func (sl *SerializationLayer) SetCompressor(compressor Compressor) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.compressor = compressor
}
// SetEncryptor sets the encryption handler
func (sl *SerializationLayer) SetEncryptor(encryptor Encryptor) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.encryptor = encryptor
}
// SetValidator sets the validation handler
func (sl *SerializationLayer) SetValidator(validator Validator) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.validator = validator
}
// Serialize serializes a message using the specified or default format
func (sl *SerializationLayer) Serialize(msg *Message, format ...SerializationFormat) (*SerializedMessage, error) {
sl.mu.RLock()
defer sl.mu.RUnlock()
// Determine format to use
selectedFormat := sl.defaultFormat
if len(format) > 0 {
selectedFormat = format[0]
}
// Get serializer
serializer, exists := sl.serializers[selectedFormat]
if !exists {
return nil, fmt.Errorf("unsupported serialization format: %s", selectedFormat)
}
// Validate message if validator is configured
if sl.validator != nil {
if err := sl.validator.Validate(msg); err != nil {
return nil, fmt.Errorf("message validation failed: %w", err)
}
}
// Serialize message
serialized, err := serializer.Serialize(msg)
if err != nil {
return nil, fmt.Errorf("serialization failed: %w", err)
}
// Apply compression if configured
config := serializer.GetConfig()
if config.Compression != CompressionNone && sl.compressor != nil {
compressed, err := sl.compressor.Compress(serialized.Data, config.Compression)
if err != nil {
return nil, fmt.Errorf("compression failed: %w", err)
}
serialized.Data = compressed
serialized.Compression = config.Compression
}
// Apply encryption if configured
if config.Encryption && sl.encryptor != nil && sl.encryptor.IsEnabled() {
encrypted, err := sl.encryptor.Encrypt(serialized.Data)
if err != nil {
return nil, fmt.Errorf("encryption failed: %w", err)
}
serialized.Data = encrypted
serialized.Encrypted = true
}
// Generate checksum
if sl.validator != nil {
serialized.Checksum = sl.validator.GenerateChecksum(serialized.Data)
}
// Update metadata
serialized.Size = len(serialized.Data)
serialized.Timestamp = msg.Timestamp.UnixNano()
return serialized, nil
}
// Deserialize deserializes a message
func (sl *SerializationLayer) Deserialize(serialized *SerializedMessage) (*Message, error) {
sl.mu.RLock()
defer sl.mu.RUnlock()
// Verify checksum if available
if sl.validator != nil && serialized.Checksum != "" {
if !sl.validator.VerifyChecksum(serialized.Data, serialized.Checksum) {
return nil, fmt.Errorf("checksum verification failed")
}
}
data := serialized.Data
// Apply decryption if needed
if serialized.Encrypted && sl.encryptor != nil && sl.encryptor.IsEnabled() {
decrypted, err := sl.encryptor.Decrypt(data)
if err != nil {
return nil, fmt.Errorf("decryption failed: %w", err)
}
data = decrypted
}
// Apply decompression if needed
if serialized.Compression != CompressionNone && sl.compressor != nil {
decompressed, err := sl.compressor.Decompress(data, serialized.Compression)
if err != nil {
return nil, fmt.Errorf("decompression failed: %w", err)
}
data = decompressed
}
// Get serializer
serializer, exists := sl.serializers[serialized.Format]
if !exists {
return nil, fmt.Errorf("unsupported serialization format: %s", serialized.Format)
}
// Create temporary serialized message for deserializer
tempSerialized := &SerializedMessage{
Format: serialized.Format,
Data: data,
}
// Deserialize message
msg, err := serializer.Deserialize(tempSerialized)
if err != nil {
return nil, fmt.Errorf("deserialization failed: %w", err)
}
// Validate deserialized message if validator is configured
if sl.validator != nil {
if err := sl.validator.Validate(msg); err != nil {
return nil, fmt.Errorf("deserialized message validation failed: %w", err)
}
}
return msg, nil
}
// GetSupportedFormats returns all supported serialization formats
func (sl *SerializationLayer) GetSupportedFormats() []SerializationFormat {
sl.mu.RLock()
defer sl.mu.RUnlock()
formats := make([]SerializationFormat, 0, len(sl.serializers))
for format := range sl.serializers {
formats = append(formats, format)
}
return formats
}
// JSONSerializer implements JSON serialization
type JSONSerializer struct {
config SerializationConfig
}
// NewJSONSerializer creates a new JSON serializer
func NewJSONSerializer() *JSONSerializer {
return &JSONSerializer{
config: SerializationConfig{
Format: SerializationJSON,
Compression: CompressionNone,
Encryption: false,
Validation: true,
},
}
}
// SetConfig updates the serializer configuration
func (js *JSONSerializer) SetConfig(config SerializationConfig) {
js.config = config
js.config.Format = SerializationJSON // Ensure format is correct
}
// Serialize serializes a message to JSON
func (js *JSONSerializer) Serialize(msg *Message) (*SerializedMessage, error) {
data, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("JSON marshal failed: %w", err)
}
return &SerializedMessage{
Format: SerializationJSON,
Compression: CompressionNone,
Encrypted: false,
Data: data,
Size: len(data),
}, nil
}
// Deserialize deserializes a message from JSON
func (js *JSONSerializer) Deserialize(serialized *SerializedMessage) (*Message, error) {
var msg Message
if err := json.Unmarshal(serialized.Data, &msg); err != nil {
return nil, fmt.Errorf("JSON unmarshal failed: %w", err)
}
return &msg, nil
}
// GetFormat returns the serialization format
func (js *JSONSerializer) GetFormat() SerializationFormat {
return SerializationJSON
}
// GetConfig returns the serializer configuration
func (js *JSONSerializer) GetConfig() SerializationConfig {
return js.config
}
// DefaultCompressor implements basic compression operations
type DefaultCompressor struct {
supportedAlgorithms []CompressionType
}
// NewDefaultCompressor creates a new default compressor
func NewDefaultCompressor() *DefaultCompressor {
return &DefaultCompressor{
supportedAlgorithms: []CompressionType{
CompressionNone,
CompressionGZip,
},
}
}
// Compress compresses data using the specified algorithm
func (dc *DefaultCompressor) Compress(data []byte, algorithm CompressionType) ([]byte, error) {
switch algorithm {
case CompressionNone:
return data, nil
case CompressionGZip:
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
if _, err := writer.Write(data); err != nil {
return nil, fmt.Errorf("gzip write failed: %w", err)
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("gzip close failed: %w", err)
}
return buf.Bytes(), nil
default:
return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm)
}
}
// Decompress decompresses data using the specified algorithm
func (dc *DefaultCompressor) Decompress(data []byte, algorithm CompressionType) ([]byte, error) {
switch algorithm {
case CompressionNone:
return data, nil
case CompressionGZip:
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("gzip reader creation failed: %w", err)
}
defer reader.Close()
decompressed, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("gzip read failed: %w", err)
}
return decompressed, nil
default:
return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm)
}
}
// GetSupportedAlgorithms returns supported compression algorithms
func (dc *DefaultCompressor) GetSupportedAlgorithms() []CompressionType {
return dc.supportedAlgorithms
}
// DefaultValidator implements basic message validation
type DefaultValidator struct {
strictMode bool
}
// NewDefaultValidator creates a new default validator
func NewDefaultValidator() *DefaultValidator {
return &DefaultValidator{
strictMode: false,
}
}
// SetStrictMode enables/disables strict validation
func (dv *DefaultValidator) SetStrictMode(enabled bool) {
dv.strictMode = enabled
}
// Validate validates a message
func (dv *DefaultValidator) Validate(msg *Message) error {
if msg == nil {
return fmt.Errorf("message is nil")
}
if msg.ID == "" {
return fmt.Errorf("message ID is empty")
}
if msg.Topic == "" {
return fmt.Errorf("message topic is empty")
}
if msg.Source == "" {
return fmt.Errorf("message source is empty")
}
if msg.Type == "" {
return fmt.Errorf("message type is empty")
}
if dv.strictMode {
if msg.Data == nil {
return fmt.Errorf("message data is nil")
}
if msg.Timestamp.IsZero() {
return fmt.Errorf("message timestamp is zero")
}
}
return nil
}
// GenerateChecksum generates a simple checksum for data
func (dv *DefaultValidator) GenerateChecksum(data []byte) string {
// Simple checksum implementation
// In production, use a proper hash function like SHA-256
var sum uint32
for _, b := range data {
sum += uint32(b)
}
return fmt.Sprintf("%08x", sum)
}
// VerifyChecksum verifies a checksum
func (dv *DefaultValidator) VerifyChecksum(data []byte, checksum string) bool {
return dv.GenerateChecksum(data) == checksum
}
// NoOpEncryptor implements a no-operation encryptor for testing
type NoOpEncryptor struct {
enabled bool
}
// NewNoOpEncryptor creates a new no-op encryptor
func NewNoOpEncryptor() *NoOpEncryptor {
return &NoOpEncryptor{enabled: false}
}
// SetEnabled enables/disables the encryptor
func (noe *NoOpEncryptor) SetEnabled(enabled bool) {
noe.enabled = enabled
}
// Encrypt returns data unchanged
func (noe *NoOpEncryptor) Encrypt(data []byte) ([]byte, error) {
return data, nil
}
// Decrypt returns data unchanged
func (noe *NoOpEncryptor) Decrypt(data []byte) ([]byte, error) {
return data, nil
}
// IsEnabled returns whether encryption is enabled
func (noe *NoOpEncryptor) IsEnabled() bool {
return noe.enabled
}
// SerializationMetrics tracks serialization performance
type SerializationMetrics struct {
SerializedMessages int64
DeserializedMessages int64
SerializationErrors int64
CompressionRatio float64
AverageMessageSize int64
TotalDataProcessed int64
}
// MetricsCollector collects serialization metrics
type MetricsCollector struct {
metrics SerializationMetrics
mu sync.RWMutex
}
// NewMetricsCollector creates a new metrics collector
func NewMetricsCollector() *MetricsCollector {
return &MetricsCollector{}
}
// RecordSerialization records a serialization operation
func (mc *MetricsCollector) RecordSerialization(originalSize, serializedSize int) {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.metrics.SerializedMessages++
mc.metrics.TotalDataProcessed += int64(originalSize)
// Update compression ratio
if originalSize > 0 {
ratio := float64(serializedSize) / float64(originalSize)
mc.metrics.CompressionRatio = (mc.metrics.CompressionRatio + ratio) / 2
}
// Update average message size
mc.metrics.AverageMessageSize = mc.metrics.TotalDataProcessed / mc.metrics.SerializedMessages
}
// RecordDeserialization records a deserialization operation
func (mc *MetricsCollector) RecordDeserialization() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.metrics.DeserializedMessages++
}
// RecordError records a serialization error
func (mc *MetricsCollector) RecordError() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.metrics.SerializationErrors++
}
// IncrementCounter increments a named counter
func (mc *MetricsCollector) IncrementCounter(name string) {
mc.mu.Lock()
defer mc.mu.Unlock()
// For simplicity, map all counters to serialization errors for now
mc.metrics.SerializationErrors++
}
// RecordLatency records a latency metric
func (mc *MetricsCollector) RecordLatency(name string, duration time.Duration) {
mc.mu.Lock()
defer mc.mu.Unlock()
// For now, we don't track specific latencies
// This can be enhanced later with proper metrics storage
}
// RecordEvent records an event metric
func (mc *MetricsCollector) RecordEvent(name string) {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.metrics.SerializationErrors++ // Simple implementation
}
// RecordGauge records a gauge metric
func (mc *MetricsCollector) RecordGauge(name string, value float64) {
mc.mu.Lock()
defer mc.mu.Unlock()
// Simple implementation - not storing actual values
}
// GetAll returns all metrics
func (mc *MetricsCollector) GetAll() map[string]interface{} {
mc.mu.RLock()
defer mc.mu.RUnlock()
return map[string]interface{}{
"serialized_messages": mc.metrics.SerializedMessages,
"deserialized_messages": mc.metrics.DeserializedMessages,
"serialization_errors": mc.metrics.SerializationErrors,
"compression_ratio": mc.metrics.CompressionRatio,
"average_message_size": mc.metrics.AverageMessageSize,
"total_data_processed": mc.metrics.TotalDataProcessed,
}
}
// Get returns a specific metric
func (mc *MetricsCollector) Get(name string) interface{} {
mc.mu.RLock()
defer mc.mu.RUnlock()
switch name {
case "serialized_messages":
return mc.metrics.SerializedMessages
case "deserialized_messages":
return mc.metrics.DeserializedMessages
case "serialization_errors":
return mc.metrics.SerializationErrors
case "compression_ratio":
return mc.metrics.CompressionRatio
default:
return nil
}
}
// GetMetrics returns current metrics
func (mc *MetricsCollector) GetMetrics() SerializationMetrics {
mc.mu.RLock()
defer mc.mu.RUnlock()
return mc.metrics
}
// Reset resets all metrics
func (mc *MetricsCollector) Reset() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.metrics = SerializationMetrics{}
}

View File

@@ -0,0 +1,451 @@
package transport
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net"
"sync"
"time"
)
// TCPTransport implements TCP transport for remote communication
type TCPTransport struct {
address string
port int
listener net.Listener
connections map[string]net.Conn
metrics TransportMetrics
connected bool
isServer bool
receiveChan chan *Message
tlsConfig *tls.Config
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
retryConfig RetryConfig
}
// NewTCPTransport creates a new TCP transport
func NewTCPTransport(address string, port int, isServer bool) *TCPTransport {
ctx, cancel := context.WithCancel(context.Background())
return &TCPTransport{
address: address,
port: port,
connections: make(map[string]net.Conn),
metrics: TransportMetrics{},
isServer: isServer,
receiveChan: make(chan *Message, 1000),
ctx: ctx,
cancel: cancel,
retryConfig: RetryConfig{
MaxRetries: 3,
InitialDelay: time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
Jitter: true,
},
}
}
// SetTLSConfig configures TLS for secure communication
func (tt *TCPTransport) SetTLSConfig(config *tls.Config) {
tt.tlsConfig = config
}
// SetRetryConfig configures retry behavior
func (tt *TCPTransport) SetRetryConfig(config RetryConfig) {
tt.retryConfig = config
}
// Connect establishes the TCP connection
func (tt *TCPTransport) Connect(ctx context.Context) error {
tt.mu.Lock()
defer tt.mu.Unlock()
if tt.connected {
return nil
}
if tt.isServer {
return tt.startServer()
} else {
return tt.connectToServer(ctx)
}
}
// Disconnect closes the TCP connection
func (tt *TCPTransport) Disconnect(ctx context.Context) error {
tt.mu.Lock()
defer tt.mu.Unlock()
if !tt.connected {
return nil
}
tt.cancel()
if tt.isServer && tt.listener != nil {
tt.listener.Close()
}
// Close all connections
for id, conn := range tt.connections {
conn.Close()
delete(tt.connections, id)
}
close(tt.receiveChan)
tt.connected = false
tt.metrics.Connections = 0
return nil
}
// Send transmits a message through TCP
func (tt *TCPTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
tt.mu.RLock()
if !tt.connected {
tt.mu.RUnlock()
tt.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Serialize message
data, err := json.Marshal(msg)
if err != nil {
tt.mu.RUnlock()
tt.metrics.Errors++
return fmt.Errorf("failed to marshal message: %w", err)
}
// Add length prefix for framing
frame := fmt.Sprintf("%d\n%s", len(data), data)
frameBytes := []byte(frame)
// Send to all connections with retry
var sendErr error
connectionCount := len(tt.connections)
tt.mu.RUnlock()
if connectionCount == 0 {
tt.metrics.Errors++
return fmt.Errorf("no active connections")
}
tt.mu.RLock()
for connID, conn := range tt.connections {
if err := tt.sendWithRetry(ctx, conn, frameBytes); err != nil {
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
// Remove failed connection
go tt.removeConnection(connID)
}
}
tt.mu.RUnlock()
if sendErr == nil {
tt.updateSendMetrics(msg, time.Since(start))
} else {
tt.metrics.Errors++
}
return sendErr
}
// Receive returns a channel for receiving messages
func (tt *TCPTransport) Receive(ctx context.Context) (<-chan *Message, error) {
tt.mu.RLock()
defer tt.mu.RUnlock()
if !tt.connected {
return nil, fmt.Errorf("transport not connected")
}
return tt.receiveChan, nil
}
// Health returns the health status of the transport
func (tt *TCPTransport) Health() ComponentHealth {
tt.mu.RLock()
defer tt.mu.RUnlock()
status := "unhealthy"
var responseTime time.Duration
if tt.connected {
if len(tt.connections) > 0 {
status = "healthy"
responseTime = time.Millisecond * 10 // Estimate for TCP
} else {
status = "degraded" // Connected but no active connections
}
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: responseTime,
ErrorCount: tt.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (tt *TCPTransport) GetMetrics() TransportMetrics {
tt.mu.RLock()
defer tt.mu.RUnlock()
return TransportMetrics{
BytesSent: tt.metrics.BytesSent,
BytesReceived: tt.metrics.BytesReceived,
MessagesSent: tt.metrics.MessagesSent,
MessagesReceived: tt.metrics.MessagesReceived,
Connections: len(tt.connections),
Errors: tt.metrics.Errors,
Latency: tt.metrics.Latency,
}
}
// Private helper methods
func (tt *TCPTransport) startServer() error {
addr := fmt.Sprintf("%s:%d", tt.address, tt.port)
var listener net.Listener
var err error
if tt.tlsConfig != nil {
listener, err = tls.Listen("tcp", addr, tt.tlsConfig)
} else {
listener, err = net.Listen("tcp", addr)
}
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", addr, err)
}
tt.listener = listener
tt.connected = true
// Start accepting connections
go tt.acceptConnections()
return nil
}
func (tt *TCPTransport) connectToServer(ctx context.Context) error {
addr := net.JoinHostPort(tt.address, fmt.Sprintf("%d", tt.port))
var conn net.Conn
var err error
// Retry connection with exponential backoff
delay := tt.retryConfig.InitialDelay
for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ {
if tt.tlsConfig != nil {
conn, err = tls.Dial("tcp", addr, tt.tlsConfig)
} else {
conn, err = net.Dial("tcp", addr)
}
if err == nil {
break
}
if attempt == tt.retryConfig.MaxRetries {
return fmt.Errorf("failed to connect to %s after %d attempts: %w", addr, attempt+1, err)
}
// Wait with exponential backoff
select {
case <-time.After(delay):
delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor)
if delay > tt.retryConfig.MaxDelay {
delay = tt.retryConfig.MaxDelay
}
// Add jitter if enabled
if tt.retryConfig.Jitter {
jitter := time.Duration(float64(delay) * 0.1)
jitterFactor := float64(2*time.Now().UnixNano()%1000)/1000.0 - 1
delay += time.Duration(float64(jitter) * jitterFactor)
}
case <-ctx.Done():
return ctx.Err()
}
}
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
tt.connections[connID] = conn
tt.connected = true
tt.metrics.Connections = 1
// Start receiving from server
go tt.handleConnection(connID, conn)
return nil
}
func (tt *TCPTransport) acceptConnections() {
for {
select {
case <-tt.ctx.Done():
return
default:
conn, err := tt.listener.Accept()
if err != nil {
if tt.ctx.Err() != nil {
return // Context cancelled
}
tt.metrics.Errors++
continue
}
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
tt.mu.Lock()
tt.connections[connID] = conn
tt.metrics.Connections = len(tt.connections)
tt.mu.Unlock()
go tt.handleConnection(connID, conn)
}
}
}
func (tt *TCPTransport) handleConnection(connID string, conn net.Conn) {
defer tt.removeConnection(connID)
buffer := make([]byte, 4096)
var messageBuffer []byte
for {
select {
case <-tt.ctx.Done():
return
default:
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, err := conn.Read(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // Continue on timeout
}
return // Connection closed or error
}
messageBuffer = append(messageBuffer, buffer[:n]...)
// Process complete messages
for {
msg, remaining, err := ExtractMessage(messageBuffer)
if err != nil {
return // Invalid message format
}
if msg == nil {
break // No complete message yet
}
// Deliver message
select {
case tt.receiveChan <- msg:
tt.updateReceiveMetrics(msg)
case <-tt.ctx.Done():
return
default:
// Channel full, drop message
tt.metrics.Errors++
}
messageBuffer = remaining
}
}
}
}
func (tt *TCPTransport) sendWithRetry(ctx context.Context, conn net.Conn, data []byte) error {
delay := tt.retryConfig.InitialDelay
for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ {
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
_, err := conn.Write(data)
if err == nil {
return nil
}
if attempt == tt.retryConfig.MaxRetries {
return err
}
// Wait with exponential backoff
select {
case <-time.After(delay):
delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor)
if delay > tt.retryConfig.MaxDelay {
delay = tt.retryConfig.MaxDelay
}
case <-ctx.Done():
return ctx.Err()
}
}
return fmt.Errorf("max retries exceeded")
}
func (tt *TCPTransport) removeConnection(connID string) {
tt.mu.Lock()
defer tt.mu.Unlock()
if conn, exists := tt.connections[connID]; exists {
conn.Close()
delete(tt.connections, connID)
tt.metrics.Connections = len(tt.connections)
}
}
func (tt *TCPTransport) updateSendMetrics(msg *Message, latency time.Duration) {
tt.mu.Lock()
defer tt.mu.Unlock()
tt.metrics.MessagesSent++
tt.metrics.Latency = latency
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
tt.metrics.BytesSent += messageSize
}
func (tt *TCPTransport) updateReceiveMetrics(msg *Message) {
tt.mu.Lock()
defer tt.mu.Unlock()
tt.metrics.MessagesReceived++
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
tt.metrics.BytesReceived += messageSize
}
// GetAddress returns the transport address (for testing/debugging)
func (tt *TCPTransport) GetAddress() string {
return fmt.Sprintf("%s:%d", tt.address, tt.port)
}
// GetConnectionCount returns the number of active connections
func (tt *TCPTransport) GetConnectionCount() int {
tt.mu.RLock()
defer tt.mu.RUnlock()
return len(tt.connections)
}
// IsSecure returns whether TLS is enabled
func (tt *TCPTransport) IsSecure() bool {
return tt.tlsConfig != nil
}

View File

@@ -0,0 +1,316 @@
package transport
import (
"fmt"
"os"
"github.com/ethereum/go-ethereum/ethclient"
"gopkg.in/yaml.v3"
)
// UnifiedProviderManager manages all provider pools (read-only, execution, testing)
type UnifiedProviderManager struct {
ReadOnlyPool *ReadOnlyProviderPool
ExecutionPool *ExecutionProviderPool
TestingPool *TestingProviderPool
config ProvidersConfig
providerConfigs map[string]ProviderConfig
}
// OperationMode defines the type of operation being performed
type OperationMode int
const (
ModeReadOnly OperationMode = iota
ModeExecution
ModeTesting
)
// NewUnifiedProviderManager creates a new unified provider manager
func NewUnifiedProviderManager(configPath string) (*UnifiedProviderManager, error) {
// Load configuration
config, err := LoadProvidersConfig(configPath)
if err != nil {
return nil, fmt.Errorf("failed to load provider config: %w", err)
}
// Create provider configs map for easy lookup
providerConfigs := make(map[string]ProviderConfig)
for _, provider := range config.Providers {
providerConfigs[provider.Name] = provider
}
manager := &UnifiedProviderManager{
config: config,
providerConfigs: providerConfigs,
}
// Initialize provider pools
if err := manager.initializePools(); err != nil {
return nil, fmt.Errorf("failed to initialize provider pools: %w", err)
}
return manager, nil
}
// initializePools initializes all provider pools based on configuration
func (upm *UnifiedProviderManager) initializePools() error {
var err error
// Initialize read-only pool if configured
if poolConfig, exists := upm.config.ProviderPools["read_only"]; exists {
upm.ReadOnlyPool, err = NewReadOnlyProviderPool(poolConfig, upm.providerConfigs)
if err != nil {
return fmt.Errorf("failed to initialize read-only pool: %w", err)
}
}
// Initialize execution pool if configured
if poolConfig, exists := upm.config.ProviderPools["execution"]; exists {
upm.ExecutionPool, err = NewExecutionProviderPool(poolConfig, upm.providerConfigs)
if err != nil {
return fmt.Errorf("failed to initialize execution pool: %w", err)
}
}
// Initialize testing pool if configured
if poolConfig, exists := upm.config.ProviderPools["testing"]; exists {
upm.TestingPool, err = NewTestingProviderPool(poolConfig, upm.providerConfigs)
if err != nil {
return fmt.Errorf("failed to initialize testing pool: %w", err)
}
}
return nil
}
// GetPoolForMode returns the appropriate provider pool for the given operation mode
func (upm *UnifiedProviderManager) GetPoolForMode(mode OperationMode) (ProviderPool, error) {
switch mode {
case ModeReadOnly:
if upm.ReadOnlyPool == nil {
return nil, fmt.Errorf("read-only pool not initialized")
}
return upm.ReadOnlyPool, nil
case ModeExecution:
if upm.ExecutionPool == nil {
return nil, fmt.Errorf("execution pool not initialized")
}
return upm.ExecutionPool, nil
case ModeTesting:
if upm.TestingPool == nil {
return nil, fmt.Errorf("testing pool not initialized")
}
return upm.TestingPool, nil
default:
return nil, fmt.Errorf("unknown operation mode: %d", mode)
}
}
// GetReadOnlyHTTPClient returns an HTTP client optimized for read-only operations
func (upm *UnifiedProviderManager) GetReadOnlyHTTPClient() (*ethclient.Client, error) {
if upm.ReadOnlyPool == nil {
return nil, fmt.Errorf("read-only pool not initialized")
}
return upm.ReadOnlyPool.GetHTTPClient()
}
// GetReadOnlyWSClient returns a WebSocket client for real-time data
func (upm *UnifiedProviderManager) GetReadOnlyWSClient() (*ethclient.Client, error) {
if upm.ReadOnlyPool == nil {
return nil, fmt.Errorf("read-only pool not initialized")
}
return upm.ReadOnlyPool.GetWSClient()
}
// GetExecutionHTTPClient returns an HTTP client optimized for transaction execution
func (upm *UnifiedProviderManager) GetExecutionHTTPClient() (*ethclient.Client, error) {
if upm.ExecutionPool == nil {
return nil, fmt.Errorf("execution pool not initialized")
}
return upm.ExecutionPool.GetHTTPClient()
}
// GetTestingHTTPClient returns an HTTP client for testing (preferably Anvil)
func (upm *UnifiedProviderManager) GetTestingHTTPClient() (*ethclient.Client, error) {
if upm.TestingPool == nil {
return nil, fmt.Errorf("testing pool not initialized")
}
return upm.TestingPool.GetHTTPClient()
}
// GetAllStats returns statistics for all provider pools
func (upm *UnifiedProviderManager) GetAllStats() map[string]interface{} {
stats := make(map[string]interface{})
if upm.ReadOnlyPool != nil {
stats["read_only"] = upm.ReadOnlyPool.GetStats()
}
if upm.ExecutionPool != nil {
stats["execution"] = upm.ExecutionPool.GetStats()
}
if upm.TestingPool != nil {
stats["testing"] = upm.TestingPool.GetStats()
}
// Add overall summary
summary := map[string]interface{}{
"total_pools": len(stats),
"pools_initialized": []string{},
}
for poolName := range stats {
summary["pools_initialized"] = append(summary["pools_initialized"].([]string), poolName)
}
stats["summary"] = summary
return stats
}
// CreateTestingSnapshot creates a snapshot in the testing environment
func (upm *UnifiedProviderManager) CreateTestingSnapshot() (string, error) {
if upm.TestingPool == nil {
return "", fmt.Errorf("testing pool not initialized")
}
return upm.TestingPool.CreateSnapshot()
}
// RevertTestingSnapshot reverts to a snapshot in the testing environment
func (upm *UnifiedProviderManager) RevertTestingSnapshot(snapshotID string) error {
if upm.TestingPool == nil {
return fmt.Errorf("testing pool not initialized")
}
return upm.TestingPool.RevertToSnapshot(snapshotID)
}
// Close shuts down all provider pools
func (upm *UnifiedProviderManager) Close() error {
var errors []error
if upm.ReadOnlyPool != nil {
if err := upm.ReadOnlyPool.Close(); err != nil {
errors = append(errors, fmt.Errorf("failed to close read-only pool: %w", err))
}
}
if upm.ExecutionPool != nil {
if err := upm.ExecutionPool.Close(); err != nil {
errors = append(errors, fmt.Errorf("failed to close execution pool: %w", err))
}
}
if upm.TestingPool != nil {
if err := upm.TestingPool.Close(); err != nil {
errors = append(errors, fmt.Errorf("failed to close testing pool: %w", err))
}
}
if len(errors) > 0 {
return fmt.Errorf("errors closing provider pools: %v", errors)
}
return nil
}
// LoadProvidersConfigFromFile loads configuration from a YAML file
func LoadProvidersConfigFromFile(path string) (ProvidersConfig, error) {
var config ProvidersConfig
// Read the YAML file
data, err := os.ReadFile(path)
if err != nil {
return config, fmt.Errorf("failed to read config file %s: %w", path, err)
}
// Unmarshal the YAML data
if err := yaml.Unmarshal(data, &config); err != nil {
return config, fmt.Errorf("failed to parse YAML config: %w", err)
}
// Validate the configuration
if err := validateProvidersConfig(&config); err != nil {
return config, fmt.Errorf("invalid configuration: %w", err)
}
return config, nil
}
// validateProvidersConfig validates the provider configuration
func validateProvidersConfig(config *ProvidersConfig) error {
if len(config.Providers) == 0 {
return fmt.Errorf("no providers configured")
}
// Validate provider pools
for poolName, poolConfig := range config.ProviderPools {
if len(poolConfig.Providers) == 0 {
return fmt.Errorf("provider pool '%s' has no providers", poolName)
}
// Check that all referenced providers exist
providerNames := make(map[string]bool)
for _, provider := range config.Providers {
providerNames[provider.Name] = true
}
for _, providerName := range poolConfig.Providers {
if !providerNames[providerName] {
return fmt.Errorf("provider pool '%s' references unknown provider '%s'", poolName, providerName)
}
}
}
// Validate individual providers
for i, provider := range config.Providers {
if provider.Name == "" {
return fmt.Errorf("provider %d has no name", i)
}
if provider.HTTPEndpoint == "" && provider.WSEndpoint == "" {
return fmt.Errorf("provider %s has no endpoints", provider.Name)
}
if provider.RateLimit.RequestsPerSecond <= 0 {
return fmt.Errorf("provider %s has invalid rate limit", provider.Name)
}
// Validate Anvil config if present
if provider.Type == "anvil_fork" && provider.AnvilConfig == nil {
return fmt.Errorf("provider %s is anvil_fork type but has no anvil_config", provider.Name)
}
}
return nil
}
// GetProviderByName returns a specific provider configuration by name
func (upm *UnifiedProviderManager) GetProviderByName(name string) (ProviderConfig, bool) {
config, exists := upm.providerConfigs[name]
return config, exists
}
// GetProvidersByType returns all providers of a specific type
func (upm *UnifiedProviderManager) GetProvidersByType(providerType string) []ProviderConfig {
var providers []ProviderConfig
for _, provider := range upm.config.Providers {
if provider.Type == providerType {
providers = append(providers, provider)
}
}
return providers
}
// GetProvidersByFeature returns all providers that support a specific feature
func (upm *UnifiedProviderManager) GetProvidersByFeature(feature string) []ProviderConfig {
var providers []ProviderConfig
for _, provider := range upm.config.Providers {
for _, providerFeature := range provider.Features {
if providerFeature == feature {
providers = append(providers, provider)
break
}
}
}
return providers
}

View File

@@ -0,0 +1,359 @@
package transport
import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"sync"
"time"
)
// UnixSocketTransport implements Unix socket transport for local IPC
type UnixSocketTransport struct {
socketPath string
listener net.Listener
connections map[string]net.Conn
metrics TransportMetrics
connected bool
isServer bool
receiveChan chan *Message
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
// NewUnixSocketTransport creates a new Unix socket transport
func NewUnixSocketTransport(socketPath string, isServer bool) *UnixSocketTransport {
ctx, cancel := context.WithCancel(context.Background())
return &UnixSocketTransport{
socketPath: socketPath,
connections: make(map[string]net.Conn),
metrics: TransportMetrics{},
isServer: isServer,
receiveChan: make(chan *Message, 1000),
ctx: ctx,
cancel: cancel,
}
}
// Connect establishes the Unix socket connection
func (ut *UnixSocketTransport) Connect(ctx context.Context) error {
ut.mu.Lock()
defer ut.mu.Unlock()
if ut.connected {
return nil
}
if ut.isServer {
return ut.startServer()
} else {
return ut.connectToServer()
}
}
// Disconnect closes the Unix socket connection
func (ut *UnixSocketTransport) Disconnect(ctx context.Context) error {
ut.mu.Lock()
defer ut.mu.Unlock()
if !ut.connected {
return nil
}
ut.cancel()
if ut.isServer && ut.listener != nil {
ut.listener.Close()
// Remove socket file
os.Remove(ut.socketPath)
}
// Close all connections
for id, conn := range ut.connections {
conn.Close()
delete(ut.connections, id)
}
close(ut.receiveChan)
ut.connected = false
ut.metrics.Connections = 0
return nil
}
// Send transmits a message through the Unix socket
func (ut *UnixSocketTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
ut.mu.RLock()
if !ut.connected {
ut.mu.RUnlock()
ut.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Serialize message
data, err := json.Marshal(msg)
if err != nil {
ut.mu.RUnlock()
ut.metrics.Errors++
return fmt.Errorf("failed to marshal message: %w", err)
}
// Add length prefix for framing
frame := fmt.Sprintf("%d\n%s", len(data), data)
frameBytes := []byte(frame)
// Send to all connections
var sendErr error
connectionCount := len(ut.connections)
ut.mu.RUnlock()
if connectionCount == 0 {
ut.metrics.Errors++
return fmt.Errorf("no active connections")
}
ut.mu.RLock()
for connID, conn := range ut.connections {
if err := ut.sendToConnection(conn, frameBytes); err != nil {
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
// Remove failed connection
go ut.removeConnection(connID)
}
}
ut.mu.RUnlock()
if sendErr == nil {
ut.updateSendMetrics(msg, time.Since(start))
} else {
ut.metrics.Errors++
}
return sendErr
}
// Receive returns a channel for receiving messages
func (ut *UnixSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) {
ut.mu.RLock()
defer ut.mu.RUnlock()
if !ut.connected {
return nil, fmt.Errorf("transport not connected")
}
return ut.receiveChan, nil
}
// Health returns the health status of the transport
func (ut *UnixSocketTransport) Health() ComponentHealth {
ut.mu.RLock()
defer ut.mu.RUnlock()
status := "unhealthy"
if ut.connected {
if len(ut.connections) > 0 {
status = "healthy"
} else {
status = "degraded" // Connected but no active connections
}
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: time.Millisecond, // Fast for local sockets
ErrorCount: ut.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (ut *UnixSocketTransport) GetMetrics() TransportMetrics {
ut.mu.RLock()
defer ut.mu.RUnlock()
return TransportMetrics{
BytesSent: ut.metrics.BytesSent,
BytesReceived: ut.metrics.BytesReceived,
MessagesSent: ut.metrics.MessagesSent,
MessagesReceived: ut.metrics.MessagesReceived,
Connections: len(ut.connections),
Errors: ut.metrics.Errors,
Latency: ut.metrics.Latency,
}
}
// Private helper methods
func (ut *UnixSocketTransport) startServer() error {
// Remove existing socket file
os.Remove(ut.socketPath)
listener, err := net.Listen("unix", ut.socketPath)
if err != nil {
return fmt.Errorf("failed to listen on socket %s: %w", ut.socketPath, err)
}
ut.listener = listener
ut.connected = true
// Start accepting connections
go ut.acceptConnections()
return nil
}
func (ut *UnixSocketTransport) connectToServer() error {
conn, err := net.Dial("unix", ut.socketPath)
if err != nil {
return fmt.Errorf("failed to connect to socket %s: %w", ut.socketPath, err)
}
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
ut.connections[connID] = conn
ut.connected = true
ut.metrics.Connections = 1
// Start receiving from server
go ut.handleConnection(connID, conn)
return nil
}
func (ut *UnixSocketTransport) acceptConnections() {
for {
select {
case <-ut.ctx.Done():
return
default:
conn, err := ut.listener.Accept()
if err != nil {
if ut.ctx.Err() != nil {
return // Context cancelled
}
ut.metrics.Errors++
continue
}
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
ut.mu.Lock()
ut.connections[connID] = conn
ut.metrics.Connections = len(ut.connections)
ut.mu.Unlock()
go ut.handleConnection(connID, conn)
}
}
}
func (ut *UnixSocketTransport) handleConnection(connID string, conn net.Conn) {
defer ut.removeConnection(connID)
buffer := make([]byte, 4096)
var messageBuffer []byte
for {
select {
case <-ut.ctx.Done():
return
default:
conn.SetReadDeadline(time.Now().Add(time.Second)) // Non-blocking read
n, err := conn.Read(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // Continue on timeout
}
return // Connection closed or error
}
messageBuffer = append(messageBuffer, buffer[:n]...)
// Process complete messages
for {
msg, remaining, err := ExtractMessage(messageBuffer)
if err != nil {
return // Invalid message format
}
if msg == nil {
break // No complete message yet
}
// Deliver message
select {
case ut.receiveChan <- msg:
ut.updateReceiveMetrics(msg)
case <-ut.ctx.Done():
return
default:
// Channel full, drop message
ut.metrics.Errors++
}
messageBuffer = remaining
}
}
}
}
func (ut *UnixSocketTransport) sendToConnection(conn net.Conn, data []byte) error {
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
_, err := conn.Write(data)
return err
}
func (ut *UnixSocketTransport) removeConnection(connID string) {
ut.mu.Lock()
defer ut.mu.Unlock()
if conn, exists := ut.connections[connID]; exists {
conn.Close()
delete(ut.connections, connID)
ut.metrics.Connections = len(ut.connections)
}
}
func (ut *UnixSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) {
ut.mu.Lock()
defer ut.mu.Unlock()
ut.metrics.MessagesSent++
ut.metrics.Latency = latency
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
ut.metrics.BytesSent += messageSize
}
func (ut *UnixSocketTransport) updateReceiveMetrics(msg *Message) {
ut.mu.Lock()
defer ut.mu.Unlock()
ut.metrics.MessagesReceived++
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
ut.metrics.BytesReceived += messageSize
}
// GetSocketPath returns the socket path (for testing/debugging)
func (ut *UnixSocketTransport) GetSocketPath() string {
return ut.socketPath
}
// GetConnectionCount returns the number of active connections
func (ut *UnixSocketTransport) GetConnectionCount() int {
ut.mu.RLock()
defer ut.mu.RUnlock()
return len(ut.connections)
}

View File

@@ -0,0 +1,48 @@
package transport
import (
"encoding/json"
"fmt"
)
// ExtractMessage extracts a message from a byte buffer with length prefix format
// Format: "length\nmessage_data"
func ExtractMessage(buffer []byte) (*Message, []byte, error) {
// Look for length prefix (format: "length\nmessage_data")
newlineIndex := -1
for i, b := range buffer {
if b == '\n' {
newlineIndex = i
break
}
}
if newlineIndex == -1 {
return nil, buffer, nil // No complete length prefix yet
}
// Parse length
lengthStr := string(buffer[:newlineIndex])
var messageLength int
if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil {
return nil, nil, fmt.Errorf("invalid length prefix: %s", lengthStr)
}
// Check if we have the complete message
messageStart := newlineIndex + 1
messageEnd := messageStart + messageLength
if len(buffer) < messageEnd {
return nil, buffer, nil // Incomplete message
}
// Extract and parse message
messageData := buffer[messageStart:messageEnd]
var msg Message
if err := json.Unmarshal(messageData, &msg); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal message: %w", err)
}
// Return message and remaining buffer
remaining := buffer[messageEnd:]
return &msg, remaining, nil
}

View File

@@ -0,0 +1,428 @@
package transport
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
// WebSocketTransport implements WebSocket transport for real-time monitoring
type WebSocketTransport struct {
address string
port int
path string
upgrader websocket.Upgrader
connections map[string]*websocket.Conn
metrics TransportMetrics
connected bool
isServer bool
receiveChan chan *Message
server *http.Server
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
pingInterval time.Duration
pongTimeout time.Duration
}
// NewWebSocketTransport creates a new WebSocket transport
func NewWebSocketTransport(address string, port int, path string, isServer bool) *WebSocketTransport {
ctx, cancel := context.WithCancel(context.Background())
return &WebSocketTransport{
address: address,
port: port,
path: path,
connections: make(map[string]*websocket.Conn),
metrics: TransportMetrics{},
isServer: isServer,
receiveChan: make(chan *Message, 1000),
ctx: ctx,
cancel: cancel,
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for now
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
pingInterval: 30 * time.Second,
pongTimeout: 10 * time.Second,
}
}
// SetPingPongSettings configures WebSocket ping/pong settings
func (wt *WebSocketTransport) SetPingPongSettings(pingInterval, pongTimeout time.Duration) {
wt.pingInterval = pingInterval
wt.pongTimeout = pongTimeout
}
// Connect establishes the WebSocket connection
func (wt *WebSocketTransport) Connect(ctx context.Context) error {
wt.mu.Lock()
defer wt.mu.Unlock()
if wt.connected {
return nil
}
if wt.isServer {
return wt.startServer()
} else {
return wt.connectToServer(ctx)
}
}
// Disconnect closes the WebSocket connection
func (wt *WebSocketTransport) Disconnect(ctx context.Context) error {
wt.mu.Lock()
defer wt.mu.Unlock()
if !wt.connected {
return nil
}
wt.cancel()
if wt.isServer && wt.server != nil {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wt.server.Shutdown(shutdownCtx)
}
// Close all connections
for id, conn := range wt.connections {
conn.Close()
delete(wt.connections, id)
}
close(wt.receiveChan)
wt.connected = false
wt.metrics.Connections = 0
return nil
}
// Send transmits a message through WebSocket
func (wt *WebSocketTransport) Send(ctx context.Context, msg *Message) error {
start := time.Now()
wt.mu.RLock()
if !wt.connected {
wt.mu.RUnlock()
wt.metrics.Errors++
return fmt.Errorf("transport not connected")
}
// Serialize message
data, err := json.Marshal(msg)
if err != nil {
wt.mu.RUnlock()
wt.metrics.Errors++
return fmt.Errorf("failed to marshal message: %w", err)
}
// Send to all connections
var sendErr error
connectionCount := len(wt.connections)
wt.mu.RUnlock()
if connectionCount == 0 {
wt.metrics.Errors++
return fmt.Errorf("no active connections")
}
wt.mu.RLock()
for connID, conn := range wt.connections {
if err := wt.sendToConnection(conn, data); err != nil {
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
// Remove failed connection
go wt.removeConnection(connID)
}
}
wt.mu.RUnlock()
if sendErr == nil {
wt.updateSendMetrics(msg, time.Since(start))
} else {
wt.metrics.Errors++
}
return sendErr
}
// Receive returns a channel for receiving messages
func (wt *WebSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) {
wt.mu.RLock()
defer wt.mu.RUnlock()
if !wt.connected {
return nil, fmt.Errorf("transport not connected")
}
return wt.receiveChan, nil
}
// Health returns the health status of the transport
func (wt *WebSocketTransport) Health() ComponentHealth {
wt.mu.RLock()
defer wt.mu.RUnlock()
status := "unhealthy"
if wt.connected {
if len(wt.connections) > 0 {
status = "healthy"
} else {
status = "degraded" // Connected but no active connections
}
}
return ComponentHealth{
Status: status,
LastCheck: time.Now(),
ResponseTime: time.Millisecond * 5, // Very fast for WebSocket
ErrorCount: wt.metrics.Errors,
}
}
// GetMetrics returns transport-specific metrics
func (wt *WebSocketTransport) GetMetrics() TransportMetrics {
wt.mu.RLock()
defer wt.mu.RUnlock()
return TransportMetrics{
BytesSent: wt.metrics.BytesSent,
BytesReceived: wt.metrics.BytesReceived,
MessagesSent: wt.metrics.MessagesSent,
MessagesReceived: wt.metrics.MessagesReceived,
Connections: len(wt.connections),
Errors: wt.metrics.Errors,
Latency: wt.metrics.Latency,
}
}
// Private helper methods
func (wt *WebSocketTransport) startServer() error {
mux := http.NewServeMux()
mux.HandleFunc(wt.path, wt.handleWebSocket)
// Add health check endpoint
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
health := wt.Health()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(health)
})
// Add metrics endpoint
mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
metrics := wt.GetMetrics()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(metrics)
})
addr := fmt.Sprintf("%s:%d", wt.address, wt.port)
wt.server = &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second, // Prevent Slowloris attacks
ReadTimeout: 60 * time.Second,
WriteTimeout: 60 * time.Second,
IdleTimeout: 120 * time.Second,
}
wt.connected = true
// Start server in goroutine
go func() {
if err := wt.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
wt.metrics.Errors++
}
}()
return nil
}
func (wt *WebSocketTransport) connectToServer(ctx context.Context) error {
url := fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path)
conn, _, err := websocket.DefaultDialer.DialContext(ctx, url, nil)
if err != nil {
return fmt.Errorf("failed to connect to WebSocket server: %w", err)
}
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
wt.connections[connID] = conn
wt.connected = true
wt.metrics.Connections = 1
// Start handling connection
go wt.handleConnection(connID, conn)
return nil
}
func (wt *WebSocketTransport) handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := wt.upgrader.Upgrade(w, r, nil)
if err != nil {
wt.metrics.Errors++
return
}
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
wt.mu.Lock()
wt.connections[connID] = conn
wt.metrics.Connections = len(wt.connections)
wt.mu.Unlock()
go wt.handleConnection(connID, conn)
}
func (wt *WebSocketTransport) handleConnection(connID string, conn *websocket.Conn) {
defer wt.removeConnection(connID)
// Set up ping/pong handling
conn.SetReadDeadline(time.Now().Add(wt.pongTimeout))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(wt.pongTimeout))
return nil
})
// Start ping routine
go wt.pingRoutine(connID, conn)
for {
select {
case <-wt.ctx.Done():
return
default:
var msg Message
err := conn.ReadJSON(&msg)
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
wt.metrics.Errors++
}
return
}
// Deliver message
select {
case wt.receiveChan <- &msg:
wt.updateReceiveMetrics(&msg)
case <-wt.ctx.Done():
return
default:
// Channel full, drop message
wt.metrics.Errors++
}
}
}
}
func (wt *WebSocketTransport) pingRoutine(connID string, conn *websocket.Conn) {
ticker := time.NewTicker(wt.pingInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
return // Connection is likely closed
}
case <-wt.ctx.Done():
return
}
}
}
func (wt *WebSocketTransport) sendToConnection(conn *websocket.Conn, data []byte) error {
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return conn.WriteMessage(websocket.TextMessage, data)
}
func (wt *WebSocketTransport) removeConnection(connID string) {
wt.mu.Lock()
defer wt.mu.Unlock()
if conn, exists := wt.connections[connID]; exists {
conn.Close()
delete(wt.connections, connID)
wt.metrics.Connections = len(wt.connections)
}
}
func (wt *WebSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) {
wt.mu.Lock()
defer wt.mu.Unlock()
wt.metrics.MessagesSent++
wt.metrics.Latency = latency
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
wt.metrics.BytesSent += messageSize
}
func (wt *WebSocketTransport) updateReceiveMetrics(msg *Message) {
wt.mu.Lock()
defer wt.mu.Unlock()
wt.metrics.MessagesReceived++
// Estimate message size
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
if msg.Data != nil {
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
}
wt.metrics.BytesReceived += messageSize
}
// Broadcast sends a message to all connected clients (server mode only)
func (wt *WebSocketTransport) Broadcast(ctx context.Context, msg *Message) error {
if !wt.isServer {
return fmt.Errorf("broadcast only available in server mode")
}
return wt.Send(ctx, msg)
}
// GetURL returns the WebSocket URL (for testing/debugging)
func (wt *WebSocketTransport) GetURL() string {
return fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path)
}
// GetConnectionCount returns the number of active connections
func (wt *WebSocketTransport) GetConnectionCount() int {
wt.mu.RLock()
defer wt.mu.RUnlock()
return len(wt.connections)
}
// SetAllowedOrigins configures CORS for WebSocket connections
func (wt *WebSocketTransport) SetAllowedOrigins(origins []string) {
if len(origins) == 0 {
wt.upgrader.CheckOrigin = func(r *http.Request) bool {
return true // Allow all origins
}
return
}
originMap := make(map[string]bool)
for _, origin := range origins {
originMap[origin] = true
}
wt.upgrader.CheckOrigin = func(r *http.Request) bool {
origin := r.Header.Get("Origin")
return originMap[origin]
}
}