feat: create v2-prep branch with comprehensive planning
Restructured project for V2 refactor: **Structure Changes:** - Moved all V1 code to orig/ folder (preserved with git mv) - Created docs/planning/ directory - Added orig/README_V1.md explaining V1 preservation **Planning Documents:** - 00_V2_MASTER_PLAN.md: Complete architecture overview - Executive summary of critical V1 issues - High-level component architecture diagrams - 5-phase implementation roadmap - Success metrics and risk mitigation - 07_TASK_BREAKDOWN.md: Atomic task breakdown - 99+ hours of detailed tasks - Every task < 2 hours (atomic) - Clear dependencies and success criteria - Organized by implementation phase **V2 Key Improvements:** - Per-exchange parsers (factory pattern) - Multi-layer strict validation - Multi-index pool cache - Background validation pipeline - Comprehensive observability **Critical Issues Addressed:** - Zero address tokens (strict validation + cache enrichment) - Parsing accuracy (protocol-specific parsers) - No audit trail (background validation channel) - Inefficient lookups (multi-index cache) - Stats disconnection (event-driven metrics) Next Steps: 1. Review planning documents 2. Begin Phase 1: Foundation (P1-001 through P1-010) 3. Implement parsers in Phase 2 4. Build cache system in Phase 3 5. Add validation pipeline in Phase 4 6. Migrate and test in Phase 5 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
673
orig/pkg/transport/benchmarks.go
Normal file
673
orig/pkg/transport/benchmarks.go
Normal file
@@ -0,0 +1,673 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
"github.com/fraktal/mev-beta/pkg/security"
|
||||
)
|
||||
|
||||
// BenchmarkSuite provides comprehensive performance testing for the transport layer
|
||||
type BenchmarkSuite struct {
|
||||
logger *logger.Logger
|
||||
messageBus *UniversalMessageBus
|
||||
results []BenchmarkResult
|
||||
config BenchmarkConfig
|
||||
metrics BenchmarkMetrics
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// BenchmarkConfig configures benchmark parameters
|
||||
type BenchmarkConfig struct {
|
||||
MessageSizes []int // Message payload sizes to test
|
||||
Concurrency []int // Concurrency levels to test
|
||||
Duration time.Duration // Duration of each benchmark
|
||||
WarmupDuration time.Duration // Warmup period before measurements
|
||||
TransportTypes []TransportType // Transport types to benchmark
|
||||
MessageTypes []MessageType // Message types to test
|
||||
SerializationFormats []SerializationFormat // Serialization formats to test
|
||||
EnableMetrics bool // Whether to collect detailed metrics
|
||||
OutputFormat string // Output format (json, csv, console)
|
||||
}
|
||||
|
||||
// BenchmarkResult contains results from a single benchmark run
|
||||
type BenchmarkResult struct {
|
||||
TestName string `json:"test_name"`
|
||||
Transport TransportType `json:"transport"`
|
||||
MessageSize int `json:"message_size"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Serialization SerializationFormat `json:"serialization"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
MessagesSent int64 `json:"messages_sent"`
|
||||
MessagesReceived int64 `json:"messages_received"`
|
||||
BytesSent int64 `json:"bytes_sent"`
|
||||
BytesReceived int64 `json:"bytes_received"`
|
||||
ThroughputMsgSec float64 `json:"throughput_msg_sec"`
|
||||
ThroughputByteSec float64 `json:"throughput_byte_sec"`
|
||||
LatencyP50 time.Duration `json:"latency_p50"`
|
||||
LatencyP95 time.Duration `json:"latency_p95"`
|
||||
LatencyP99 time.Duration `json:"latency_p99"`
|
||||
ErrorRate float64 `json:"error_rate"`
|
||||
CPUUsage float64 `json:"cpu_usage"`
|
||||
MemoryUsage int64 `json:"memory_usage"`
|
||||
GCPauses int64 `json:"gc_pauses"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// BenchmarkMetrics tracks overall benchmark statistics
|
||||
type BenchmarkMetrics struct {
|
||||
TotalTests int `json:"total_tests"`
|
||||
PassedTests int `json:"passed_tests"`
|
||||
FailedTests int `json:"failed_tests"`
|
||||
TotalDuration time.Duration `json:"total_duration"`
|
||||
HighestThroughput float64 `json:"highest_throughput"`
|
||||
LowestLatency time.Duration `json:"lowest_latency"`
|
||||
BestTransport TransportType `json:"best_transport"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// LatencyTracker tracks message latencies
|
||||
type LatencyTracker struct {
|
||||
latencies []time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewBenchmarkSuite creates a new benchmark suite
|
||||
func NewBenchmarkSuite(messageBus *UniversalMessageBus, logger *logger.Logger) *BenchmarkSuite {
|
||||
return &BenchmarkSuite{
|
||||
logger: logger,
|
||||
messageBus: messageBus,
|
||||
results: make([]BenchmarkResult, 0),
|
||||
config: BenchmarkConfig{
|
||||
MessageSizes: []int{64, 256, 1024, 4096, 16384},
|
||||
Concurrency: []int{1, 10, 50, 100},
|
||||
Duration: 30 * time.Second,
|
||||
WarmupDuration: 5 * time.Second,
|
||||
TransportTypes: []TransportType{TransportMemory, TransportUnixSocket, TransportTCP},
|
||||
MessageTypes: []MessageType{MessageTypeEvent, MessageTypeCommand},
|
||||
SerializationFormats: []SerializationFormat{SerializationJSON},
|
||||
EnableMetrics: true,
|
||||
OutputFormat: "console",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig updates the benchmark configuration
|
||||
func (bs *BenchmarkSuite) SetConfig(config BenchmarkConfig) {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
bs.config = config
|
||||
}
|
||||
|
||||
// RunAll executes all benchmark tests
|
||||
func (bs *BenchmarkSuite) RunAll(ctx context.Context) error {
|
||||
bs.mu.Lock()
|
||||
defer bs.mu.Unlock()
|
||||
|
||||
startTime := time.Now()
|
||||
bs.metrics = BenchmarkMetrics{
|
||||
Timestamp: startTime,
|
||||
}
|
||||
|
||||
for _, transport := range bs.config.TransportTypes {
|
||||
for _, msgSize := range bs.config.MessageSizes {
|
||||
for _, concurrency := range bs.config.Concurrency {
|
||||
for _, serialization := range bs.config.SerializationFormats {
|
||||
result, err := bs.runSingleBenchmark(ctx, transport, msgSize, concurrency, serialization)
|
||||
if err != nil {
|
||||
bs.metrics.FailedTests++
|
||||
continue
|
||||
}
|
||||
|
||||
bs.results = append(bs.results, result)
|
||||
bs.metrics.PassedTests++
|
||||
bs.updateBestMetrics(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bs.metrics.TotalTests = bs.metrics.PassedTests + bs.metrics.FailedTests
|
||||
bs.metrics.TotalDuration = time.Since(startTime)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunThroughputBenchmark tests message throughput
|
||||
func (bs *BenchmarkSuite) RunThroughputBenchmark(ctx context.Context, transport TransportType, messageSize int, concurrency int) (BenchmarkResult, error) {
|
||||
return bs.runSingleBenchmark(ctx, transport, messageSize, concurrency, SerializationJSON)
|
||||
}
|
||||
|
||||
// RunLatencyBenchmark tests message latency
|
||||
func (bs *BenchmarkSuite) RunLatencyBenchmark(ctx context.Context, transport TransportType, messageSize int) (BenchmarkResult, error) {
|
||||
return bs.runSingleBenchmark(ctx, transport, messageSize, 1, SerializationJSON)
|
||||
}
|
||||
|
||||
// RunScalabilityBenchmark tests scalability across different concurrency levels
|
||||
func (bs *BenchmarkSuite) RunScalabilityBenchmark(ctx context.Context, transport TransportType, messageSize int) ([]BenchmarkResult, error) {
|
||||
var results []BenchmarkResult
|
||||
|
||||
for _, concurrency := range bs.config.Concurrency {
|
||||
result, err := bs.runSingleBenchmark(ctx, transport, messageSize, concurrency, SerializationJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scalability benchmark failed at concurrency %d: %w", concurrency, err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetResults returns all benchmark results
|
||||
func (bs *BenchmarkSuite) GetResults() []BenchmarkResult {
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
|
||||
results := make([]BenchmarkResult, len(bs.results))
|
||||
copy(results, bs.results)
|
||||
return results
|
||||
}
|
||||
|
||||
// GetMetrics returns benchmark metrics
|
||||
func (bs *BenchmarkSuite) GetMetrics() BenchmarkMetrics {
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
return bs.metrics
|
||||
}
|
||||
|
||||
// GetBestPerformingTransport returns the transport with the highest throughput
|
||||
func (bs *BenchmarkSuite) GetBestPerformingTransport() TransportType {
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
return bs.metrics.BestTransport
|
||||
}
|
||||
|
||||
// Private methods
|
||||
|
||||
func (bs *BenchmarkSuite) runSingleBenchmark(ctx context.Context, transport TransportType, messageSize int, concurrency int, serialization SerializationFormat) (BenchmarkResult, error) {
|
||||
testName := fmt.Sprintf("%s_%db_%dc_%s", transport, messageSize, concurrency, serialization)
|
||||
|
||||
result := BenchmarkResult{
|
||||
TestName: testName,
|
||||
Transport: transport,
|
||||
MessageSize: messageSize,
|
||||
Concurrency: concurrency,
|
||||
Serialization: serialization,
|
||||
Duration: bs.config.Duration,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
// Setup test environment
|
||||
latencyTracker := &LatencyTracker{
|
||||
latencies: make([]time.Duration, 0),
|
||||
}
|
||||
|
||||
// Create test topic
|
||||
topic := fmt.Sprintf("benchmark_%s", testName)
|
||||
|
||||
// Subscribe to topic
|
||||
subscription, err := bs.messageBus.Subscribe(topic, func(ctx context.Context, msg *Message) error {
|
||||
if startTime, ok := msg.Metadata["start_time"].(time.Time); ok {
|
||||
latency := time.Since(startTime)
|
||||
latencyTracker.AddLatency(latency)
|
||||
}
|
||||
atomic.AddInt64(&result.MessagesReceived, 1)
|
||||
atomic.AddInt64(&result.BytesReceived, int64(messageSize))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("failed to subscribe: %w", err)
|
||||
}
|
||||
defer bs.messageBus.Unsubscribe(subscription.ID)
|
||||
|
||||
// Warmup phase
|
||||
if bs.config.WarmupDuration > 0 {
|
||||
bs.warmup(ctx, topic, messageSize, concurrency, bs.config.WarmupDuration)
|
||||
}
|
||||
|
||||
// Start system monitoring
|
||||
var cpuUsage float64
|
||||
var memUsageBefore, memUsageAfter runtime.MemStats
|
||||
runtime.ReadMemStats(&memUsageBefore)
|
||||
|
||||
monitorCtx, monitorCancel := context.WithCancel(ctx)
|
||||
defer monitorCancel()
|
||||
|
||||
go bs.monitorSystemResources(monitorCtx, &cpuUsage)
|
||||
|
||||
// Main benchmark
|
||||
startTime := time.Now()
|
||||
benchmarkCtx, cancel := context.WithTimeout(ctx, bs.config.Duration)
|
||||
defer cancel()
|
||||
|
||||
// Launch concurrent senders
|
||||
var wg sync.WaitGroup
|
||||
var totalSent int64
|
||||
var totalErrors int64
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
bs.senderWorker(benchmarkCtx, topic, messageSize, &totalSent, &totalErrors)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Wait a bit for remaining messages to be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
actualDuration := time.Since(startTime)
|
||||
runtime.ReadMemStats(&memUsageAfter)
|
||||
|
||||
// Calculate results
|
||||
result.MessagesSent = totalSent
|
||||
result.BytesSent = totalSent * int64(messageSize)
|
||||
result.ThroughputMsgSec = float64(totalSent) / actualDuration.Seconds()
|
||||
result.ThroughputByteSec = float64(result.BytesSent) / actualDuration.Seconds()
|
||||
result.ErrorRate = float64(totalErrors) / float64(totalSent) * 100
|
||||
result.CPUUsage = cpuUsage
|
||||
// Calculate memory usage difference safely
|
||||
memDiff := memUsageAfter.Alloc - memUsageBefore.Alloc
|
||||
memDiffInt64, err := security.SafeUint64ToInt64(memDiff)
|
||||
if err != nil {
|
||||
bs.logger.Warn("Memory usage difference exceeds int64 max", "diff", memDiff, "error", err)
|
||||
result.MemoryUsage = math.MaxInt64
|
||||
} else {
|
||||
result.MemoryUsage = memDiffInt64
|
||||
}
|
||||
|
||||
// Calculate GC pauses difference safely
|
||||
gcDiff := int64(memUsageAfter.NumGC) - int64(memUsageBefore.NumGC)
|
||||
result.GCPauses = gcDiff
|
||||
|
||||
// Calculate latency percentiles
|
||||
if len(latencyTracker.latencies) > 0 {
|
||||
result.LatencyP50 = latencyTracker.GetPercentile(50)
|
||||
result.LatencyP95 = latencyTracker.GetPercentile(95)
|
||||
result.LatencyP99 = latencyTracker.GetPercentile(99)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (bs *BenchmarkSuite) warmup(ctx context.Context, topic string, messageSize int, concurrency int, duration time.Duration) {
|
||||
warmupCtx, cancel := context.WithTimeout(ctx, duration)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var dummy1, dummy2 int64
|
||||
bs.senderWorker(warmupCtx, topic, messageSize, &dummy1, &dummy2)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (bs *BenchmarkSuite) senderWorker(ctx context.Context, topic string, messageSize int, totalSent, totalErrors *int64) {
|
||||
payload := make([]byte, messageSize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
msg := NewMessage(MessageTypeEvent, topic, "benchmark", payload)
|
||||
msg.Metadata["start_time"] = time.Now()
|
||||
|
||||
if err := bs.messageBus.Publish(ctx, msg); err != nil {
|
||||
atomic.AddInt64(totalErrors, 1)
|
||||
} else {
|
||||
atomic.AddInt64(totalSent, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (bs *BenchmarkSuite) monitorSystemResources(ctx context.Context, cpuUsage *float64) {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
var samples []float64
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Calculate average CPU usage
|
||||
if len(samples) > 0 {
|
||||
var total float64
|
||||
for _, sample := range samples {
|
||||
total += sample
|
||||
}
|
||||
*cpuUsage = total / float64(len(samples))
|
||||
}
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Simple CPU usage estimation based on runtime stats
|
||||
var stats runtime.MemStats
|
||||
runtime.ReadMemStats(&stats)
|
||||
|
||||
// This is a simplified CPU usage calculation
|
||||
// In production, you'd want to use proper OS-specific CPU monitoring
|
||||
elapsed := time.Since(startTime).Seconds()
|
||||
cpuSample := float64(stats.NumGC) / elapsed * 100 // Rough approximation
|
||||
if cpuSample > 100 {
|
||||
cpuSample = 100
|
||||
}
|
||||
samples = append(samples, cpuSample)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (bs *BenchmarkSuite) updateBestMetrics(result BenchmarkResult) {
|
||||
if result.ThroughputMsgSec > bs.metrics.HighestThroughput {
|
||||
bs.metrics.HighestThroughput = result.ThroughputMsgSec
|
||||
bs.metrics.BestTransport = result.Transport
|
||||
}
|
||||
|
||||
if bs.metrics.LowestLatency == 0 || result.LatencyP50 < bs.metrics.LowestLatency {
|
||||
bs.metrics.LowestLatency = result.LatencyP50
|
||||
}
|
||||
}
|
||||
|
||||
// LatencyTracker methods
|
||||
|
||||
func (lt *LatencyTracker) AddLatency(latency time.Duration) {
|
||||
lt.mu.Lock()
|
||||
defer lt.mu.Unlock()
|
||||
lt.latencies = append(lt.latencies, latency)
|
||||
}
|
||||
|
||||
func (lt *LatencyTracker) GetPercentile(percentile int) time.Duration {
|
||||
lt.mu.Lock()
|
||||
defer lt.mu.Unlock()
|
||||
|
||||
if len(lt.latencies) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sort latencies
|
||||
sorted := make([]time.Duration, len(lt.latencies))
|
||||
copy(sorted, lt.latencies)
|
||||
|
||||
// Simple insertion sort for small datasets
|
||||
for i := 1; i < len(sorted); i++ {
|
||||
for j := i; j > 0 && sorted[j] < sorted[j-1]; j-- {
|
||||
sorted[j], sorted[j-1] = sorted[j-1], sorted[j]
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate percentile index
|
||||
index := int(float64(len(sorted)) * float64(percentile) / 100.0)
|
||||
if index >= len(sorted) {
|
||||
index = len(sorted) - 1
|
||||
}
|
||||
|
||||
return sorted[index]
|
||||
}
|
||||
|
||||
// Benchmark report generation
|
||||
|
||||
// GenerateReport generates a comprehensive benchmark report
|
||||
func (bs *BenchmarkSuite) GenerateReport() BenchmarkReport {
|
||||
bs.mu.RLock()
|
||||
defer bs.mu.RUnlock()
|
||||
|
||||
report := BenchmarkReport{
|
||||
Summary: bs.generateSummary(),
|
||||
Results: bs.results,
|
||||
Metrics: bs.metrics,
|
||||
Config: bs.config,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
report.Analysis = bs.generateAnalysis()
|
||||
|
||||
return report
|
||||
}
|
||||
|
||||
// BenchmarkReport contains a complete benchmark report
|
||||
type BenchmarkReport struct {
|
||||
Summary ReportSummary `json:"summary"`
|
||||
Results []BenchmarkResult `json:"results"`
|
||||
Metrics BenchmarkMetrics `json:"metrics"`
|
||||
Config BenchmarkConfig `json:"config"`
|
||||
Analysis ReportAnalysis `json:"analysis"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// ReportSummary provides a high-level summary
|
||||
type ReportSummary struct {
|
||||
TotalTests int `json:"total_tests"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
BestThroughput float64 `json:"best_throughput"`
|
||||
BestLatency time.Duration `json:"best_latency"`
|
||||
RecommendedTransport TransportType `json:"recommended_transport"`
|
||||
TransportRankings []TransportRanking `json:"transport_rankings"`
|
||||
}
|
||||
|
||||
// TransportRanking ranks transports by performance
|
||||
type TransportRanking struct {
|
||||
Transport TransportType `json:"transport"`
|
||||
AvgThroughput float64 `json:"avg_throughput"`
|
||||
AvgLatency time.Duration `json:"avg_latency"`
|
||||
Score float64 `json:"score"`
|
||||
Rank int `json:"rank"`
|
||||
}
|
||||
|
||||
// ReportAnalysis provides detailed analysis
|
||||
type ReportAnalysis struct {
|
||||
ScalabilityAnalysis ScalabilityAnalysis `json:"scalability"`
|
||||
PerformanceBottlenecks []PerformanceIssue `json:"bottlenecks"`
|
||||
Recommendations []Recommendation `json:"recommendations"`
|
||||
}
|
||||
|
||||
// ScalabilityAnalysis analyzes scaling characteristics
|
||||
type ScalabilityAnalysis struct {
|
||||
LinearScaling bool `json:"linear_scaling"`
|
||||
ScalingFactor float64 `json:"scaling_factor"`
|
||||
OptimalConcurrency int `json:"optimal_concurrency"`
|
||||
}
|
||||
|
||||
// PerformanceIssue identifies performance problems
|
||||
type PerformanceIssue struct {
|
||||
Issue string `json:"issue"`
|
||||
Severity string `json:"severity"`
|
||||
Impact string `json:"impact"`
|
||||
Suggestion string `json:"suggestion"`
|
||||
}
|
||||
|
||||
// Recommendation provides optimization suggestions
|
||||
type Recommendation struct {
|
||||
Category string `json:"category"`
|
||||
Description string `json:"description"`
|
||||
Priority string `json:"priority"`
|
||||
Expected string `json:"expected_improvement"`
|
||||
}
|
||||
|
||||
func (bs *BenchmarkSuite) generateSummary() ReportSummary {
|
||||
rankings := bs.calculateTransportRankings()
|
||||
|
||||
return ReportSummary{
|
||||
TotalTests: bs.metrics.TotalTests,
|
||||
Duration: bs.metrics.TotalDuration,
|
||||
BestThroughput: bs.metrics.HighestThroughput,
|
||||
BestLatency: bs.metrics.LowestLatency,
|
||||
RecommendedTransport: bs.metrics.BestTransport,
|
||||
TransportRankings: rankings,
|
||||
}
|
||||
}
|
||||
|
||||
func (bs *BenchmarkSuite) calculateTransportRankings() []TransportRanking {
|
||||
// Group results by transport
|
||||
transportStats := make(map[TransportType][]BenchmarkResult)
|
||||
for _, result := range bs.results {
|
||||
transportStats[result.Transport] = append(transportStats[result.Transport], result)
|
||||
}
|
||||
|
||||
var rankings []TransportRanking
|
||||
for transport, results := range transportStats {
|
||||
var totalThroughput float64
|
||||
var totalLatency time.Duration
|
||||
|
||||
for _, result := range results {
|
||||
totalThroughput += result.ThroughputMsgSec
|
||||
totalLatency += result.LatencyP50
|
||||
}
|
||||
|
||||
avgThroughput := totalThroughput / float64(len(results))
|
||||
avgLatency := totalLatency / time.Duration(len(results))
|
||||
|
||||
// Score calculation (higher throughput + lower latency = better score)
|
||||
score := avgThroughput / float64(avgLatency.Microseconds())
|
||||
|
||||
rankings = append(rankings, TransportRanking{
|
||||
Transport: transport,
|
||||
AvgThroughput: avgThroughput,
|
||||
AvgLatency: avgLatency,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by score (descending)
|
||||
for i := 0; i < len(rankings); i++ {
|
||||
for j := i + 1; j < len(rankings); j++ {
|
||||
if rankings[j].Score > rankings[i].Score {
|
||||
rankings[i], rankings[j] = rankings[j], rankings[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Assign ranks
|
||||
for i := range rankings {
|
||||
rankings[i].Rank = i + 1
|
||||
}
|
||||
|
||||
return rankings
|
||||
}
|
||||
|
||||
func (bs *BenchmarkSuite) generateAnalysis() ReportAnalysis {
|
||||
return ReportAnalysis{
|
||||
ScalabilityAnalysis: bs.analyzeScalability(),
|
||||
PerformanceBottlenecks: bs.identifyBottlenecks(),
|
||||
Recommendations: bs.generateRecommendations(),
|
||||
}
|
||||
}
|
||||
|
||||
func (bs *BenchmarkSuite) analyzeScalability() ScalabilityAnalysis {
|
||||
if len(bs.results) < 2 {
|
||||
return ScalabilityAnalysis{
|
||||
LinearScaling: false,
|
||||
ScalingFactor: 0.0,
|
||||
OptimalConcurrency: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Analyze throughput vs concurrency relationship
|
||||
var throughputData []float64
|
||||
var concurrencyData []int
|
||||
|
||||
for _, result := range bs.results {
|
||||
if result.Concurrency > 0 && result.Duration > 0 {
|
||||
throughput := float64(result.MessagesReceived) / result.Duration.Seconds()
|
||||
throughputData = append(throughputData, throughput)
|
||||
concurrencyData = append(concurrencyData, result.Concurrency)
|
||||
}
|
||||
}
|
||||
|
||||
if len(throughputData) < 2 {
|
||||
return ScalabilityAnalysis{
|
||||
LinearScaling: false,
|
||||
ScalingFactor: 0.0,
|
||||
OptimalConcurrency: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate scaling efficiency
|
||||
// Compare actual throughput improvement with ideal linear scaling
|
||||
maxThroughput := 0.0
|
||||
maxThroughputConcurrency := 1
|
||||
baseThroughput := throughputData[0]
|
||||
baseConcurrency := float64(concurrencyData[0])
|
||||
|
||||
for i, throughput := range throughputData {
|
||||
if throughput > maxThroughput {
|
||||
maxThroughput = throughput
|
||||
maxThroughputConcurrency = concurrencyData[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate scaling factor (actual vs ideal)
|
||||
idealThroughput := baseThroughput * float64(maxThroughputConcurrency) / baseConcurrency
|
||||
actualScalingFactor := maxThroughput / idealThroughput
|
||||
|
||||
// Determine if scaling is linear (within 20% of ideal)
|
||||
linearScaling := actualScalingFactor >= 0.8
|
||||
|
||||
return ScalabilityAnalysis{
|
||||
LinearScaling: linearScaling,
|
||||
ScalingFactor: actualScalingFactor,
|
||||
OptimalConcurrency: maxThroughputConcurrency,
|
||||
}
|
||||
}
|
||||
|
||||
func (bs *BenchmarkSuite) identifyBottlenecks() []PerformanceIssue {
|
||||
var issues []PerformanceIssue
|
||||
|
||||
// Analyze results for common performance issues
|
||||
for _, result := range bs.results {
|
||||
if result.ErrorRate > 5.0 {
|
||||
issues = append(issues, PerformanceIssue{
|
||||
Issue: fmt.Sprintf("High error rate (%0.2f%%) for %s", result.ErrorRate, result.Transport),
|
||||
Severity: "high",
|
||||
Impact: "Reduced reliability and performance",
|
||||
Suggestion: "Check transport configuration and network stability",
|
||||
})
|
||||
}
|
||||
|
||||
if result.LatencyP99 > 100*time.Millisecond {
|
||||
issues = append(issues, PerformanceIssue{
|
||||
Issue: fmt.Sprintf("High P99 latency (%v) for %s", result.LatencyP99, result.Transport),
|
||||
Severity: "medium",
|
||||
Impact: "Poor user experience for latency-sensitive operations",
|
||||
Suggestion: "Consider using faster transport or optimizing message serialization",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return issues
|
||||
}
|
||||
|
||||
func (bs *BenchmarkSuite) generateRecommendations() []Recommendation {
|
||||
var recommendations []Recommendation
|
||||
|
||||
recommendations = append(recommendations, Recommendation{
|
||||
Category: "Transport Selection",
|
||||
Description: fmt.Sprintf("Use %s for best overall performance", bs.metrics.BestTransport),
|
||||
Priority: "high",
|
||||
Expected: "20-50% improvement in throughput",
|
||||
})
|
||||
|
||||
recommendations = append(recommendations, Recommendation{
|
||||
Category: "Concurrency",
|
||||
Description: "Optimize concurrency level based on workload characteristics",
|
||||
Priority: "medium",
|
||||
Expected: "10-30% improvement in resource utilization",
|
||||
})
|
||||
|
||||
return recommendations
|
||||
}
|
||||
591
orig/pkg/transport/dlq.go
Normal file
591
orig/pkg/transport/dlq.go
Normal file
@@ -0,0 +1,591 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DeadLetterQueue handles failed messages with retry and reprocessing capabilities
|
||||
type DeadLetterQueue struct {
|
||||
messages map[string][]*DLQMessage
|
||||
config DLQConfig
|
||||
metrics DLQMetrics
|
||||
reprocessor MessageReprocessor
|
||||
mu sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// DLQMessage represents a message in the dead letter queue
|
||||
type DLQMessage struct {
|
||||
ID string
|
||||
OriginalMessage *Message
|
||||
Topic string
|
||||
FirstFailed time.Time
|
||||
LastAttempt time.Time
|
||||
AttemptCount int
|
||||
MaxRetries int
|
||||
FailureReason string
|
||||
RetryDelay time.Duration
|
||||
NextRetry time.Time
|
||||
Metadata map[string]interface{}
|
||||
Permanent bool
|
||||
}
|
||||
|
||||
// DLQConfig configures dead letter queue behavior
|
||||
type DLQConfig struct {
|
||||
MaxMessages int
|
||||
MaxRetries int
|
||||
RetentionTime time.Duration
|
||||
AutoReprocess bool
|
||||
ReprocessInterval time.Duration
|
||||
BackoffStrategy BackoffStrategy
|
||||
InitialRetryDelay time.Duration
|
||||
MaxRetryDelay time.Duration
|
||||
BackoffMultiplier float64
|
||||
PermanentFailures []string // Error patterns that mark messages as permanently failed
|
||||
ReprocessBatchSize int
|
||||
}
|
||||
|
||||
// BackoffStrategy defines retry delay calculation methods
|
||||
type BackoffStrategy string
|
||||
|
||||
const (
|
||||
BackoffFixed BackoffStrategy = "fixed"
|
||||
BackoffLinear BackoffStrategy = "linear"
|
||||
BackoffExponential BackoffStrategy = "exponential"
|
||||
BackoffCustom BackoffStrategy = "custom"
|
||||
)
|
||||
|
||||
// DLQMetrics tracks dead letter queue statistics
|
||||
type DLQMetrics struct {
|
||||
MessagesAdded int64
|
||||
MessagesReprocessed int64
|
||||
MessagesExpired int64
|
||||
MessagesPermanent int64
|
||||
ReprocessSuccesses int64
|
||||
ReprocessFailures int64
|
||||
QueueSize int64
|
||||
OldestMessage time.Time
|
||||
}
|
||||
|
||||
// MessageReprocessor handles message reprocessing logic
|
||||
type MessageReprocessor interface {
|
||||
Reprocess(ctx context.Context, msg *DLQMessage) error
|
||||
CanReprocess(msg *DLQMessage) bool
|
||||
ShouldRetry(msg *DLQMessage, err error) bool
|
||||
}
|
||||
|
||||
// DefaultMessageReprocessor implements basic reprocessing logic
|
||||
type DefaultMessageReprocessor struct {
|
||||
publisher MessagePublisher
|
||||
}
|
||||
|
||||
// MessagePublisher interface for republishing messages
|
||||
type MessagePublisher interface {
|
||||
Publish(ctx context.Context, msg *Message) error
|
||||
}
|
||||
|
||||
// NewDeadLetterQueue creates a new dead letter queue
|
||||
func NewDeadLetterQueue(config DLQConfig) *DeadLetterQueue {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
dlq := &DeadLetterQueue{
|
||||
messages: make(map[string][]*DLQMessage),
|
||||
config: config,
|
||||
metrics: DLQMetrics{},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Set default configuration values
|
||||
if dlq.config.MaxMessages == 0 {
|
||||
dlq.config.MaxMessages = 10000
|
||||
}
|
||||
if dlq.config.MaxRetries == 0 {
|
||||
dlq.config.MaxRetries = 3
|
||||
}
|
||||
if dlq.config.RetentionTime == 0 {
|
||||
dlq.config.RetentionTime = 24 * time.Hour
|
||||
}
|
||||
if dlq.config.ReprocessInterval == 0 {
|
||||
dlq.config.ReprocessInterval = 5 * time.Minute
|
||||
}
|
||||
if dlq.config.InitialRetryDelay == 0 {
|
||||
dlq.config.InitialRetryDelay = time.Minute
|
||||
}
|
||||
if dlq.config.MaxRetryDelay == 0 {
|
||||
dlq.config.MaxRetryDelay = time.Hour
|
||||
}
|
||||
if dlq.config.BackoffMultiplier == 0 {
|
||||
dlq.config.BackoffMultiplier = 2.0
|
||||
}
|
||||
if dlq.config.BackoffStrategy == "" {
|
||||
dlq.config.BackoffStrategy = BackoffExponential
|
||||
}
|
||||
if dlq.config.ReprocessBatchSize == 0 {
|
||||
dlq.config.ReprocessBatchSize = 10
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
dlq.startCleanupRoutine()
|
||||
|
||||
// Start reprocessing routine if enabled
|
||||
if dlq.config.AutoReprocess {
|
||||
dlq.startReprocessRoutine()
|
||||
}
|
||||
|
||||
return dlq
|
||||
}
|
||||
|
||||
// AddMessage adds a failed message to the dead letter queue
|
||||
func (dlq *DeadLetterQueue) AddMessage(topic string, msg *Message) error {
|
||||
return dlq.AddMessageWithReason(topic, msg, "unknown failure")
|
||||
}
|
||||
|
||||
// AddMessageWithReason adds a failed message with a specific failure reason
|
||||
func (dlq *DeadLetterQueue) AddMessageWithReason(topic string, msg *Message, reason string) error {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
|
||||
// Check if we've exceeded max messages
|
||||
totalMessages := dlq.getTotalMessageCount()
|
||||
if totalMessages >= dlq.config.MaxMessages {
|
||||
// Remove oldest message to make room
|
||||
dlq.removeOldestMessage()
|
||||
}
|
||||
|
||||
// Check if this is a permanent failure
|
||||
permanent := dlq.isPermanentFailure(reason)
|
||||
|
||||
dlqMsg := &DLQMessage{
|
||||
ID: fmt.Sprintf("dlq_%s_%d", topic, time.Now().UnixNano()),
|
||||
OriginalMessage: msg,
|
||||
Topic: topic,
|
||||
FirstFailed: time.Now(),
|
||||
LastAttempt: time.Now(),
|
||||
AttemptCount: 1,
|
||||
MaxRetries: dlq.config.MaxRetries,
|
||||
FailureReason: reason,
|
||||
Metadata: make(map[string]interface{}),
|
||||
Permanent: permanent,
|
||||
}
|
||||
|
||||
if !permanent {
|
||||
dlqMsg.RetryDelay = dlq.calculateRetryDelay(dlqMsg)
|
||||
dlqMsg.NextRetry = time.Now().Add(dlqMsg.RetryDelay)
|
||||
}
|
||||
|
||||
// Add to queue
|
||||
if _, exists := dlq.messages[topic]; !exists {
|
||||
dlq.messages[topic] = make([]*DLQMessage, 0)
|
||||
}
|
||||
dlq.messages[topic] = append(dlq.messages[topic], dlqMsg)
|
||||
|
||||
// Update metrics
|
||||
dlq.metrics.MessagesAdded++
|
||||
dlq.metrics.QueueSize++
|
||||
if permanent {
|
||||
dlq.metrics.MessagesPermanent++
|
||||
}
|
||||
dlq.updateOldestMessage()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessages returns all messages for a topic
|
||||
func (dlq *DeadLetterQueue) GetMessages(topic string) ([]*DLQMessage, error) {
|
||||
dlq.mu.RLock()
|
||||
defer dlq.mu.RUnlock()
|
||||
|
||||
messages, exists := dlq.messages[topic]
|
||||
if !exists {
|
||||
return []*DLQMessage{}, nil
|
||||
}
|
||||
|
||||
// Return a copy to avoid race conditions
|
||||
result := make([]*DLQMessage, len(messages))
|
||||
copy(result, messages)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetAllMessages returns all messages across all topics
|
||||
func (dlq *DeadLetterQueue) GetAllMessages() map[string][]*DLQMessage {
|
||||
dlq.mu.RLock()
|
||||
defer dlq.mu.RUnlock()
|
||||
|
||||
result := make(map[string][]*DLQMessage)
|
||||
for topic, messages := range dlq.messages {
|
||||
result[topic] = make([]*DLQMessage, len(messages))
|
||||
copy(result[topic], messages)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ReprocessMessage attempts to reprocess a specific message
|
||||
func (dlq *DeadLetterQueue) ReprocessMessage(messageID string) error {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
|
||||
// Find message
|
||||
var dlqMsg *DLQMessage
|
||||
var topic string
|
||||
var index int
|
||||
|
||||
for t, messages := range dlq.messages {
|
||||
for i, msg := range messages {
|
||||
if msg.ID == messageID {
|
||||
dlqMsg = msg
|
||||
topic = t
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if dlqMsg != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if dlqMsg == nil {
|
||||
return fmt.Errorf("message not found: %s", messageID)
|
||||
}
|
||||
|
||||
if dlqMsg.Permanent {
|
||||
return fmt.Errorf("message marked as permanent failure: %s", messageID)
|
||||
}
|
||||
|
||||
// Attempt reprocessing
|
||||
err := dlq.attemptReprocess(dlqMsg)
|
||||
if err == nil {
|
||||
// Success - remove from queue
|
||||
dlq.removeMessageByIndex(topic, index)
|
||||
dlq.metrics.ReprocessSuccesses++
|
||||
dlq.metrics.QueueSize--
|
||||
return nil
|
||||
}
|
||||
|
||||
// Failed - update retry information
|
||||
dlqMsg.AttemptCount++
|
||||
dlqMsg.LastAttempt = time.Now()
|
||||
dlqMsg.FailureReason = err.Error()
|
||||
|
||||
if dlqMsg.AttemptCount >= dlqMsg.MaxRetries {
|
||||
dlqMsg.Permanent = true
|
||||
dlq.metrics.MessagesPermanent++
|
||||
} else {
|
||||
dlqMsg.RetryDelay = dlq.calculateRetryDelay(dlqMsg)
|
||||
dlqMsg.NextRetry = time.Now().Add(dlqMsg.RetryDelay)
|
||||
}
|
||||
|
||||
dlq.metrics.ReprocessFailures++
|
||||
return fmt.Errorf("reprocessing failed: %w", err)
|
||||
}
|
||||
|
||||
// PurgeMessages removes all messages for a topic
|
||||
func (dlq *DeadLetterQueue) PurgeMessages(topic string) error {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
|
||||
if messages, exists := dlq.messages[topic]; exists {
|
||||
count := len(messages)
|
||||
delete(dlq.messages, topic)
|
||||
dlq.metrics.QueueSize -= int64(count)
|
||||
dlq.updateOldestMessage()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PurgeAllMessages removes all messages from the queue
|
||||
func (dlq *DeadLetterQueue) PurgeAllMessages() error {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
|
||||
dlq.messages = make(map[string][]*DLQMessage)
|
||||
dlq.metrics.QueueSize = 0
|
||||
dlq.metrics.OldestMessage = time.Time{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessageCount returns the total number of messages in the queue
|
||||
func (dlq *DeadLetterQueue) GetMessageCount() int {
|
||||
dlq.mu.RLock()
|
||||
defer dlq.mu.RUnlock()
|
||||
return dlq.getTotalMessageCount()
|
||||
}
|
||||
|
||||
// GetMetrics returns current DLQ metrics
|
||||
func (dlq *DeadLetterQueue) GetMetrics() DLQMetrics {
|
||||
dlq.mu.RLock()
|
||||
defer dlq.mu.RUnlock()
|
||||
return dlq.metrics
|
||||
}
|
||||
|
||||
// SetReprocessor sets the message reprocessor
|
||||
func (dlq *DeadLetterQueue) SetReprocessor(reprocessor MessageReprocessor) {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
dlq.reprocessor = reprocessor
|
||||
}
|
||||
|
||||
// Cleanup removes expired messages
|
||||
func (dlq *DeadLetterQueue) Cleanup(maxAge time.Duration) error {
|
||||
dlq.mu.Lock()
|
||||
defer dlq.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
expiredCount := 0
|
||||
|
||||
for topic, messages := range dlq.messages {
|
||||
filtered := make([]*DLQMessage, 0)
|
||||
for _, msg := range messages {
|
||||
if msg.FirstFailed.After(cutoff) {
|
||||
filtered = append(filtered, msg)
|
||||
} else {
|
||||
expiredCount++
|
||||
}
|
||||
}
|
||||
dlq.messages[topic] = filtered
|
||||
|
||||
// Remove empty topics
|
||||
if len(filtered) == 0 {
|
||||
delete(dlq.messages, topic)
|
||||
}
|
||||
}
|
||||
|
||||
dlq.metrics.MessagesExpired += int64(expiredCount)
|
||||
dlq.metrics.QueueSize -= int64(expiredCount)
|
||||
dlq.updateOldestMessage()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the dead letter queue
|
||||
func (dlq *DeadLetterQueue) Stop() error {
|
||||
dlq.cancel()
|
||||
|
||||
if dlq.cleanupTicker != nil {
|
||||
dlq.cleanupTicker.Stop()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (dlq *DeadLetterQueue) getTotalMessageCount() int {
|
||||
count := 0
|
||||
for _, messages := range dlq.messages {
|
||||
count += len(messages)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) removeOldestMessage() {
|
||||
var oldestTime time.Time
|
||||
var oldestTopic string
|
||||
var oldestIndex int
|
||||
|
||||
for topic, messages := range dlq.messages {
|
||||
for i, msg := range messages {
|
||||
if oldestTime.IsZero() || msg.FirstFailed.Before(oldestTime) {
|
||||
oldestTime = msg.FirstFailed
|
||||
oldestTopic = topic
|
||||
oldestIndex = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !oldestTime.IsZero() {
|
||||
dlq.removeMessageByIndex(oldestTopic, oldestIndex)
|
||||
dlq.metrics.QueueSize--
|
||||
}
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) removeMessageByIndex(topic string, index int) {
|
||||
messages := dlq.messages[topic]
|
||||
dlq.messages[topic] = append(messages[:index], messages[index+1:]...)
|
||||
|
||||
if len(dlq.messages[topic]) == 0 {
|
||||
delete(dlq.messages, topic)
|
||||
}
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) isPermanentFailure(reason string) bool {
|
||||
for _, pattern := range dlq.config.PermanentFailures {
|
||||
if pattern == reason {
|
||||
return true
|
||||
}
|
||||
// Simple pattern matching (can be enhanced with regex)
|
||||
if len(pattern) > 0 && pattern[len(pattern)-1] == '*' {
|
||||
prefix := pattern[:len(pattern)-1]
|
||||
if len(reason) >= len(prefix) && reason[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) calculateRetryDelay(msg *DLQMessage) time.Duration {
|
||||
switch dlq.config.BackoffStrategy {
|
||||
case BackoffFixed:
|
||||
return dlq.config.InitialRetryDelay
|
||||
|
||||
case BackoffLinear:
|
||||
delay := time.Duration(msg.AttemptCount) * dlq.config.InitialRetryDelay
|
||||
if delay > dlq.config.MaxRetryDelay {
|
||||
return dlq.config.MaxRetryDelay
|
||||
}
|
||||
return delay
|
||||
|
||||
case BackoffExponential:
|
||||
delay := time.Duration(float64(dlq.config.InitialRetryDelay) *
|
||||
pow(dlq.config.BackoffMultiplier, float64(msg.AttemptCount-1)))
|
||||
if delay > dlq.config.MaxRetryDelay {
|
||||
return dlq.config.MaxRetryDelay
|
||||
}
|
||||
return delay
|
||||
|
||||
default:
|
||||
return dlq.config.InitialRetryDelay
|
||||
}
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) attemptReprocess(msg *DLQMessage) error {
|
||||
if dlq.reprocessor == nil {
|
||||
return fmt.Errorf("no reprocessor configured")
|
||||
}
|
||||
|
||||
if !dlq.reprocessor.CanReprocess(msg) {
|
||||
return fmt.Errorf("message cannot be reprocessed")
|
||||
}
|
||||
|
||||
return dlq.reprocessor.Reprocess(dlq.ctx, msg)
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) updateOldestMessage() {
|
||||
var oldest time.Time
|
||||
|
||||
for _, messages := range dlq.messages {
|
||||
for _, msg := range messages {
|
||||
if oldest.IsZero() || msg.FirstFailed.Before(oldest) {
|
||||
oldest = msg.FirstFailed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dlq.metrics.OldestMessage = oldest
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) startCleanupRoutine() {
|
||||
dlq.cleanupTicker = time.NewTicker(dlq.config.ReprocessInterval)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-dlq.cleanupTicker.C:
|
||||
dlq.Cleanup(dlq.config.RetentionTime)
|
||||
case <-dlq.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) startReprocessRoutine() {
|
||||
ticker := time.NewTicker(dlq.config.ReprocessInterval)
|
||||
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
dlq.processRetryableMessages()
|
||||
case <-dlq.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) processRetryableMessages() {
|
||||
dlq.mu.Lock()
|
||||
retryable := dlq.getRetryableMessages()
|
||||
dlq.mu.Unlock()
|
||||
|
||||
// Sort by next retry time
|
||||
sort.Slice(retryable, func(i, j int) bool {
|
||||
return retryable[i].NextRetry.Before(retryable[j].NextRetry)
|
||||
})
|
||||
|
||||
// Process batch
|
||||
batchSize := dlq.config.ReprocessBatchSize
|
||||
if len(retryable) < batchSize {
|
||||
batchSize = len(retryable)
|
||||
}
|
||||
|
||||
for i := 0; i < batchSize; i++ {
|
||||
msg := retryable[i]
|
||||
if time.Now().After(msg.NextRetry) {
|
||||
dlq.ReprocessMessage(msg.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dlq *DeadLetterQueue) getRetryableMessages() []*DLQMessage {
|
||||
var retryable []*DLQMessage
|
||||
|
||||
for _, messages := range dlq.messages {
|
||||
for _, msg := range messages {
|
||||
if !msg.Permanent && msg.AttemptCount < msg.MaxRetries {
|
||||
retryable = append(retryable, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return retryable
|
||||
}
|
||||
|
||||
// Implementation of DefaultMessageReprocessor
|
||||
|
||||
func NewDefaultMessageReprocessor(publisher MessagePublisher) *DefaultMessageReprocessor {
|
||||
return &DefaultMessageReprocessor{
|
||||
publisher: publisher,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultMessageReprocessor) Reprocess(ctx context.Context, msg *DLQMessage) error {
|
||||
if r.publisher == nil {
|
||||
return fmt.Errorf("no publisher configured")
|
||||
}
|
||||
|
||||
return r.publisher.Publish(ctx, msg.OriginalMessage)
|
||||
}
|
||||
|
||||
func (r *DefaultMessageReprocessor) CanReprocess(msg *DLQMessage) bool {
|
||||
return !msg.Permanent && msg.AttemptCount < msg.MaxRetries
|
||||
}
|
||||
|
||||
func (r *DefaultMessageReprocessor) ShouldRetry(msg *DLQMessage, err error) bool {
|
||||
// Simple retry logic - can be enhanced based on error types
|
||||
return msg.AttemptCount < msg.MaxRetries
|
||||
}
|
||||
|
||||
// Helper function for power calculation
|
||||
func pow(base, exp float64) float64 {
|
||||
if exp == 0 {
|
||||
return 1
|
||||
}
|
||||
result := base
|
||||
for i := 1; i < int(exp); i++ {
|
||||
result *= base
|
||||
}
|
||||
return result
|
||||
}
|
||||
612
orig/pkg/transport/failover.go
Normal file
612
orig/pkg/transport/failover.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FailoverManager handles transport failover and redundancy
|
||||
type FailoverManager struct {
|
||||
transports map[string]*ManagedTransport
|
||||
primaryTransport string
|
||||
backupTransports []string
|
||||
failoverPolicy FailoverPolicy
|
||||
healthChecker HealthChecker
|
||||
circuitBreaker *CircuitBreaker
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
metrics FailoverMetrics
|
||||
notifications chan FailoverEvent
|
||||
}
|
||||
|
||||
// ManagedTransport wraps a transport with management metadata
|
||||
type ManagedTransport struct {
|
||||
Transport Transport
|
||||
ID string
|
||||
Name string
|
||||
Priority int
|
||||
Status TransportStatus
|
||||
LastHealthCheck time.Time
|
||||
FailureCount int
|
||||
LastFailure time.Time
|
||||
Config TransportConfig
|
||||
Metrics TransportMetrics
|
||||
}
|
||||
|
||||
// TransportStatus represents the current status of a transport
|
||||
type TransportStatus string
|
||||
|
||||
const (
|
||||
StatusHealthy TransportStatus = "healthy"
|
||||
StatusDegraded TransportStatus = "degraded"
|
||||
StatusUnhealthy TransportStatus = "unhealthy"
|
||||
StatusDisabled TransportStatus = "disabled"
|
||||
)
|
||||
|
||||
// FailoverPolicy defines when and how to failover
|
||||
type FailoverPolicy struct {
|
||||
FailureThreshold int // Number of failures before marking unhealthy
|
||||
HealthCheckInterval time.Duration // How often to check health
|
||||
FailoverTimeout time.Duration // Timeout for failover operations
|
||||
RetryInterval time.Duration // Interval between retry attempts
|
||||
MaxRetries int // Maximum retry attempts
|
||||
AutoFailback bool // Whether to automatically failback to primary
|
||||
FailbackDelay time.Duration // Delay before attempting failback
|
||||
RequireAllHealthy bool // Whether all transports must be healthy
|
||||
}
|
||||
|
||||
// FailoverMetrics tracks failover statistics
|
||||
type FailoverMetrics struct {
|
||||
TotalFailovers int64 `json:"total_failovers"`
|
||||
TotalFailbacks int64 `json:"total_failbacks"`
|
||||
CurrentTransport string `json:"current_transport"`
|
||||
LastFailover time.Time `json:"last_failover"`
|
||||
LastFailback time.Time `json:"last_failback"`
|
||||
FailoverDuration time.Duration `json:"failover_duration"`
|
||||
FailoverSuccessRate float64 `json:"failover_success_rate"`
|
||||
HealthCheckFailures int64 `json:"health_check_failures"`
|
||||
CircuitBreakerTrips int64 `json:"circuit_breaker_trips"`
|
||||
}
|
||||
|
||||
// FailoverEvent represents a failover-related event
|
||||
type FailoverEvent struct {
|
||||
Type FailoverEventType `json:"type"`
|
||||
FromTransport string `json:"from_transport"`
|
||||
ToTransport string `json:"to_transport"`
|
||||
Reason string `json:"reason"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Success bool `json:"success"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
}
|
||||
|
||||
// FailoverEventType defines types of failover events
|
||||
type FailoverEventType string
|
||||
|
||||
const (
|
||||
EventFailover FailoverEventType = "failover"
|
||||
EventFailback FailoverEventType = "failback"
|
||||
EventHealthCheck FailoverEventType = "health_check"
|
||||
EventCircuitBreak FailoverEventType = "circuit_break"
|
||||
EventRecovery FailoverEventType = "recovery"
|
||||
)
|
||||
|
||||
// HealthChecker interface for custom health checking logic
|
||||
type HealthChecker interface {
|
||||
CheckHealth(ctx context.Context, transport Transport) (bool, error)
|
||||
GetHealthScore(transport Transport) float64
|
||||
}
|
||||
|
||||
// NewFailoverManager creates a new failover manager
|
||||
func NewFailoverManager(policy FailoverPolicy) *FailoverManager {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
fm := &FailoverManager{
|
||||
transports: make(map[string]*ManagedTransport),
|
||||
failoverPolicy: policy,
|
||||
healthChecker: NewDefaultHealthChecker(),
|
||||
circuitBreaker: NewCircuitBreaker(CircuitBreakerConfig{
|
||||
FailureThreshold: policy.FailureThreshold,
|
||||
RecoveryTimeout: policy.RetryInterval,
|
||||
MaxRetries: policy.MaxRetries,
|
||||
}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
notifications: make(chan FailoverEvent, 100),
|
||||
}
|
||||
|
||||
// Start background routines
|
||||
go fm.healthCheckLoop()
|
||||
go fm.failoverMonitorLoop()
|
||||
|
||||
return fm
|
||||
}
|
||||
|
||||
// RegisterTransport adds a transport to the failover manager
|
||||
func (fm *FailoverManager) RegisterTransport(id, name string, transport Transport, priority int, config TransportConfig) error {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
managedTransport := &ManagedTransport{
|
||||
Transport: transport,
|
||||
ID: id,
|
||||
Name: name,
|
||||
Priority: priority,
|
||||
Status: StatusHealthy,
|
||||
LastHealthCheck: time.Now(),
|
||||
Config: config,
|
||||
}
|
||||
|
||||
fm.transports[id] = managedTransport
|
||||
|
||||
// Set as primary if it's the first or highest priority transport
|
||||
if fm.primaryTransport == "" || priority > fm.transports[fm.primaryTransport].Priority {
|
||||
fm.primaryTransport = id
|
||||
} else {
|
||||
fm.backupTransports = append(fm.backupTransports, id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnregisterTransport removes a transport from the failover manager
|
||||
func (fm *FailoverManager) UnregisterTransport(id string) error {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
if _, exists := fm.transports[id]; !exists {
|
||||
return fmt.Errorf("transport not found: %s", id)
|
||||
}
|
||||
|
||||
delete(fm.transports, id)
|
||||
|
||||
// Update primary if needed
|
||||
if fm.primaryTransport == id {
|
||||
fm.selectNewPrimary()
|
||||
}
|
||||
|
||||
// Remove from backups
|
||||
for i, backupID := range fm.backupTransports {
|
||||
if backupID == id {
|
||||
fm.backupTransports = append(fm.backupTransports[:i], fm.backupTransports[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveTransport returns the currently active transport
|
||||
func (fm *FailoverManager) GetActiveTransport() (Transport, error) {
|
||||
fm.mu.RLock()
|
||||
defer fm.mu.RUnlock()
|
||||
|
||||
if fm.primaryTransport == "" {
|
||||
return nil, fmt.Errorf("no active transport available")
|
||||
}
|
||||
|
||||
transport, exists := fm.transports[fm.primaryTransport]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("primary transport not found: %s", fm.primaryTransport)
|
||||
}
|
||||
|
||||
if transport.Status == StatusHealthy || transport.Status == StatusDegraded {
|
||||
return transport.Transport, nil
|
||||
}
|
||||
|
||||
// Try to failover to a backup
|
||||
if err := fm.performFailover(); err != nil {
|
||||
return nil, fmt.Errorf("failover failed: %w", err)
|
||||
}
|
||||
|
||||
// Return new primary after failover
|
||||
newPrimary := fm.transports[fm.primaryTransport]
|
||||
return newPrimary.Transport, nil
|
||||
}
|
||||
|
||||
// Send sends a message through the active transport with automatic failover
|
||||
func (fm *FailoverManager) Send(ctx context.Context, msg *Message) error {
|
||||
transport, err := fm.GetActiveTransport()
|
||||
if err != nil {
|
||||
return fmt.Errorf("no available transport: %w", err)
|
||||
}
|
||||
|
||||
// Try to send through circuit breaker
|
||||
return fm.circuitBreaker.Execute(func() error {
|
||||
return transport.Send(ctx, msg)
|
||||
})
|
||||
}
|
||||
|
||||
// Receive receives messages from the active transport
|
||||
func (fm *FailoverManager) Receive(ctx context.Context) (<-chan *Message, error) {
|
||||
transport, err := fm.GetActiveTransport()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no available transport: %w", err)
|
||||
}
|
||||
|
||||
return transport.Receive(ctx)
|
||||
}
|
||||
|
||||
// ForceFailover manually triggers a failover to a specific transport
|
||||
func (fm *FailoverManager) ForceFailover(targetTransportID string) error {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
target, exists := fm.transports[targetTransportID]
|
||||
if !exists {
|
||||
return fmt.Errorf("target transport not found: %s", targetTransportID)
|
||||
}
|
||||
|
||||
if target.Status != StatusHealthy && target.Status != StatusDegraded {
|
||||
return fmt.Errorf("target transport is not healthy: %s", target.Status)
|
||||
}
|
||||
|
||||
return fm.switchPrimary(targetTransportID, "manual failover")
|
||||
}
|
||||
|
||||
// GetTransportStatus returns the status of all transports
|
||||
func (fm *FailoverManager) GetTransportStatus() map[string]TransportStatus {
|
||||
fm.mu.RLock()
|
||||
defer fm.mu.RUnlock()
|
||||
|
||||
status := make(map[string]TransportStatus)
|
||||
for id, transport := range fm.transports {
|
||||
status[id] = transport.Status
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
// GetMetrics returns failover metrics
|
||||
func (fm *FailoverManager) GetMetrics() FailoverMetrics {
|
||||
fm.mu.RLock()
|
||||
defer fm.mu.RUnlock()
|
||||
return fm.metrics
|
||||
}
|
||||
|
||||
// GetNotifications returns a channel for failover events
|
||||
func (fm *FailoverManager) GetNotifications() <-chan FailoverEvent {
|
||||
return fm.notifications
|
||||
}
|
||||
|
||||
// SetHealthChecker sets a custom health checker
|
||||
func (fm *FailoverManager) SetHealthChecker(checker HealthChecker) {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
fm.healthChecker = checker
|
||||
}
|
||||
|
||||
// Stop gracefully stops the failover manager
|
||||
func (fm *FailoverManager) Stop() error {
|
||||
fm.cancel()
|
||||
close(fm.notifications)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Private methods
|
||||
|
||||
func (fm *FailoverManager) healthCheckLoop() {
|
||||
ticker := time.NewTicker(fm.failoverPolicy.HealthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-fm.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
fm.performHealthChecks()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fm *FailoverManager) failoverMonitorLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-fm.ctx.Done():
|
||||
return
|
||||
default:
|
||||
if fm.shouldPerformFailover() {
|
||||
if err := fm.performFailover(); err != nil {
|
||||
fm.metrics.HealthCheckFailures++
|
||||
}
|
||||
}
|
||||
|
||||
if fm.shouldPerformFailback() {
|
||||
if err := fm.performFailback(); err != nil {
|
||||
fm.metrics.HealthCheckFailures++
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second) // Check every second
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fm *FailoverManager) performHealthChecks() {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
for id, transport := range fm.transports {
|
||||
healthy, err := fm.healthChecker.CheckHealth(fm.ctx, transport.Transport)
|
||||
transport.LastHealthCheck = time.Now()
|
||||
|
||||
previousStatus := transport.Status
|
||||
|
||||
if err != nil || !healthy {
|
||||
transport.FailureCount++
|
||||
transport.LastFailure = time.Now()
|
||||
|
||||
if transport.FailureCount >= fm.failoverPolicy.FailureThreshold {
|
||||
transport.Status = StatusUnhealthy
|
||||
} else {
|
||||
transport.Status = StatusDegraded
|
||||
}
|
||||
} else {
|
||||
// Reset failure count on successful health check
|
||||
transport.FailureCount = 0
|
||||
transport.Status = StatusHealthy
|
||||
}
|
||||
|
||||
// Notify status change
|
||||
if previousStatus != transport.Status {
|
||||
fm.notifyEvent(FailoverEvent{
|
||||
Type: EventHealthCheck,
|
||||
ToTransport: id,
|
||||
Reason: fmt.Sprintf("status changed from %s to %s", previousStatus, transport.Status),
|
||||
Timestamp: time.Now(),
|
||||
Success: transport.Status == StatusHealthy,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fm *FailoverManager) shouldPerformFailover() bool {
|
||||
fm.mu.RLock()
|
||||
defer fm.mu.RUnlock()
|
||||
|
||||
if fm.primaryTransport == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
primary := fm.transports[fm.primaryTransport]
|
||||
return primary.Status == StatusUnhealthy
|
||||
}
|
||||
|
||||
func (fm *FailoverManager) shouldPerformFailback() bool {
|
||||
if !fm.failoverPolicy.AutoFailback {
|
||||
return false
|
||||
}
|
||||
|
||||
fm.mu.RLock()
|
||||
defer fm.mu.RUnlock()
|
||||
|
||||
// Find the highest priority healthy transport
|
||||
var highestPriority int
|
||||
var highestPriorityID string
|
||||
|
||||
for id, transport := range fm.transports {
|
||||
if transport.Status == StatusHealthy && transport.Priority > highestPriority {
|
||||
highestPriority = transport.Priority
|
||||
highestPriorityID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Failback if there's a higher priority transport available
|
||||
return highestPriorityID != "" && highestPriorityID != fm.primaryTransport
|
||||
}
|
||||
|
||||
func (fm *FailoverManager) performFailover() error {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
// Find the best backup transport
|
||||
var bestBackup string
|
||||
var bestPriority int
|
||||
|
||||
for _, backupID := range fm.backupTransports {
|
||||
backup := fm.transports[backupID]
|
||||
if (backup.Status == StatusHealthy || backup.Status == StatusDegraded) && backup.Priority > bestPriority {
|
||||
bestBackup = backupID
|
||||
bestPriority = backup.Priority
|
||||
}
|
||||
}
|
||||
|
||||
if bestBackup == "" {
|
||||
return fmt.Errorf("no healthy backup transport available")
|
||||
}
|
||||
|
||||
return fm.switchPrimary(bestBackup, "automatic failover")
|
||||
}
|
||||
|
||||
func (fm *FailoverManager) performFailback() error {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
// Find the highest priority healthy transport
|
||||
var highestPriority int
|
||||
var highestPriorityID string
|
||||
|
||||
for id, transport := range fm.transports {
|
||||
if transport.Status == StatusHealthy && transport.Priority > highestPriority {
|
||||
highestPriority = transport.Priority
|
||||
highestPriorityID = id
|
||||
}
|
||||
}
|
||||
|
||||
if highestPriorityID == "" || highestPriorityID == fm.primaryTransport {
|
||||
return nil // No failback needed
|
||||
}
|
||||
|
||||
// Wait for failback delay
|
||||
if time.Since(fm.metrics.LastFailover) < fm.failoverPolicy.FailbackDelay {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fm.switchPrimary(highestPriorityID, "automatic failback")
|
||||
}
|
||||
|
||||
func (fm *FailoverManager) switchPrimary(newPrimaryID, reason string) error {
|
||||
start := time.Now()
|
||||
oldPrimary := fm.primaryTransport
|
||||
|
||||
// Update primary and backup lists
|
||||
fm.primaryTransport = newPrimaryID
|
||||
|
||||
// Rebuild backup list
|
||||
fm.backupTransports = make([]string, 0)
|
||||
for id := range fm.transports {
|
||||
if id != newPrimaryID {
|
||||
fm.backupTransports = append(fm.backupTransports, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
duration := time.Since(start)
|
||||
if oldPrimary != newPrimaryID {
|
||||
if reason == "automatic failback" {
|
||||
fm.metrics.TotalFailbacks++
|
||||
fm.metrics.LastFailback = time.Now()
|
||||
} else {
|
||||
fm.metrics.TotalFailovers++
|
||||
fm.metrics.LastFailover = time.Now()
|
||||
}
|
||||
fm.metrics.FailoverDuration = duration
|
||||
fm.metrics.CurrentTransport = newPrimaryID
|
||||
}
|
||||
|
||||
// Notify
|
||||
eventType := EventFailover
|
||||
if reason == "automatic failback" {
|
||||
eventType = EventFailback
|
||||
}
|
||||
|
||||
fm.notifyEvent(FailoverEvent{
|
||||
Type: eventType,
|
||||
FromTransport: oldPrimary,
|
||||
ToTransport: newPrimaryID,
|
||||
Reason: reason,
|
||||
Timestamp: time.Now(),
|
||||
Success: true,
|
||||
Duration: duration,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fm *FailoverManager) selectNewPrimary() {
|
||||
var bestID string
|
||||
var bestPriority int
|
||||
|
||||
for id, transport := range fm.transports {
|
||||
if transport.Status == StatusHealthy && transport.Priority > bestPriority {
|
||||
bestID = id
|
||||
bestPriority = transport.Priority
|
||||
}
|
||||
}
|
||||
|
||||
fm.primaryTransport = bestID
|
||||
}
|
||||
|
||||
func (fm *FailoverManager) notifyEvent(event FailoverEvent) {
|
||||
select {
|
||||
case fm.notifications <- event:
|
||||
default:
|
||||
// Channel full, drop event
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHealthChecker implements basic health checking
|
||||
type DefaultHealthChecker struct{}
|
||||
|
||||
func NewDefaultHealthChecker() *DefaultHealthChecker {
|
||||
return &DefaultHealthChecker{}
|
||||
}
|
||||
|
||||
func (dhc *DefaultHealthChecker) CheckHealth(ctx context.Context, transport Transport) (bool, error) {
|
||||
health := transport.Health()
|
||||
return health.Status == "healthy", nil
|
||||
}
|
||||
|
||||
func (dhc *DefaultHealthChecker) GetHealthScore(transport Transport) float64 {
|
||||
health := transport.Health()
|
||||
switch health.Status {
|
||||
case "healthy":
|
||||
return 1.0
|
||||
case "degraded":
|
||||
return 0.5
|
||||
default:
|
||||
return 0.0
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements circuit breaker pattern for transport operations
|
||||
type CircuitBreaker struct {
|
||||
config CircuitBreakerConfig
|
||||
state CircuitBreakerState
|
||||
failureCount int
|
||||
lastFailure time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type CircuitBreakerConfig struct {
|
||||
FailureThreshold int
|
||||
RecoveryTimeout time.Duration
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
type CircuitBreakerState string
|
||||
|
||||
const (
|
||||
StateClosed CircuitBreakerState = "closed"
|
||||
StateOpen CircuitBreakerState = "open"
|
||||
StateHalfOpen CircuitBreakerState = "half_open"
|
||||
)
|
||||
|
||||
func NewCircuitBreaker(config CircuitBreakerConfig) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
config: config,
|
||||
state: StateClosed,
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *CircuitBreaker) Execute(operation func() error) error {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
if cb.state == StateOpen {
|
||||
if time.Since(cb.lastFailure) < cb.config.RecoveryTimeout {
|
||||
return fmt.Errorf("circuit breaker is open")
|
||||
}
|
||||
cb.state = StateHalfOpen
|
||||
}
|
||||
|
||||
err := operation()
|
||||
if err != nil {
|
||||
cb.onFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
cb.onSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cb *CircuitBreaker) onFailure() {
|
||||
cb.failureCount++
|
||||
cb.lastFailure = time.Now()
|
||||
|
||||
if cb.failureCount >= cb.config.FailureThreshold {
|
||||
cb.state = StateOpen
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *CircuitBreaker) onSuccess() {
|
||||
cb.failureCount = 0
|
||||
cb.state = StateClosed
|
||||
}
|
||||
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
return cb.state
|
||||
}
|
||||
230
orig/pkg/transport/memory_transport.go
Normal file
230
orig/pkg/transport/memory_transport.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryTransport implements in-memory message transport for local communication
|
||||
type MemoryTransport struct {
|
||||
channels map[string]chan *Message
|
||||
metrics TransportMetrics
|
||||
connected bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMemoryTransport creates a new in-memory transport
|
||||
func NewMemoryTransport() *MemoryTransport {
|
||||
return &MemoryTransport{
|
||||
channels: make(map[string]chan *Message),
|
||||
metrics: TransportMetrics{},
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes the transport connection
|
||||
func (mt *MemoryTransport) Connect(ctx context.Context) error {
|
||||
mt.mu.Lock()
|
||||
defer mt.mu.Unlock()
|
||||
|
||||
if mt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
mt.connected = true
|
||||
mt.metrics.Connections = 1
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect closes the transport connection
|
||||
func (mt *MemoryTransport) Disconnect(ctx context.Context) error {
|
||||
mt.mu.Lock()
|
||||
defer mt.mu.Unlock()
|
||||
|
||||
if !mt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close all channels
|
||||
for _, ch := range mt.channels {
|
||||
close(ch)
|
||||
}
|
||||
mt.channels = make(map[string]chan *Message)
|
||||
mt.connected = false
|
||||
mt.metrics.Connections = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send transmits a message through the memory transport
|
||||
func (mt *MemoryTransport) Send(ctx context.Context, msg *Message) error {
|
||||
start := time.Now()
|
||||
|
||||
mt.mu.RLock()
|
||||
if !mt.connected {
|
||||
mt.mu.RUnlock()
|
||||
mt.metrics.Errors++
|
||||
return fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
// Get or create channel for topic
|
||||
ch, exists := mt.channels[msg.Topic]
|
||||
if !exists {
|
||||
mt.mu.RUnlock()
|
||||
mt.mu.Lock()
|
||||
// Double-check after acquiring write lock
|
||||
if ch, exists = mt.channels[msg.Topic]; !exists {
|
||||
ch = make(chan *Message, 1000) // Buffered channel
|
||||
mt.channels[msg.Topic] = ch
|
||||
}
|
||||
mt.mu.Unlock()
|
||||
} else {
|
||||
mt.mu.RUnlock()
|
||||
}
|
||||
|
||||
// Send message
|
||||
select {
|
||||
case ch <- msg:
|
||||
mt.updateSendMetrics(msg, time.Since(start))
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
mt.metrics.Errors++
|
||||
return ctx.Err()
|
||||
default:
|
||||
mt.metrics.Errors++
|
||||
return fmt.Errorf("channel full for topic: %s", msg.Topic)
|
||||
}
|
||||
}
|
||||
|
||||
// Receive returns a channel for receiving messages
|
||||
func (mt *MemoryTransport) Receive(ctx context.Context) (<-chan *Message, error) {
|
||||
mt.mu.RLock()
|
||||
defer mt.mu.RUnlock()
|
||||
|
||||
if !mt.connected {
|
||||
return nil, fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
// Create a merged channel that receives from all topic channels
|
||||
merged := make(chan *Message, 1000)
|
||||
|
||||
go func() {
|
||||
defer close(merged)
|
||||
|
||||
// Use a wait group to handle multiple topic channels
|
||||
var wg sync.WaitGroup
|
||||
|
||||
mt.mu.RLock()
|
||||
for topic, ch := range mt.channels {
|
||||
wg.Add(1)
|
||||
go func(topicCh <-chan *Message, topicName string) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-topicCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case merged <- msg:
|
||||
mt.updateReceiveMetrics(msg)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}(ch, topic)
|
||||
}
|
||||
mt.mu.RUnlock()
|
||||
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
// Health returns the health status of the transport
|
||||
func (mt *MemoryTransport) Health() ComponentHealth {
|
||||
mt.mu.RLock()
|
||||
defer mt.mu.RUnlock()
|
||||
|
||||
status := "unhealthy"
|
||||
if mt.connected {
|
||||
status = "healthy"
|
||||
}
|
||||
|
||||
return ComponentHealth{
|
||||
Status: status,
|
||||
LastCheck: time.Now(),
|
||||
ResponseTime: time.Microsecond, // Very fast for memory transport
|
||||
ErrorCount: mt.metrics.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns transport-specific metrics
|
||||
func (mt *MemoryTransport) GetMetrics() TransportMetrics {
|
||||
mt.mu.RLock()
|
||||
defer mt.mu.RUnlock()
|
||||
|
||||
// Create a copy to avoid race conditions
|
||||
return TransportMetrics{
|
||||
BytesSent: mt.metrics.BytesSent,
|
||||
BytesReceived: mt.metrics.BytesReceived,
|
||||
MessagesSent: mt.metrics.MessagesSent,
|
||||
MessagesReceived: mt.metrics.MessagesReceived,
|
||||
Connections: mt.metrics.Connections,
|
||||
Errors: mt.metrics.Errors,
|
||||
Latency: mt.metrics.Latency,
|
||||
}
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (mt *MemoryTransport) updateSendMetrics(msg *Message, latency time.Duration) {
|
||||
mt.mu.Lock()
|
||||
defer mt.mu.Unlock()
|
||||
|
||||
mt.metrics.MessagesSent++
|
||||
mt.metrics.Latency = latency
|
||||
|
||||
// Estimate message size (simplified)
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
mt.metrics.BytesSent += messageSize
|
||||
}
|
||||
|
||||
func (mt *MemoryTransport) updateReceiveMetrics(msg *Message) {
|
||||
mt.mu.Lock()
|
||||
defer mt.mu.Unlock()
|
||||
|
||||
mt.metrics.MessagesReceived++
|
||||
|
||||
// Estimate message size (simplified)
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
mt.metrics.BytesReceived += messageSize
|
||||
}
|
||||
|
||||
// GetChannelForTopic returns the channel for a specific topic (for testing/debugging)
|
||||
func (mt *MemoryTransport) GetChannelForTopic(topic string) (<-chan *Message, bool) {
|
||||
mt.mu.RLock()
|
||||
defer mt.mu.RUnlock()
|
||||
|
||||
ch, exists := mt.channels[topic]
|
||||
return ch, exists
|
||||
}
|
||||
|
||||
// GetTopicCount returns the number of active topic channels
|
||||
func (mt *MemoryTransport) GetTopicCount() int {
|
||||
mt.mu.RLock()
|
||||
defer mt.mu.RUnlock()
|
||||
|
||||
return len(mt.channels)
|
||||
}
|
||||
410
orig/pkg/transport/message_bus.go
Normal file
410
orig/pkg/transport/message_bus.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MessageType represents the type of message being sent
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
MessageTypeEvent MessageType = "event"
|
||||
MessageTypeCommand MessageType = "command"
|
||||
MessageTypeQuery MessageType = "query"
|
||||
MessageTypeResponse MessageType = "response"
|
||||
MessageTypeNotification MessageType = "notification"
|
||||
MessageTypeHeartbeat MessageType = "heartbeat"
|
||||
)
|
||||
|
||||
// MessagePriority defines message processing priority
|
||||
type MessagePriority int
|
||||
|
||||
const (
|
||||
PriorityLow MessagePriority = iota
|
||||
PriorityNormal
|
||||
PriorityHigh
|
||||
PriorityCritical
|
||||
)
|
||||
|
||||
// Message represents a universal message in the system
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Type MessageType `json:"type"`
|
||||
Topic string `json:"topic"`
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target,omitempty"`
|
||||
Priority MessagePriority `json:"priority"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Data interface{} `json:"data"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
CorrelationID string `json:"correlation_id,omitempty"`
|
||||
TTL time.Duration `json:"ttl,omitempty"`
|
||||
Retries int `json:"retries"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MessageHandler processes incoming messages
|
||||
type MessageHandler func(ctx context.Context, msg *Message) error
|
||||
|
||||
// MessageFilter determines if a message should be processed
|
||||
type MessageFilter func(msg *Message) bool
|
||||
|
||||
// Subscription represents a topic subscription
|
||||
type Subscription struct {
|
||||
ID string
|
||||
Topic string
|
||||
Filter MessageFilter
|
||||
Handler MessageHandler
|
||||
Options SubscriptionOptions
|
||||
created time.Time
|
||||
active bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SubscriptionOptions configures subscription behavior
|
||||
type SubscriptionOptions struct {
|
||||
QueueSize int
|
||||
BatchSize int
|
||||
BatchTimeout time.Duration
|
||||
DLQEnabled bool
|
||||
RetryEnabled bool
|
||||
Persistent bool
|
||||
Durable bool
|
||||
}
|
||||
|
||||
// MessageBusInterface defines the universal message bus contract
|
||||
type MessageBusInterface interface {
|
||||
// Core messaging operations
|
||||
Publish(ctx context.Context, msg *Message) error
|
||||
Subscribe(topic string, handler MessageHandler, opts ...SubscriptionOption) (*Subscription, error)
|
||||
Unsubscribe(subscriptionID string) error
|
||||
|
||||
// Advanced messaging patterns
|
||||
Request(ctx context.Context, msg *Message, timeout time.Duration) (*Message, error)
|
||||
Reply(ctx context.Context, originalMsg *Message, response *Message) error
|
||||
|
||||
// Topic management
|
||||
CreateTopic(topic string, config TopicConfig) error
|
||||
DeleteTopic(topic string) error
|
||||
ListTopics() []string
|
||||
GetTopicInfo(topic string) (*TopicInfo, error)
|
||||
|
||||
// Queue operations
|
||||
QueueMessage(topic string, msg *Message) error
|
||||
DequeueMessage(topic string, timeout time.Duration) (*Message, error)
|
||||
PeekMessage(topic string) (*Message, error)
|
||||
|
||||
// Dead letter queue
|
||||
GetDLQMessages(topic string) ([]*Message, error)
|
||||
ReprocessDLQMessage(messageID string) error
|
||||
PurgeDLQ(topic string) error
|
||||
|
||||
// Lifecycle management
|
||||
Start(ctx context.Context) error
|
||||
Stop(ctx context.Context) error
|
||||
Health() HealthStatus
|
||||
|
||||
// Metrics and monitoring
|
||||
GetMetrics() MessageBusMetrics
|
||||
GetSubscriptions() []*Subscription
|
||||
GetActiveConnections() int
|
||||
}
|
||||
|
||||
// SubscriptionOption configures subscription behavior
|
||||
type SubscriptionOption func(*SubscriptionOptions)
|
||||
|
||||
// TopicConfig defines topic configuration
|
||||
type TopicConfig struct {
|
||||
Persistent bool
|
||||
Replicated bool
|
||||
RetentionPolicy RetentionPolicy
|
||||
Partitions int
|
||||
MaxMessageSize int64
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// RetentionPolicy defines message retention behavior
|
||||
type RetentionPolicy struct {
|
||||
MaxMessages int
|
||||
MaxAge time.Duration
|
||||
MaxSize int64
|
||||
}
|
||||
|
||||
// TopicInfo provides topic statistics
|
||||
type TopicInfo struct {
|
||||
Name string
|
||||
Config TopicConfig
|
||||
MessageCount int64
|
||||
SubscriberCount int
|
||||
LastActivity time.Time
|
||||
SizeBytes int64
|
||||
}
|
||||
|
||||
// HealthStatus represents system health
|
||||
type HealthStatus struct {
|
||||
Status string
|
||||
Uptime time.Duration
|
||||
LastCheck time.Time
|
||||
Components map[string]ComponentHealth
|
||||
Errors []HealthError
|
||||
}
|
||||
|
||||
// ComponentHealth represents component-specific health
|
||||
type ComponentHealth struct {
|
||||
Status string
|
||||
LastCheck time.Time
|
||||
ResponseTime time.Duration
|
||||
ErrorCount int64
|
||||
}
|
||||
|
||||
// HealthError represents a health check error
|
||||
type HealthError struct {
|
||||
Component string
|
||||
Message string
|
||||
Timestamp time.Time
|
||||
Severity string
|
||||
}
|
||||
|
||||
// MessageBusMetrics provides operational metrics
|
||||
type MessageBusMetrics struct {
|
||||
MessagesPublished int64
|
||||
MessagesConsumed int64
|
||||
MessagesFailed int64
|
||||
MessagesInDLQ int64
|
||||
ActiveSubscriptions int
|
||||
TopicCount int
|
||||
AverageLatency time.Duration
|
||||
ThroughputPerSec float64
|
||||
ErrorRate float64
|
||||
MemoryUsage int64
|
||||
CPUUsage float64
|
||||
}
|
||||
|
||||
// TransportType defines available transport mechanisms
|
||||
type TransportType string
|
||||
|
||||
const (
|
||||
TransportMemory TransportType = "memory"
|
||||
TransportUnixSocket TransportType = "unix"
|
||||
TransportTCP TransportType = "tcp"
|
||||
TransportWebSocket TransportType = "websocket"
|
||||
TransportRedis TransportType = "redis"
|
||||
TransportNATS TransportType = "nats"
|
||||
)
|
||||
|
||||
// TransportConfig configures transport layer
|
||||
type TransportConfig struct {
|
||||
Type TransportType
|
||||
Address string
|
||||
Options map[string]interface{}
|
||||
RetryConfig RetryConfig
|
||||
SecurityConfig SecurityConfig
|
||||
}
|
||||
|
||||
// RetryConfig defines retry behavior
|
||||
type RetryConfig struct {
|
||||
MaxRetries int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
BackoffFactor float64
|
||||
Jitter bool
|
||||
}
|
||||
|
||||
// SecurityConfig defines security settings
|
||||
type SecurityConfig struct {
|
||||
Enabled bool
|
||||
TLSConfig *TLSConfig
|
||||
AuthConfig *AuthConfig
|
||||
Encryption bool
|
||||
Compression bool
|
||||
}
|
||||
|
||||
// TLSConfig for secure transport
|
||||
type TLSConfig struct {
|
||||
CertFile string
|
||||
KeyFile string
|
||||
CAFile string
|
||||
Verify bool
|
||||
}
|
||||
|
||||
// AuthConfig for authentication
|
||||
type AuthConfig struct {
|
||||
Username string
|
||||
Password string
|
||||
Token string
|
||||
Method string
|
||||
}
|
||||
|
||||
// UniversalMessageBus implements MessageBusInterface
|
||||
type UniversalMessageBus struct {
|
||||
config MessageBusConfig
|
||||
transports map[TransportType]Transport
|
||||
router *MessageRouter
|
||||
topics map[string]*Topic
|
||||
subscriptions map[string]*Subscription
|
||||
dlq *DeadLetterQueue
|
||||
metrics *MetricsCollector
|
||||
persistence PersistenceLayer
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
started bool
|
||||
}
|
||||
|
||||
// MessageBusConfig configures the message bus
|
||||
type MessageBusConfig struct {
|
||||
DefaultTransport TransportType
|
||||
EnablePersistence bool
|
||||
EnableMetrics bool
|
||||
EnableDLQ bool
|
||||
MaxMessageSize int64
|
||||
DefaultTTL time.Duration
|
||||
HealthCheckInterval time.Duration
|
||||
CleanupInterval time.Duration
|
||||
}
|
||||
|
||||
// Transport interface for different transport mechanisms
|
||||
type Transport interface {
|
||||
Send(ctx context.Context, msg *Message) error
|
||||
Receive(ctx context.Context) (<-chan *Message, error)
|
||||
Connect(ctx context.Context) error
|
||||
Disconnect(ctx context.Context) error
|
||||
Health() ComponentHealth
|
||||
GetMetrics() TransportMetrics
|
||||
}
|
||||
|
||||
// TransportMetrics for transport-specific metrics
|
||||
type TransportMetrics struct {
|
||||
BytesSent int64
|
||||
BytesReceived int64
|
||||
MessagesSent int64
|
||||
MessagesReceived int64
|
||||
Connections int
|
||||
Errors int64
|
||||
Latency time.Duration
|
||||
}
|
||||
|
||||
// Note: MessageRouter and RoutingRule are defined in router.go
|
||||
|
||||
// LoadBalancer for transport selection
|
||||
type LoadBalancer interface {
|
||||
SelectTransport(transports []TransportType, msg *Message) TransportType
|
||||
UpdateStats(transport TransportType, latency time.Duration, success bool)
|
||||
}
|
||||
|
||||
// Topic represents a message topic
|
||||
type Topic struct {
|
||||
Name string
|
||||
Config TopicConfig
|
||||
Messages []StoredMessage
|
||||
Subscribers []*Subscription
|
||||
Created time.Time
|
||||
LastActivity time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// StoredMessage represents a persisted message
|
||||
type StoredMessage struct {
|
||||
Message *Message
|
||||
Stored time.Time
|
||||
Processed bool
|
||||
}
|
||||
|
||||
// Note: DeadLetterQueue and DLQConfig are defined in dlq.go
|
||||
|
||||
// Note: MetricsCollector is defined in serialization.go
|
||||
|
||||
// PersistenceLayer handles message persistence
|
||||
type PersistenceLayer interface {
|
||||
Store(msg *Message) error
|
||||
Retrieve(id string) (*Message, error)
|
||||
Delete(id string) error
|
||||
List(topic string, limit int) ([]*Message, error)
|
||||
Cleanup(maxAge time.Duration) error
|
||||
}
|
||||
|
||||
// Factory functions for common subscription options
|
||||
func WithQueueSize(size int) SubscriptionOption {
|
||||
return func(opts *SubscriptionOptions) {
|
||||
opts.QueueSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithBatchProcessing(size int, timeout time.Duration) SubscriptionOption {
|
||||
return func(opts *SubscriptionOptions) {
|
||||
opts.BatchSize = size
|
||||
opts.BatchTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
func WithDLQ(enabled bool) SubscriptionOption {
|
||||
return func(opts *SubscriptionOptions) {
|
||||
opts.DLQEnabled = enabled
|
||||
}
|
||||
}
|
||||
|
||||
func WithRetry(enabled bool) SubscriptionOption {
|
||||
return func(opts *SubscriptionOptions) {
|
||||
opts.RetryEnabled = enabled
|
||||
}
|
||||
}
|
||||
|
||||
func WithPersistence(enabled bool) SubscriptionOption {
|
||||
return func(opts *SubscriptionOptions) {
|
||||
opts.Persistent = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// NewUniversalMessageBus creates a new message bus instance
|
||||
func NewUniversalMessageBus(config MessageBusConfig) *UniversalMessageBus {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &UniversalMessageBus{
|
||||
config: config,
|
||||
transports: make(map[TransportType]Transport),
|
||||
topics: make(map[string]*Topic),
|
||||
subscriptions: make(map[string]*Subscription),
|
||||
router: NewMessageRouter(),
|
||||
dlq: NewDeadLetterQueue(DLQConfig{}),
|
||||
metrics: NewMetricsCollector(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageRouter creates a new message router
|
||||
func NewMessageRouter() *MessageRouter {
|
||||
return &MessageRouter{
|
||||
rules: make([]RoutingRule, 0),
|
||||
fallback: TransportMemory,
|
||||
}
|
||||
}
|
||||
|
||||
// Note: NewDeadLetterQueue is defined in dlq.go
|
||||
// Note: NewMetricsCollector is defined in serialization.go
|
||||
|
||||
// Helper function to generate message ID
|
||||
func GenerateMessageID() string {
|
||||
return fmt.Sprintf("msg_%d_%d", time.Now().UnixNano(), time.Now().Nanosecond())
|
||||
}
|
||||
|
||||
// Helper function to create message with defaults
|
||||
func NewMessage(msgType MessageType, topic string, source string, data interface{}) *Message {
|
||||
return &Message{
|
||||
ID: GenerateMessageID(),
|
||||
Type: msgType,
|
||||
Topic: topic,
|
||||
Source: source,
|
||||
Priority: PriorityNormal,
|
||||
Timestamp: time.Now(),
|
||||
Data: data,
|
||||
Headers: make(map[string]string),
|
||||
Metadata: make(map[string]interface{}),
|
||||
Retries: 0,
|
||||
MaxRetries: 3,
|
||||
}
|
||||
}
|
||||
742
orig/pkg/transport/message_bus_impl.go
Normal file
742
orig/pkg/transport/message_bus_impl.go
Normal file
@@ -0,0 +1,742 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Publish sends a message to the specified topic
|
||||
func (mb *UniversalMessageBus) Publish(ctx context.Context, msg *Message) error {
|
||||
if !mb.started {
|
||||
return fmt.Errorf("message bus not started")
|
||||
}
|
||||
|
||||
// Validate message
|
||||
if err := mb.validateMessage(msg); err != nil {
|
||||
return fmt.Errorf("invalid message: %w", err)
|
||||
}
|
||||
|
||||
// Set timestamp if not set
|
||||
if msg.Timestamp.IsZero() {
|
||||
msg.Timestamp = time.Now()
|
||||
}
|
||||
|
||||
// Set ID if not set
|
||||
if msg.ID == "" {
|
||||
msg.ID = GenerateMessageID()
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
mb.metrics.IncrementCounter("messages_published_total")
|
||||
mb.metrics.RecordLatency("publish_latency", time.Since(msg.Timestamp))
|
||||
|
||||
// Route message to appropriate transport
|
||||
transport, err := mb.router.RouteMessage(msg, mb.transports)
|
||||
if err != nil {
|
||||
mb.metrics.IncrementCounter("routing_errors_total")
|
||||
return fmt.Errorf("routing failed: %w", err)
|
||||
}
|
||||
|
||||
// Send via transport
|
||||
if err := transport.Send(ctx, msg); err != nil {
|
||||
mb.metrics.IncrementCounter("send_errors_total")
|
||||
// Try dead letter queue if enabled
|
||||
if mb.config.EnableDLQ {
|
||||
if dlqErr := mb.dlq.AddMessage(msg.Topic, msg); dlqErr != nil {
|
||||
return fmt.Errorf("send failed and DLQ failed: %v, original error: %w", dlqErr, err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("send failed: %w", err)
|
||||
}
|
||||
|
||||
// Store in topic if persistence enabled
|
||||
if mb.config.EnablePersistence {
|
||||
if err := mb.addMessageToTopic(msg); err != nil {
|
||||
// Log error but don't fail the publish
|
||||
mb.metrics.IncrementCounter("persistence_errors_total")
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver to local subscribers
|
||||
go mb.deliverToSubscribers(ctx, msg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe creates a subscription to a topic
|
||||
func (mb *UniversalMessageBus) Subscribe(topic string, handler MessageHandler, opts ...SubscriptionOption) (*Subscription, error) {
|
||||
if !mb.started {
|
||||
return nil, fmt.Errorf("message bus not started")
|
||||
}
|
||||
|
||||
// Apply subscription options
|
||||
options := SubscriptionOptions{
|
||||
QueueSize: 1000,
|
||||
BatchSize: 1,
|
||||
BatchTimeout: time.Second,
|
||||
DLQEnabled: mb.config.EnableDLQ,
|
||||
RetryEnabled: true,
|
||||
Persistent: false,
|
||||
Durable: false,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
// Create subscription
|
||||
subscription := &Subscription{
|
||||
ID: fmt.Sprintf("sub_%s_%d", topic, time.Now().UnixNano()),
|
||||
Topic: topic,
|
||||
Handler: handler,
|
||||
Options: options,
|
||||
created: time.Now(),
|
||||
active: true,
|
||||
}
|
||||
|
||||
mb.mu.Lock()
|
||||
mb.subscriptions[subscription.ID] = subscription
|
||||
mb.mu.Unlock()
|
||||
|
||||
// Add to topic subscribers
|
||||
mb.addSubscriberToTopic(topic, subscription)
|
||||
|
||||
mb.metrics.IncrementCounter("subscriptions_created_total")
|
||||
|
||||
return subscription, nil
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription
|
||||
func (mb *UniversalMessageBus) Unsubscribe(subscriptionID string) error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
subscription, exists := mb.subscriptions[subscriptionID]
|
||||
if !exists {
|
||||
return fmt.Errorf("subscription not found: %s", subscriptionID)
|
||||
}
|
||||
|
||||
// Mark as inactive
|
||||
subscription.mu.Lock()
|
||||
subscription.active = false
|
||||
subscription.mu.Unlock()
|
||||
|
||||
// Remove from subscriptions map
|
||||
delete(mb.subscriptions, subscriptionID)
|
||||
|
||||
// Remove from topic subscribers
|
||||
mb.removeSubscriberFromTopic(subscription.Topic, subscriptionID)
|
||||
|
||||
mb.metrics.IncrementCounter("subscriptions_removed_total")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Request sends a request and waits for a response
|
||||
func (mb *UniversalMessageBus) Request(ctx context.Context, msg *Message, timeout time.Duration) (*Message, error) {
|
||||
if !mb.started {
|
||||
return nil, fmt.Errorf("message bus not started")
|
||||
}
|
||||
|
||||
// Set correlation ID for request-response
|
||||
if msg.CorrelationID == "" {
|
||||
msg.CorrelationID = GenerateMessageID()
|
||||
}
|
||||
|
||||
// Create response channel
|
||||
responseChannel := make(chan *Message, 1)
|
||||
defer close(responseChannel)
|
||||
|
||||
// Subscribe to response topic
|
||||
responseTopic := fmt.Sprintf("response.%s", msg.CorrelationID)
|
||||
subscription, err := mb.Subscribe(responseTopic, func(ctx context.Context, response *Message) error {
|
||||
select {
|
||||
case responseChannel <- response:
|
||||
default:
|
||||
// Channel full, ignore
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to subscribe to response topic: %w", err)
|
||||
}
|
||||
defer mb.Unsubscribe(subscription.ID)
|
||||
|
||||
// Send request
|
||||
if err := mb.Publish(ctx, msg); err != nil {
|
||||
return nil, fmt.Errorf("failed to publish request: %w", err)
|
||||
}
|
||||
|
||||
// Wait for response with timeout
|
||||
select {
|
||||
case response := <-responseChannel:
|
||||
return response, nil
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("request timeout after %v", timeout)
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Reply sends a response to a request
|
||||
func (mb *UniversalMessageBus) Reply(ctx context.Context, originalMsg *Message, response *Message) error {
|
||||
if originalMsg.CorrelationID == "" {
|
||||
return fmt.Errorf("original message has no correlation ID")
|
||||
}
|
||||
|
||||
// Set response properties
|
||||
response.Type = MessageTypeResponse
|
||||
response.Topic = fmt.Sprintf("response.%s", originalMsg.CorrelationID)
|
||||
response.CorrelationID = originalMsg.CorrelationID
|
||||
response.Target = originalMsg.Source
|
||||
|
||||
return mb.Publish(ctx, response)
|
||||
}
|
||||
|
||||
// CreateTopic creates a new topic with configuration
|
||||
func (mb *UniversalMessageBus) CreateTopic(topicName string, config TopicConfig) error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
if _, exists := mb.topics[topicName]; exists {
|
||||
return fmt.Errorf("topic already exists: %s", topicName)
|
||||
}
|
||||
|
||||
topic := &Topic{
|
||||
Name: topicName,
|
||||
Config: config,
|
||||
Messages: make([]StoredMessage, 0),
|
||||
Subscribers: make([]*Subscription, 0),
|
||||
Created: time.Now(),
|
||||
LastActivity: time.Now(),
|
||||
}
|
||||
|
||||
mb.topics[topicName] = topic
|
||||
mb.metrics.IncrementCounter("topics_created_total")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTopic removes a topic
|
||||
func (mb *UniversalMessageBus) DeleteTopic(topicName string) error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
topic, exists := mb.topics[topicName]
|
||||
if !exists {
|
||||
return fmt.Errorf("topic not found: %s", topicName)
|
||||
}
|
||||
|
||||
// Remove all subscribers
|
||||
for _, sub := range topic.Subscribers {
|
||||
mb.Unsubscribe(sub.ID)
|
||||
}
|
||||
|
||||
delete(mb.topics, topicName)
|
||||
mb.metrics.IncrementCounter("topics_deleted_total")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListTopics returns all topic names
|
||||
func (mb *UniversalMessageBus) ListTopics() []string {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
|
||||
topics := make([]string, 0, len(mb.topics))
|
||||
for name := range mb.topics {
|
||||
topics = append(topics, name)
|
||||
}
|
||||
|
||||
return topics
|
||||
}
|
||||
|
||||
// GetTopicInfo returns topic information
|
||||
func (mb *UniversalMessageBus) GetTopicInfo(topicName string) (*TopicInfo, error) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
|
||||
topic, exists := mb.topics[topicName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("topic not found: %s", topicName)
|
||||
}
|
||||
|
||||
topic.mu.RLock()
|
||||
defer topic.mu.RUnlock()
|
||||
|
||||
// Calculate size
|
||||
var sizeBytes int64
|
||||
for _, stored := range topic.Messages {
|
||||
// Rough estimation of message size
|
||||
sizeBytes += int64(len(fmt.Sprintf("%+v", stored.Message)))
|
||||
}
|
||||
|
||||
return &TopicInfo{
|
||||
Name: topic.Name,
|
||||
Config: topic.Config,
|
||||
MessageCount: int64(len(topic.Messages)),
|
||||
SubscriberCount: len(topic.Subscribers),
|
||||
LastActivity: topic.LastActivity,
|
||||
SizeBytes: sizeBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueueMessage adds a message to a topic queue
|
||||
func (mb *UniversalMessageBus) QueueMessage(topic string, msg *Message) error {
|
||||
return mb.addMessageToTopic(msg)
|
||||
}
|
||||
|
||||
// DequeueMessage removes a message from a topic queue
|
||||
func (mb *UniversalMessageBus) DequeueMessage(topic string, timeout time.Duration) (*Message, error) {
|
||||
start := time.Now()
|
||||
|
||||
for time.Since(start) < timeout {
|
||||
mb.mu.RLock()
|
||||
topicObj, exists := mb.topics[topic]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("topic not found: %s", topic)
|
||||
}
|
||||
|
||||
topicObj.mu.Lock()
|
||||
if len(topicObj.Messages) > 0 {
|
||||
// Get first unprocessed message
|
||||
for i, stored := range topicObj.Messages {
|
||||
if !stored.Processed {
|
||||
topicObj.Messages[i].Processed = true
|
||||
topicObj.mu.Unlock()
|
||||
return stored.Message, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
topicObj.mu.Unlock()
|
||||
|
||||
// Wait a bit before trying again
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no message available within timeout")
|
||||
}
|
||||
|
||||
// PeekMessage returns the next message without removing it
|
||||
func (mb *UniversalMessageBus) PeekMessage(topic string) (*Message, error) {
|
||||
mb.mu.RLock()
|
||||
topicObj, exists := mb.topics[topic]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("topic not found: %s", topic)
|
||||
}
|
||||
|
||||
topicObj.mu.RLock()
|
||||
defer topicObj.mu.RUnlock()
|
||||
|
||||
for _, stored := range topicObj.Messages {
|
||||
if !stored.Processed {
|
||||
return stored.Message, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no messages available")
|
||||
}
|
||||
|
||||
// Start initializes and starts the message bus
|
||||
func (mb *UniversalMessageBus) Start(ctx context.Context) error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
if mb.started {
|
||||
return fmt.Errorf("message bus already started")
|
||||
}
|
||||
|
||||
// Initialize default transport if none configured
|
||||
if len(mb.transports) == 0 {
|
||||
memTransport := NewMemoryTransport()
|
||||
mb.transports[TransportMemory] = memTransport
|
||||
if err := memTransport.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("failed to connect default transport: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start background routines
|
||||
go mb.healthCheckLoop()
|
||||
go mb.cleanupLoop()
|
||||
go mb.metricsLoop()
|
||||
|
||||
mb.started = true
|
||||
mb.metrics.RecordEvent("message_bus_started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the message bus
|
||||
func (mb *UniversalMessageBus) Stop(ctx context.Context) error {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
|
||||
if !mb.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel context to stop background routines
|
||||
mb.cancel()
|
||||
|
||||
// Disconnect all transports
|
||||
for _, transport := range mb.transports {
|
||||
if err := transport.Disconnect(ctx); err != nil {
|
||||
// Log error but continue shutdown
|
||||
}
|
||||
}
|
||||
|
||||
mb.started = false
|
||||
mb.metrics.RecordEvent("message_bus_stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Health returns the current health status
|
||||
func (mb *UniversalMessageBus) Health() HealthStatus {
|
||||
components := make(map[string]ComponentHealth)
|
||||
|
||||
// Check transport health
|
||||
for transportType, transport := range mb.transports {
|
||||
components[string(transportType)] = transport.Health()
|
||||
}
|
||||
|
||||
// Overall status
|
||||
status := "healthy"
|
||||
var errors []HealthError
|
||||
|
||||
for name, component := range components {
|
||||
if component.Status != "healthy" {
|
||||
status = "degraded"
|
||||
errors = append(errors, HealthError{
|
||||
Component: name,
|
||||
Message: fmt.Sprintf("Component %s is %s", name, component.Status),
|
||||
Timestamp: time.Now(),
|
||||
Severity: "warning",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return HealthStatus{
|
||||
Status: status,
|
||||
Uptime: time.Since(time.Now()), // Would track actual uptime
|
||||
LastCheck: time.Now(),
|
||||
Components: components,
|
||||
Errors: errors,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns current operational metrics
|
||||
func (mb *UniversalMessageBus) GetMetrics() MessageBusMetrics {
|
||||
_ = mb.metrics.GetAll() // metrics not used
|
||||
|
||||
return MessageBusMetrics{
|
||||
MessagesPublished: mb.getMetricInt64("messages_published_total"),
|
||||
MessagesConsumed: mb.getMetricInt64("messages_consumed_total"),
|
||||
MessagesFailed: mb.getMetricInt64("send_errors_total"),
|
||||
MessagesInDLQ: int64(mb.dlq.GetMessageCount()),
|
||||
ActiveSubscriptions: len(mb.subscriptions),
|
||||
TopicCount: len(mb.topics),
|
||||
AverageLatency: mb.getMetricDuration("average_latency"),
|
||||
ThroughputPerSec: mb.getMetricFloat64("throughput_per_second"),
|
||||
ErrorRate: mb.getMetricFloat64("error_rate"),
|
||||
MemoryUsage: mb.getMetricInt64("memory_usage_bytes"),
|
||||
CPUUsage: mb.getMetricFloat64("cpu_usage_percent"),
|
||||
}
|
||||
}
|
||||
|
||||
// GetSubscriptions returns all active subscriptions
|
||||
func (mb *UniversalMessageBus) GetSubscriptions() []*Subscription {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
|
||||
subscriptions := make([]*Subscription, 0, len(mb.subscriptions))
|
||||
for _, sub := range mb.subscriptions {
|
||||
subscriptions = append(subscriptions, sub)
|
||||
}
|
||||
|
||||
return subscriptions
|
||||
}
|
||||
|
||||
// GetActiveConnections returns the number of active connections
|
||||
func (mb *UniversalMessageBus) GetActiveConnections() int {
|
||||
count := 0
|
||||
for _, transport := range mb.transports {
|
||||
metrics := transport.GetMetrics()
|
||||
count += metrics.Connections
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
func (mb *UniversalMessageBus) validateMessage(msg *Message) error {
|
||||
if msg == nil {
|
||||
return fmt.Errorf("message is nil")
|
||||
}
|
||||
if msg.Topic == "" {
|
||||
return fmt.Errorf("message topic is empty")
|
||||
}
|
||||
if msg.Source == "" {
|
||||
return fmt.Errorf("message source is empty")
|
||||
}
|
||||
if msg.Data == nil {
|
||||
return fmt.Errorf("message data is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) addMessageToTopic(msg *Message) error {
|
||||
mb.mu.RLock()
|
||||
topic, exists := mb.topics[msg.Topic]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Create topic automatically
|
||||
config := TopicConfig{
|
||||
Persistent: true,
|
||||
RetentionPolicy: RetentionPolicy{MaxMessages: 10000, MaxAge: 24 * time.Hour},
|
||||
}
|
||||
if err := mb.CreateTopic(msg.Topic, config); err != nil {
|
||||
return err
|
||||
}
|
||||
topic = mb.topics[msg.Topic]
|
||||
}
|
||||
|
||||
topic.mu.Lock()
|
||||
defer topic.mu.Unlock()
|
||||
|
||||
stored := StoredMessage{
|
||||
Message: msg,
|
||||
Stored: time.Now(),
|
||||
Processed: false,
|
||||
}
|
||||
|
||||
topic.Messages = append(topic.Messages, stored)
|
||||
topic.LastActivity = time.Now()
|
||||
|
||||
// Apply retention policy
|
||||
mb.applyRetentionPolicy(topic)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) addSubscriberToTopic(topicName string, subscription *Subscription) {
|
||||
mb.mu.RLock()
|
||||
topic, exists := mb.topics[topicName]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Create topic automatically
|
||||
config := TopicConfig{Persistent: false}
|
||||
mb.CreateTopic(topicName, config)
|
||||
topic = mb.topics[topicName]
|
||||
}
|
||||
|
||||
topic.mu.Lock()
|
||||
topic.Subscribers = append(topic.Subscribers, subscription)
|
||||
topic.mu.Unlock()
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) removeSubscriberFromTopic(topicName, subscriptionID string) {
|
||||
mb.mu.RLock()
|
||||
topic, exists := mb.topics[topicName]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
topic.mu.Lock()
|
||||
defer topic.mu.Unlock()
|
||||
|
||||
for i, sub := range topic.Subscribers {
|
||||
if sub.ID == subscriptionID {
|
||||
topic.Subscribers = append(topic.Subscribers[:i], topic.Subscribers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) deliverToSubscribers(ctx context.Context, msg *Message) {
|
||||
mb.mu.RLock()
|
||||
topic, exists := mb.topics[msg.Topic]
|
||||
mb.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
topic.mu.RLock()
|
||||
subscribers := make([]*Subscription, len(topic.Subscribers))
|
||||
copy(subscribers, topic.Subscribers)
|
||||
topic.mu.RUnlock()
|
||||
|
||||
for _, sub := range subscribers {
|
||||
sub.mu.RLock()
|
||||
if !sub.active {
|
||||
sub.mu.RUnlock()
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply filter if present
|
||||
if sub.Filter != nil && !sub.Filter(msg) {
|
||||
sub.mu.RUnlock()
|
||||
continue
|
||||
}
|
||||
|
||||
handler := sub.Handler
|
||||
sub.mu.RUnlock()
|
||||
|
||||
// Deliver message in goroutine
|
||||
go func(subscription *Subscription, message *Message) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
mb.metrics.IncrementCounter("handler_panics_total")
|
||||
}
|
||||
}()
|
||||
|
||||
if err := handler(ctx, message); err != nil {
|
||||
mb.metrics.IncrementCounter("handler_errors_total")
|
||||
if mb.config.EnableDLQ {
|
||||
mb.dlq.AddMessage(message.Topic, message)
|
||||
}
|
||||
} else {
|
||||
mb.metrics.IncrementCounter("messages_consumed_total")
|
||||
}
|
||||
}(sub, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) applyRetentionPolicy(topic *Topic) {
|
||||
policy := topic.Config.RetentionPolicy
|
||||
|
||||
// Remove old messages
|
||||
if policy.MaxAge > 0 {
|
||||
cutoff := time.Now().Add(-policy.MaxAge)
|
||||
filtered := make([]StoredMessage, 0)
|
||||
for _, stored := range topic.Messages {
|
||||
if stored.Stored.After(cutoff) {
|
||||
filtered = append(filtered, stored)
|
||||
}
|
||||
}
|
||||
topic.Messages = filtered
|
||||
}
|
||||
|
||||
// Limit number of messages
|
||||
if policy.MaxMessages > 0 && len(topic.Messages) > policy.MaxMessages {
|
||||
topic.Messages = topic.Messages[len(topic.Messages)-policy.MaxMessages:]
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) healthCheckLoop() {
|
||||
ticker := time.NewTicker(mb.config.HealthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mb.performHealthCheck()
|
||||
case <-mb.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) cleanupLoop() {
|
||||
ticker := time.NewTicker(mb.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mb.performCleanup()
|
||||
case <-mb.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) metricsLoop() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mb.updateMetrics()
|
||||
case <-mb.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) performHealthCheck() {
|
||||
// Check all transports
|
||||
for transportType, transport := range mb.transports {
|
||||
health := transport.Health()
|
||||
mb.metrics.RecordGauge(fmt.Sprintf("transport_%s_healthy", transportType),
|
||||
map[string]float64{"healthy": 1, "unhealthy": 0, "degraded": 0.5}[health.Status])
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) performCleanup() {
|
||||
// Clean up processed messages in topics
|
||||
mb.mu.RLock()
|
||||
topics := make([]*Topic, 0, len(mb.topics))
|
||||
for _, topic := range mb.topics {
|
||||
topics = append(topics, topic)
|
||||
}
|
||||
mb.mu.RUnlock()
|
||||
|
||||
for _, topic := range topics {
|
||||
topic.mu.Lock()
|
||||
mb.applyRetentionPolicy(topic)
|
||||
topic.mu.Unlock()
|
||||
}
|
||||
|
||||
// Clean up DLQ
|
||||
mb.dlq.Cleanup(time.Hour * 24) // Clean messages older than 24 hours
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) updateMetrics() {
|
||||
// Update throughput metrics
|
||||
publishedCount := mb.getMetricInt64("messages_published_total")
|
||||
if publishedCount > 0 {
|
||||
// Calculate per-second rate (simplified)
|
||||
mb.metrics.RecordGauge("throughput_per_second", float64(publishedCount)/60.0)
|
||||
}
|
||||
|
||||
// Update error rate
|
||||
errorCount := mb.getMetricInt64("send_errors_total")
|
||||
totalCount := publishedCount
|
||||
if totalCount > 0 {
|
||||
errorRate := float64(errorCount) / float64(totalCount)
|
||||
mb.metrics.RecordGauge("error_rate", errorRate)
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) getMetricInt64(key string) int64 {
|
||||
if val, ok := mb.metrics.Get(key).(int64); ok {
|
||||
return val
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) getMetricFloat64(key string) float64 {
|
||||
if val, ok := mb.metrics.Get(key).(float64); ok {
|
||||
return val
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mb *UniversalMessageBus) getMetricDuration(key string) time.Duration {
|
||||
if val, ok := mb.metrics.Get(key).(time.Duration); ok {
|
||||
return val
|
||||
}
|
||||
return 0
|
||||
}
|
||||
820
orig/pkg/transport/persistence.go
Normal file
820
orig/pkg/transport/persistence.go
Normal file
@@ -0,0 +1,820 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FilePersistenceLayer implements file-based message persistence
|
||||
type FilePersistenceLayer struct {
|
||||
basePath string
|
||||
maxFileSize int64
|
||||
maxFiles int
|
||||
compression bool
|
||||
encryption EncryptionConfig
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// EncryptionConfig configures message encryption at rest
|
||||
type EncryptionConfig struct {
|
||||
Enabled bool
|
||||
Algorithm string
|
||||
Key []byte
|
||||
}
|
||||
|
||||
// PersistedMessage represents a message stored on disk
|
||||
type PersistedMessage struct {
|
||||
ID string `json:"id"`
|
||||
Topic string `json:"topic"`
|
||||
Message *Message `json:"message"`
|
||||
Stored time.Time `json:"stored"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
Encrypted bool `json:"encrypted"`
|
||||
}
|
||||
|
||||
// PersistenceMetrics tracks persistence layer statistics
|
||||
type PersistenceMetrics struct {
|
||||
MessagesStored int64
|
||||
MessagesRetrieved int64
|
||||
MessagesDeleted int64
|
||||
StorageSize int64
|
||||
FileCount int
|
||||
LastCleanup time.Time
|
||||
Errors int64
|
||||
}
|
||||
|
||||
// NewFilePersistenceLayer creates a new file-based persistence layer
|
||||
func NewFilePersistenceLayer(basePath string) *FilePersistenceLayer {
|
||||
return &FilePersistenceLayer{
|
||||
basePath: basePath,
|
||||
maxFileSize: 100 * 1024 * 1024, // 100MB default
|
||||
maxFiles: 1000,
|
||||
compression: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxFileSize configures the maximum file size
|
||||
func (fpl *FilePersistenceLayer) SetMaxFileSize(size int64) {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
fpl.maxFileSize = size
|
||||
}
|
||||
|
||||
// SetMaxFiles configures the maximum number of files
|
||||
func (fpl *FilePersistenceLayer) SetMaxFiles(count int) {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
fpl.maxFiles = count
|
||||
}
|
||||
|
||||
// EnableCompression enables/disables compression
|
||||
func (fpl *FilePersistenceLayer) EnableCompression(enabled bool) {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
fpl.compression = enabled
|
||||
}
|
||||
|
||||
// SetEncryption configures encryption settings
|
||||
func (fpl *FilePersistenceLayer) SetEncryption(config EncryptionConfig) {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
fpl.encryption = config
|
||||
}
|
||||
|
||||
// Store persists a message to disk
|
||||
func (fpl *FilePersistenceLayer) Store(msg *Message) error {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
topicDir := filepath.Join(fpl.basePath, msg.Topic)
|
||||
if err := os.MkdirAll(topicDir, 0750); err != nil {
|
||||
return fmt.Errorf("failed to create topic directory: %w", err)
|
||||
}
|
||||
|
||||
// Create persisted message
|
||||
persistedMsg := &PersistedMessage{
|
||||
ID: msg.ID,
|
||||
Topic: msg.Topic,
|
||||
Message: msg,
|
||||
Stored: time.Now(),
|
||||
Metadata: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
data, err := json.Marshal(persistedMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Apply encryption if enabled
|
||||
if fpl.encryption.Enabled {
|
||||
encryptedData, err := fpl.encrypt(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encryption failed: %w", err)
|
||||
}
|
||||
data = encryptedData
|
||||
persistedMsg.Encrypted = true
|
||||
}
|
||||
|
||||
// Apply compression if enabled
|
||||
if fpl.compression {
|
||||
compressedData, err := fpl.compress(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compression failed: %w", err)
|
||||
}
|
||||
data = compressedData
|
||||
}
|
||||
|
||||
// Find appropriate file to write to
|
||||
filename, err := fpl.getWritableFile(topicDir, len(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get writable file: %w", err)
|
||||
}
|
||||
|
||||
// Write to file
|
||||
file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write length prefix and data
|
||||
lengthPrefix := fmt.Sprintf("%d\n", len(data))
|
||||
if _, err := file.WriteString(lengthPrefix); err != nil {
|
||||
return fmt.Errorf("failed to write length prefix: %w", err)
|
||||
}
|
||||
if _, err := file.Write(data); err != nil {
|
||||
return fmt.Errorf("failed to write data: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retrieve loads a message from disk by ID
|
||||
func (fpl *FilePersistenceLayer) Retrieve(id string) (*Message, error) {
|
||||
fpl.mu.RLock()
|
||||
defer fpl.mu.RUnlock()
|
||||
|
||||
// Search all topic directories
|
||||
topicDirs, err := fpl.getTopicDirectories()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get topic directories: %w", err)
|
||||
}
|
||||
|
||||
for _, topicDir := range topicDirs {
|
||||
files, err := fpl.getTopicFiles(topicDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
msg, err := fpl.findMessageInFile(file, id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if msg != nil {
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("message not found: %s", id)
|
||||
}
|
||||
|
||||
// Delete removes a message from disk by ID
|
||||
func (fpl *FilePersistenceLayer) Delete(id string) error {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
|
||||
// Search all topic directories to find the message
|
||||
topicDirs, err := fpl.getTopicDirectories()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get topic directories: %w", err)
|
||||
}
|
||||
|
||||
for _, topicDir := range topicDirs {
|
||||
files, err := fpl.getTopicFiles(topicDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if err := fpl.deleteMessageFromFile(file, id); err == nil {
|
||||
return nil // Successfully deleted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("message not found: %s", id)
|
||||
}
|
||||
|
||||
// List returns messages for a topic with optional limit
|
||||
func (fpl *FilePersistenceLayer) List(topic string, limit int) ([]*Message, error) {
|
||||
fpl.mu.RLock()
|
||||
defer fpl.mu.RUnlock()
|
||||
|
||||
topicDir := filepath.Join(fpl.basePath, topic)
|
||||
if _, err := os.Stat(topicDir); os.IsNotExist(err) {
|
||||
return []*Message{}, nil
|
||||
}
|
||||
|
||||
files, err := fpl.getTopicFiles(topicDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get topic files: %w", err)
|
||||
}
|
||||
|
||||
var messages []*Message
|
||||
count := 0
|
||||
|
||||
// Read files in chronological order (newest first)
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
infoI, _ := os.Stat(files[i])
|
||||
infoJ, _ := os.Stat(files[j])
|
||||
return infoI.ModTime().After(infoJ.ModTime())
|
||||
})
|
||||
|
||||
for _, file := range files {
|
||||
fileMessages, err := fpl.readMessagesFromFile(file)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, msg := range fileMessages {
|
||||
messages = append(messages, msg)
|
||||
count++
|
||||
if limit > 0 && count >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if limit > 0 && count >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// Cleanup removes messages older than maxAge
|
||||
func (fpl *FilePersistenceLayer) Cleanup(maxAge time.Duration) error {
|
||||
fpl.mu.Lock()
|
||||
defer fpl.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
|
||||
topicDirs, err := fpl.getTopicDirectories()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get topic directories: %w", err)
|
||||
}
|
||||
|
||||
for _, topicDir := range topicDirs {
|
||||
files, err := fpl.getTopicFiles(topicDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
// Check file modification time
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.ModTime().Before(cutoff) {
|
||||
os.Remove(file)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove empty topic directories
|
||||
if isEmpty, _ := fpl.isDirectoryEmpty(topicDir); isEmpty {
|
||||
os.Remove(topicDir)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMetrics returns persistence layer metrics
|
||||
func (fpl *FilePersistenceLayer) GetMetrics() (PersistenceMetrics, error) {
|
||||
fpl.mu.RLock()
|
||||
defer fpl.mu.RUnlock()
|
||||
|
||||
metrics := PersistenceMetrics{}
|
||||
|
||||
// Calculate storage size and file count
|
||||
err := filepath.Walk(fpl.basePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
metrics.FileCount++
|
||||
metrics.StorageSize += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return metrics, err
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (fpl *FilePersistenceLayer) getWritableFile(topicDir string, dataSize int) (string, error) {
|
||||
files, err := fpl.getTopicFiles(topicDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Find a file with enough space
|
||||
for _, file := range files {
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.Size()+int64(dataSize) <= fpl.maxFileSize {
|
||||
return file, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create new file
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
filename := filepath.Join(topicDir, fmt.Sprintf("messages_%s.dat", timestamp))
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) getTopicDirectories() ([]string, error) {
|
||||
entries, err := os.ReadDir(fpl.basePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dirs []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
dirs = append(dirs, filepath.Join(fpl.basePath, entry.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
return dirs, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) getTopicFiles(topicDir string) ([]string, error) {
|
||||
entries, err := os.ReadDir(topicDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var files []string
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && filepath.Ext(entry.Name()) == ".dat" {
|
||||
files = append(files, filepath.Join(topicDir, entry.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) findMessageInFile(filename, messageID string) (*Message, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse messages from file data
|
||||
messages, err := fpl.parseFileData(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.ID == messageID {
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) readMessagesFromFile(filename string) ([]*Message, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fpl.parseFileData(data)
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) parseFileData(data []byte) ([]*Message, error) {
|
||||
var messages []*Message
|
||||
offset := 0
|
||||
|
||||
for offset < len(data) {
|
||||
// Read length prefix
|
||||
lengthEnd := -1
|
||||
for i := offset; i < len(data); i++ {
|
||||
if data[i] == '\n' {
|
||||
lengthEnd = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if lengthEnd == -1 {
|
||||
break // No more complete messages
|
||||
}
|
||||
|
||||
lengthStr := string(data[offset:lengthEnd])
|
||||
var messageLength int
|
||||
if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil {
|
||||
break // Invalid length prefix
|
||||
}
|
||||
|
||||
messageStart := lengthEnd + 1
|
||||
messageEnd := messageStart + messageLength
|
||||
|
||||
if messageEnd > len(data) {
|
||||
break // Incomplete message
|
||||
}
|
||||
|
||||
messageData := data[messageStart:messageEnd]
|
||||
|
||||
// Apply decompression if needed
|
||||
if fpl.compression {
|
||||
decompressed, err := fpl.decompress(messageData)
|
||||
if err != nil {
|
||||
offset = messageEnd
|
||||
continue
|
||||
}
|
||||
messageData = decompressed
|
||||
}
|
||||
|
||||
// Apply decryption if needed
|
||||
if fpl.encryption.Enabled {
|
||||
decrypted, err := fpl.decrypt(messageData)
|
||||
if err != nil {
|
||||
offset = messageEnd
|
||||
continue
|
||||
}
|
||||
messageData = decrypted
|
||||
}
|
||||
|
||||
// Parse message
|
||||
var persistedMsg PersistedMessage
|
||||
if err := json.Unmarshal(messageData, &persistedMsg); err != nil {
|
||||
offset = messageEnd
|
||||
continue
|
||||
}
|
||||
|
||||
messages = append(messages, persistedMsg.Message)
|
||||
offset = messageEnd
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) isDirectoryEmpty(dir string) (bool, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return len(entries) == 0, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) encrypt(data []byte) ([]byte, error) {
|
||||
if !fpl.encryption.Enabled || len(fpl.encryption.Key) == 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Create cipher block
|
||||
block, err := aes.NewCipher(fpl.encryption.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Generate random nonce
|
||||
nonce := make([]byte, 12) // GCM standard nonce size
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM mode
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt and authenticate
|
||||
ciphertext := gcm.Seal(nil, nonce, data, nil)
|
||||
|
||||
// Prepend nonce to ciphertext
|
||||
result := make([]byte, len(nonce)+len(ciphertext))
|
||||
copy(result, nonce)
|
||||
copy(result[len(nonce):], ciphertext)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) decrypt(data []byte) ([]byte, error) {
|
||||
if !fpl.encryption.Enabled || len(fpl.encryption.Key) == 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if len(data) < 12 {
|
||||
return nil, fmt.Errorf("encrypted data too short")
|
||||
}
|
||||
|
||||
// Create cipher block
|
||||
block, err := aes.NewCipher(fpl.encryption.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM mode
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce := data[:12]
|
||||
ciphertext := data[12:]
|
||||
|
||||
// Decrypt and verify
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) compress(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&buf)
|
||||
|
||||
if _, err := gzWriter.Write(data); err != nil {
|
||||
gzWriter.Close()
|
||||
return nil, fmt.Errorf("compression failed: %w", err)
|
||||
}
|
||||
|
||||
if err := gzWriter.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close gzip writer: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (fpl *FilePersistenceLayer) decompress(data []byte) ([]byte, error) {
|
||||
buf := bytes.NewReader(data)
|
||||
gzReader, err := gzip.NewReader(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gzReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompression failed: %w", err)
|
||||
}
|
||||
|
||||
return decompressed, nil
|
||||
}
|
||||
|
||||
// deleteMessageFromFile removes a specific message from a file
|
||||
func (fpl *FilePersistenceLayer) deleteMessageFromFile(filename, messageID string) error {
|
||||
// Read all messages from file
|
||||
messages, err := fpl.readMessagesFromFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read messages from file: %w", err)
|
||||
}
|
||||
|
||||
// Check if message exists in this file
|
||||
found := false
|
||||
var filteredMessages []*Message
|
||||
for _, msg := range messages {
|
||||
if msg.ID != messageID {
|
||||
filteredMessages = append(filteredMessages, msg)
|
||||
} else {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return fmt.Errorf("message not found in file")
|
||||
}
|
||||
|
||||
// If no messages remain, delete the file
|
||||
if len(filteredMessages) == 0 {
|
||||
return os.Remove(filename)
|
||||
}
|
||||
|
||||
// Rewrite file with remaining messages
|
||||
return fpl.rewriteFileWithMessages(filename, filteredMessages)
|
||||
}
|
||||
|
||||
// rewriteFileWithMessages rewrites a file with the given messages
|
||||
func (fpl *FilePersistenceLayer) rewriteFileWithMessages(filename string, messages []*Message) error {
|
||||
// Create temporary file
|
||||
tempFile := filename + ".tmp"
|
||||
file, err := os.Create(tempFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write each message to temp file
|
||||
for _, msg := range messages {
|
||||
// Create persisted message
|
||||
persistedMsg := &PersistedMessage{
|
||||
ID: msg.ID,
|
||||
Topic: msg.Topic,
|
||||
Message: msg,
|
||||
Stored: time.Now(),
|
||||
Metadata: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
data, err := json.Marshal(persistedMsg)
|
||||
if err != nil {
|
||||
os.Remove(tempFile)
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Apply encryption if enabled
|
||||
if fpl.encryption.Enabled {
|
||||
encryptedData, err := fpl.encrypt(data)
|
||||
if err != nil {
|
||||
os.Remove(tempFile)
|
||||
return fmt.Errorf("encryption failed: %w", err)
|
||||
}
|
||||
data = encryptedData
|
||||
persistedMsg.Encrypted = true
|
||||
}
|
||||
|
||||
// Apply compression if enabled
|
||||
if fpl.compression {
|
||||
compressedData, err := fpl.compress(data)
|
||||
if err != nil {
|
||||
os.Remove(tempFile)
|
||||
return fmt.Errorf("compression failed: %w", err)
|
||||
}
|
||||
data = compressedData
|
||||
}
|
||||
|
||||
// Write length prefix and data
|
||||
lengthPrefix := fmt.Sprintf("%d\n", len(data))
|
||||
if _, err := file.WriteString(lengthPrefix); err != nil {
|
||||
os.Remove(tempFile)
|
||||
return fmt.Errorf("failed to write length prefix: %w", err)
|
||||
}
|
||||
if _, err := file.Write(data); err != nil {
|
||||
os.Remove(tempFile)
|
||||
return fmt.Errorf("failed to write data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close temp file before rename
|
||||
file.Close()
|
||||
|
||||
// Replace original file with temp file
|
||||
return os.Rename(tempFile, filename)
|
||||
}
|
||||
|
||||
// InMemoryPersistenceLayer implements in-memory persistence for testing/development
|
||||
type InMemoryPersistenceLayer struct {
|
||||
messages map[string]*Message
|
||||
topics map[string][]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewInMemoryPersistenceLayer creates a new in-memory persistence layer
|
||||
func NewInMemoryPersistenceLayer() *InMemoryPersistenceLayer {
|
||||
return &InMemoryPersistenceLayer{
|
||||
messages: make(map[string]*Message),
|
||||
topics: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Store stores a message in memory
|
||||
func (impl *InMemoryPersistenceLayer) Store(msg *Message) error {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
impl.messages[msg.ID] = msg
|
||||
|
||||
if _, exists := impl.topics[msg.Topic]; !exists {
|
||||
impl.topics[msg.Topic] = make([]string, 0)
|
||||
}
|
||||
impl.topics[msg.Topic] = append(impl.topics[msg.Topic], msg.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retrieve retrieves a message from memory
|
||||
func (impl *InMemoryPersistenceLayer) Retrieve(id string) (*Message, error) {
|
||||
impl.mu.RLock()
|
||||
defer impl.mu.RUnlock()
|
||||
|
||||
msg, exists := impl.messages[id]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("message not found: %s", id)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Delete removes a message from memory
|
||||
func (impl *InMemoryPersistenceLayer) Delete(id string) error {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
msg, exists := impl.messages[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("message not found: %s", id)
|
||||
}
|
||||
|
||||
delete(impl.messages, id)
|
||||
|
||||
// Remove from topic index
|
||||
if messageIDs, exists := impl.topics[msg.Topic]; exists {
|
||||
for i, msgID := range messageIDs {
|
||||
if msgID == id {
|
||||
impl.topics[msg.Topic] = append(messageIDs[:i], messageIDs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns messages for a topic
|
||||
func (impl *InMemoryPersistenceLayer) List(topic string, limit int) ([]*Message, error) {
|
||||
impl.mu.RLock()
|
||||
defer impl.mu.RUnlock()
|
||||
|
||||
messageIDs, exists := impl.topics[topic]
|
||||
if !exists {
|
||||
return []*Message{}, nil
|
||||
}
|
||||
|
||||
var messages []*Message
|
||||
count := 0
|
||||
|
||||
for _, msgID := range messageIDs {
|
||||
if limit > 0 && count >= limit {
|
||||
break
|
||||
}
|
||||
|
||||
if msg, exists := impl.messages[msgID]; exists {
|
||||
messages = append(messages, msg)
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// Cleanup removes messages older than maxAge
|
||||
func (impl *InMemoryPersistenceLayer) Cleanup(maxAge time.Duration) error {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
var toDelete []string
|
||||
|
||||
for id, msg := range impl.messages {
|
||||
if msg.Timestamp.Before(cutoff) {
|
||||
toDelete = append(toDelete, id)
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range toDelete {
|
||||
impl.Delete(id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
560
orig/pkg/transport/provider_manager.go
Normal file
560
orig/pkg/transport/provider_manager.go
Normal file
@@ -0,0 +1,560 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/ethclient"
|
||||
"github.com/ethereum/go-ethereum/rpc"
|
||||
"golang.org/x/time/rate"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ProviderConfig represents a single RPC provider configuration
|
||||
type ProviderConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Type string `yaml:"type"`
|
||||
HTTPEndpoint string `yaml:"http_endpoint"`
|
||||
WSEndpoint string `yaml:"ws_endpoint"`
|
||||
Priority int `yaml:"priority"`
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
Features []string `yaml:"features"`
|
||||
HealthCheck HealthCheckConfig `yaml:"health_check"`
|
||||
AnvilConfig *AnvilConfig `yaml:"anvil_config,omitempty"` // For Anvil fork providers
|
||||
}
|
||||
|
||||
// AnvilConfig defines Anvil-specific configuration
|
||||
type AnvilConfig struct {
|
||||
ForkURL string `yaml:"fork_url"`
|
||||
ChainID int `yaml:"chain_id"`
|
||||
Port int `yaml:"port"`
|
||||
BlockTime int `yaml:"block_time"`
|
||||
AutoImpersonate bool `yaml:"auto_impersonate"`
|
||||
StateInterval int `yaml:"state_interval"`
|
||||
}
|
||||
|
||||
// RateLimitConfig defines rate limiting parameters
|
||||
type RateLimitConfig struct {
|
||||
RequestsPerSecond int `yaml:"requests_per_second"`
|
||||
Burst int `yaml:"burst"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
RetryDelay time.Duration `yaml:"retry_delay"`
|
||||
MaxRetries int `yaml:"max_retries"`
|
||||
}
|
||||
|
||||
// HealthCheckConfig defines health check parameters
|
||||
type HealthCheckConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Interval time.Duration `yaml:"interval"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
}
|
||||
|
||||
// RotationConfig defines provider rotation strategy
|
||||
type RotationConfig struct {
|
||||
Strategy string `yaml:"strategy"`
|
||||
HealthCheckRequired bool `yaml:"health_check_required"`
|
||||
FallbackEnabled bool `yaml:"fallback_enabled"`
|
||||
RetryFailedAfter time.Duration `yaml:"retry_failed_after"`
|
||||
}
|
||||
|
||||
// ProviderPoolConfig defines configuration for a provider pool
|
||||
type ProviderPoolConfig struct {
|
||||
Strategy string `yaml:"strategy"`
|
||||
MaxConcurrentConnections int `yaml:"max_concurrent_connections"`
|
||||
HealthCheckInterval string `yaml:"health_check_interval"`
|
||||
FailoverEnabled bool `yaml:"failover_enabled"`
|
||||
Providers []string `yaml:"providers"`
|
||||
}
|
||||
|
||||
// ProvidersConfig represents the complete provider configuration
|
||||
type ProvidersConfig struct {
|
||||
ProviderPools map[string]ProviderPoolConfig `yaml:"provider_pools"`
|
||||
Providers []ProviderConfig `yaml:"providers"`
|
||||
Rotation RotationConfig `yaml:"rotation"`
|
||||
GlobalLimits GlobalLimits `yaml:"global_limits"`
|
||||
Monitoring MonitoringConfig `yaml:"monitoring"`
|
||||
}
|
||||
|
||||
// GlobalLimits defines global connection limits
|
||||
type GlobalLimits struct {
|
||||
MaxConcurrentConnections int `yaml:"max_concurrent_connections"`
|
||||
ConnectionTimeout time.Duration `yaml:"connection_timeout"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout"`
|
||||
IdleTimeout time.Duration `yaml:"idle_timeout"`
|
||||
}
|
||||
|
||||
// MonitoringConfig defines monitoring settings
|
||||
type MonitoringConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
MetricsInterval time.Duration `yaml:"metrics_interval"`
|
||||
LogSlowRequests bool `yaml:"log_slow_requests"`
|
||||
SlowRequestThreshold time.Duration `yaml:"slow_request_threshold"`
|
||||
TrackProviderPerformance bool `yaml:"track_provider_performance"`
|
||||
}
|
||||
|
||||
// Provider represents an active RPC provider connection
|
||||
type Provider struct {
|
||||
Config ProviderConfig
|
||||
HTTPClient *ethclient.Client
|
||||
WSClient *ethclient.Client
|
||||
RateLimiter *rate.Limiter
|
||||
HTTPConn *rpc.Client
|
||||
WSConn *rpc.Client
|
||||
IsHealthy bool
|
||||
LastHealthCheck time.Time
|
||||
RequestCount int64
|
||||
ErrorCount int64
|
||||
AvgResponseTime time.Duration
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// ProviderManager manages multiple RPC providers with rotation and failover
|
||||
type ProviderManager struct {
|
||||
providers []*Provider
|
||||
config ProvidersConfig
|
||||
currentProvider int
|
||||
mutex sync.RWMutex
|
||||
healthTicker *time.Ticker
|
||||
metricsTicker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewProviderManager creates a new provider manager from configuration
|
||||
func NewProviderManager(configPath string) (*ProviderManager, error) {
|
||||
// Load configuration
|
||||
config, err := LoadProvidersConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load provider config: %w", err)
|
||||
}
|
||||
|
||||
pm := &ProviderManager{
|
||||
config: config,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize providers
|
||||
if err := pm.initializeProviders(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize providers: %w", err)
|
||||
}
|
||||
|
||||
// Start health checks and metrics collection
|
||||
pm.startBackgroundTasks()
|
||||
|
||||
return pm, nil
|
||||
}
|
||||
|
||||
// LoadProvidersConfig loads provider configuration from YAML file
|
||||
func LoadProvidersConfig(path string) (ProvidersConfig, error) {
|
||||
var config ProvidersConfig
|
||||
|
||||
// Read the YAML file
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return config, fmt.Errorf("failed to read config file %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Unmarshal the YAML data
|
||||
expanded := os.ExpandEnv(string(data))
|
||||
if strings.Contains(expanded, "${") {
|
||||
return config, fmt.Errorf("unresolved environment variables found in provider config %s", path)
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal([]byte(expanded), &config); err != nil {
|
||||
return config, fmt.Errorf("failed to parse YAML config: %w", err)
|
||||
}
|
||||
|
||||
// Validate the configuration
|
||||
if err := validateConfig(&config); err != nil {
|
||||
return config, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// validateConfig validates the provider configuration
|
||||
func validateConfig(config *ProvidersConfig) error {
|
||||
if len(config.Providers) == 0 {
|
||||
return fmt.Errorf("no providers configured")
|
||||
}
|
||||
|
||||
for i, provider := range config.Providers {
|
||||
if provider.Name == "" {
|
||||
return fmt.Errorf("provider %d has no name", i)
|
||||
}
|
||||
if provider.HTTPEndpoint == "" && provider.WSEndpoint == "" {
|
||||
return fmt.Errorf("provider %s has no endpoints", provider.Name)
|
||||
}
|
||||
if provider.RateLimit.RequestsPerSecond <= 0 {
|
||||
return fmt.Errorf("provider %s has invalid rate limit", provider.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initializeProviders sets up all configured providers
|
||||
func (pm *ProviderManager) initializeProviders() error {
|
||||
pm.providers = make([]*Provider, 0, len(pm.config.Providers))
|
||||
|
||||
for _, providerConfig := range pm.config.Providers {
|
||||
provider, err := createProvider(providerConfig)
|
||||
if err != nil {
|
||||
// Log error but continue with other providers
|
||||
continue
|
||||
}
|
||||
pm.providers = append(pm.providers, provider)
|
||||
}
|
||||
|
||||
if len(pm.providers) == 0 {
|
||||
return fmt.Errorf("no providers successfully initialized")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createProvider creates a new provider instance (shared utility function)
|
||||
func createProvider(config ProviderConfig) (*Provider, error) {
|
||||
// Create rate limiter
|
||||
rateLimiter := rate.NewLimiter(
|
||||
rate.Limit(config.RateLimit.RequestsPerSecond),
|
||||
config.RateLimit.Burst,
|
||||
)
|
||||
|
||||
provider := &Provider{
|
||||
Config: config,
|
||||
RateLimiter: rateLimiter,
|
||||
IsHealthy: true, // Assume healthy until proven otherwise
|
||||
}
|
||||
|
||||
// Initialize HTTP connection
|
||||
if config.HTTPEndpoint != "" {
|
||||
httpClient := &http.Client{
|
||||
Timeout: config.RateLimit.Timeout, // Use config timeout
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.DialHTTPWithClient(config.HTTPEndpoint, httpClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to HTTP endpoint %s: %w", config.HTTPEndpoint, err)
|
||||
}
|
||||
|
||||
provider.HTTPConn = rpcClient
|
||||
provider.HTTPClient = ethclient.NewClient(rpcClient)
|
||||
}
|
||||
|
||||
// Initialize WebSocket connection
|
||||
if config.WSEndpoint != "" {
|
||||
wsClient, err := rpc.DialWebsocket(context.Background(), config.WSEndpoint, "")
|
||||
if err != nil {
|
||||
// Don't fail if WS connection fails, HTTP might still work
|
||||
fmt.Printf("Warning: failed to connect to WebSocket endpoint %s: %v\n", config.WSEndpoint, err)
|
||||
} else {
|
||||
provider.WSConn = wsClient
|
||||
provider.WSClient = ethclient.NewClient(wsClient)
|
||||
}
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// GetHealthyProvider returns the next healthy provider based on rotation strategy
|
||||
func (pm *ProviderManager) GetHealthyProvider() (*Provider, error) {
|
||||
pm.mutex.RLock()
|
||||
defer pm.mutex.RUnlock()
|
||||
|
||||
if len(pm.providers) == 0 {
|
||||
return nil, fmt.Errorf("no providers available")
|
||||
}
|
||||
|
||||
switch pm.config.Rotation.Strategy {
|
||||
case "round_robin":
|
||||
return pm.getNextRoundRobin()
|
||||
case "weighted":
|
||||
return pm.getWeightedProvider()
|
||||
case "priority_based":
|
||||
return pm.getPriorityProvider()
|
||||
default:
|
||||
return pm.getNextRoundRobin()
|
||||
}
|
||||
}
|
||||
|
||||
// getNextRoundRobin implements round-robin provider selection
|
||||
func (pm *ProviderManager) getNextRoundRobin() (*Provider, error) {
|
||||
startIndex := pm.currentProvider
|
||||
|
||||
for i := 0; i < len(pm.providers); i++ {
|
||||
index := (startIndex + i) % len(pm.providers)
|
||||
provider := pm.providers[index]
|
||||
|
||||
if pm.isProviderUsable(provider) {
|
||||
pm.currentProvider = (index + 1) % len(pm.providers)
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no healthy providers available")
|
||||
}
|
||||
|
||||
// getPriorityProvider returns the highest priority healthy provider
|
||||
func (pm *ProviderManager) getPriorityProvider() (*Provider, error) {
|
||||
var bestProvider *Provider
|
||||
highestPriority := int(^uint(0) >> 1) // Max int
|
||||
|
||||
for _, provider := range pm.providers {
|
||||
if pm.isProviderUsable(provider) && provider.Config.Priority < highestPriority {
|
||||
bestProvider = provider
|
||||
highestPriority = provider.Config.Priority
|
||||
}
|
||||
}
|
||||
|
||||
if bestProvider == nil {
|
||||
return nil, fmt.Errorf("no healthy providers available")
|
||||
}
|
||||
|
||||
return bestProvider, nil
|
||||
}
|
||||
|
||||
// getWeightedProvider implements weighted provider selection based on performance
|
||||
func (pm *ProviderManager) getWeightedProvider() (*Provider, error) {
|
||||
// For now, fallback to priority-based selection
|
||||
// In a full implementation, this would consider response times and success rates
|
||||
return pm.getPriorityProvider()
|
||||
}
|
||||
|
||||
// isProviderUsable checks if a provider is healthy and within rate limits
|
||||
func (pm *ProviderManager) isProviderUsable(provider *Provider) bool {
|
||||
provider.mutex.RLock()
|
||||
defer provider.mutex.RUnlock()
|
||||
|
||||
// Check health status
|
||||
if pm.config.Rotation.HealthCheckRequired && !provider.IsHealthy {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
if !provider.RateLimiter.Allow() {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetHTTPClient returns an HTTP client for the current provider
|
||||
func (pm *ProviderManager) GetHTTPClient() (*ethclient.Client, error) {
|
||||
provider, err := pm.GetHealthyProvider()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if provider.HTTPClient == nil {
|
||||
return nil, fmt.Errorf("provider %s has no HTTP client", provider.Config.Name)
|
||||
}
|
||||
|
||||
return provider.HTTPClient, nil
|
||||
}
|
||||
|
||||
// GetWSClient returns a WebSocket client for the current provider
|
||||
func (pm *ProviderManager) GetWSClient() (*ethclient.Client, error) {
|
||||
provider, err := pm.GetHealthyProvider()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if provider.WSClient == nil {
|
||||
return nil, fmt.Errorf("provider %s has no WebSocket client", provider.Config.Name)
|
||||
}
|
||||
|
||||
return provider.WSClient, nil
|
||||
}
|
||||
|
||||
// GetRPCClient returns a raw RPC client for advanced operations
|
||||
func (pm *ProviderManager) GetRPCClient(preferWS bool) (*rpc.Client, error) {
|
||||
provider, err := pm.GetHealthyProvider()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if preferWS && provider.WSConn != nil {
|
||||
return provider.WSConn, nil
|
||||
}
|
||||
|
||||
if provider.HTTPConn != nil {
|
||||
return provider.HTTPConn, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("provider %s has no available RPC client", provider.Config.Name)
|
||||
}
|
||||
|
||||
// startBackgroundTasks starts health checking and metrics collection
|
||||
func (pm *ProviderManager) startBackgroundTasks() {
|
||||
// Start health checks
|
||||
if pm.config.Monitoring.Enabled {
|
||||
pm.healthTicker = time.NewTicker(time.Minute) // Default 1 minute
|
||||
go pm.healthCheckLoop()
|
||||
|
||||
pm.metricsTicker = time.NewTicker(pm.config.Monitoring.MetricsInterval)
|
||||
go pm.metricsLoop()
|
||||
}
|
||||
}
|
||||
|
||||
// healthCheckLoop periodically checks provider health
|
||||
func (pm *ProviderManager) healthCheckLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-pm.healthTicker.C:
|
||||
pm.performHealthChecks()
|
||||
case <-pm.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// metricsLoop periodically collects provider metrics
|
||||
func (pm *ProviderManager) metricsLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-pm.metricsTicker.C:
|
||||
pm.collectMetrics()
|
||||
case <-pm.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthChecks checks all providers' health
|
||||
func (pm *ProviderManager) performHealthChecks() {
|
||||
for _, provider := range pm.providers {
|
||||
go pm.checkProviderHealth(provider)
|
||||
}
|
||||
}
|
||||
|
||||
// checkProviderHealth performs a health check on a single provider
|
||||
func (pm *ProviderManager) checkProviderHealth(provider *Provider) {
|
||||
pm.performProviderHealthCheck(provider, func(ctx context.Context, provider *Provider) error {
|
||||
// Try to get latest block number as health check
|
||||
if provider.HTTPClient != nil {
|
||||
_, err := provider.HTTPClient.BlockNumber(ctx)
|
||||
return err
|
||||
} else if provider.WSClient != nil {
|
||||
_, err := provider.WSClient.BlockNumber(ctx)
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("no client available for health check")
|
||||
})
|
||||
}
|
||||
|
||||
// RACE CONDITION FIX: performProviderHealthCheck executes health check with proper synchronization
|
||||
func (pm *ProviderManager) performProviderHealthCheck(provider *Provider, healthChecker func(context.Context, *Provider) error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), provider.Config.HealthCheck.Timeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
err := healthChecker(ctx, provider)
|
||||
duration := time.Since(start)
|
||||
|
||||
// RACE CONDITION FIX: Use atomic operations for counters
|
||||
atomic.AddInt64(&provider.RequestCount, 1)
|
||||
|
||||
provider.mutex.Lock()
|
||||
defer provider.mutex.Unlock()
|
||||
|
||||
provider.LastHealthCheck = time.Now()
|
||||
|
||||
if err != nil {
|
||||
// RACE CONDITION FIX: Use atomic operation for error count
|
||||
atomic.AddInt64(&provider.ErrorCount, 1)
|
||||
provider.IsHealthy = false
|
||||
} else {
|
||||
provider.IsHealthy = true
|
||||
}
|
||||
|
||||
// Update average response time
|
||||
// Simple moving average calculation
|
||||
if provider.AvgResponseTime == 0 {
|
||||
provider.AvgResponseTime = duration
|
||||
} else {
|
||||
// Weight new measurement at 20% to smooth out spikes
|
||||
provider.AvgResponseTime = time.Duration(
|
||||
float64(provider.AvgResponseTime)*0.8 + float64(duration)*0.2,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// RACE CONDITION FIX: IncrementRequestCount safely increments request counter
|
||||
func (p *Provider) IncrementRequestCount() {
|
||||
atomic.AddInt64(&p.RequestCount, 1)
|
||||
}
|
||||
|
||||
// RACE CONDITION FIX: IncrementErrorCount safely increments error counter
|
||||
func (p *Provider) IncrementErrorCount() {
|
||||
atomic.AddInt64(&p.ErrorCount, 1)
|
||||
}
|
||||
|
||||
// RACE CONDITION FIX: GetRequestCount safely gets request count
|
||||
func (p *Provider) GetRequestCount() int64 {
|
||||
return atomic.LoadInt64(&p.RequestCount)
|
||||
}
|
||||
|
||||
// RACE CONDITION FIX: GetErrorCount safely gets error count
|
||||
func (p *Provider) GetErrorCount() int64 {
|
||||
return atomic.LoadInt64(&p.ErrorCount)
|
||||
}
|
||||
|
||||
// collectMetrics collects performance metrics
|
||||
func (pm *ProviderManager) collectMetrics() {
|
||||
// Implementation would collect and report metrics
|
||||
// For now, just log basic stats
|
||||
}
|
||||
|
||||
// Close shuts down the provider manager
|
||||
func (pm *ProviderManager) Close() error {
|
||||
close(pm.stopChan)
|
||||
|
||||
if pm.healthTicker != nil {
|
||||
pm.healthTicker.Stop()
|
||||
}
|
||||
if pm.metricsTicker != nil {
|
||||
pm.metricsTicker.Stop()
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
for _, provider := range pm.providers {
|
||||
if provider.HTTPConn != nil {
|
||||
provider.HTTPConn.Close()
|
||||
}
|
||||
if provider.WSConn != nil {
|
||||
provider.WSConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProviderStats returns current provider statistics
|
||||
func (pm *ProviderManager) GetProviderStats() map[string]interface{} {
|
||||
pm.mutex.RLock()
|
||||
defer pm.mutex.RUnlock()
|
||||
|
||||
stats := make(map[string]interface{})
|
||||
for _, provider := range pm.providers {
|
||||
provider.mutex.RLock()
|
||||
providerStats := map[string]interface{}{
|
||||
"name": provider.Config.Name,
|
||||
"healthy": provider.IsHealthy,
|
||||
"last_health_check": provider.LastHealthCheck,
|
||||
"request_count": provider.GetRequestCount(), // RACE CONDITION FIX: Use atomic getter
|
||||
"error_count": provider.GetErrorCount(), // RACE CONDITION FIX: Use atomic getter
|
||||
"avg_response_time": provider.AvgResponseTime,
|
||||
}
|
||||
provider.mutex.RUnlock()
|
||||
stats[provider.Config.Name] = providerStats
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
1282
orig/pkg/transport/provider_pools.go
Normal file
1282
orig/pkg/transport/provider_pools.go
Normal file
File diff suppressed because it is too large
Load Diff
513
orig/pkg/transport/router.go
Normal file
513
orig/pkg/transport/router.go
Normal file
@@ -0,0 +1,513 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// cryptoRandInt returns a random integer in [0,n) using crypto/rand
|
||||
func cryptoRandInt(n int) (int, error) {
|
||||
if n <= 0 {
|
||||
return 0, fmt.Errorf("n must be positive")
|
||||
}
|
||||
max := big.NewInt(int64(n))
|
||||
randNum, err := rand.Int(rand.Reader, max)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(randNum.Int64()), nil
|
||||
}
|
||||
|
||||
// cryptoRandFloat returns a random float64 in [0.0,1.0) using crypto/rand
|
||||
func cryptoRandFloat() (float64, error) {
|
||||
// Generate a random number in [0, 2^53) and divide by 2^53
|
||||
max := new(big.Int).Exp(big.NewInt(2), big.NewInt(53), nil) // 2^53
|
||||
randNum, err := rand.Int(rand.Reader, max)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return float64(randNum.Int64()) / float64(max.Int64()), nil
|
||||
}
|
||||
|
||||
// MessageRouter handles intelligent message routing and transport selection
|
||||
type MessageRouter struct {
|
||||
rules []RoutingRule
|
||||
fallback TransportType
|
||||
loadBalancer LoadBalancer
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// RoutingRule defines message routing logic
|
||||
type RoutingRule struct {
|
||||
ID string
|
||||
Name string
|
||||
Condition MessageFilter
|
||||
Transport TransportType
|
||||
Priority int
|
||||
Enabled bool
|
||||
Created time.Time
|
||||
LastUsed time.Time
|
||||
UsageCount int64
|
||||
}
|
||||
|
||||
// RouteMessage selects the appropriate transport for a message
|
||||
func (mr *MessageRouter) RouteMessage(msg *Message, transports map[TransportType]Transport) (Transport, error) {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
|
||||
// Find matching rules (sorted by priority)
|
||||
matchingRules := mr.findMatchingRules(msg)
|
||||
|
||||
// Try each matching rule in priority order
|
||||
for _, rule := range matchingRules {
|
||||
if transport, exists := transports[rule.Transport]; exists {
|
||||
// Check transport health
|
||||
if health := transport.Health(); health.Status == "healthy" {
|
||||
mr.updateRuleUsage(rule.ID)
|
||||
return transport, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use load balancer for available transports
|
||||
if mr.loadBalancer != nil {
|
||||
availableTransports := mr.getHealthyTransports(transports)
|
||||
if len(availableTransports) > 0 {
|
||||
selectedType := mr.loadBalancer.SelectTransport(availableTransports, msg)
|
||||
if transport, exists := transports[selectedType]; exists {
|
||||
return transport, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default transport
|
||||
if fallbackTransport, exists := transports[mr.fallback]; exists {
|
||||
if health := fallbackTransport.Health(); health.Status != "unhealthy" {
|
||||
return fallbackTransport, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no available transport for message")
|
||||
}
|
||||
|
||||
// AddRule adds a new routing rule
|
||||
func (mr *MessageRouter) AddRule(rule RoutingRule) {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
if rule.ID == "" {
|
||||
rule.ID = fmt.Sprintf("rule_%d", time.Now().UnixNano())
|
||||
}
|
||||
rule.Created = time.Now()
|
||||
rule.Enabled = true
|
||||
|
||||
mr.rules = append(mr.rules, rule)
|
||||
mr.sortRulesByPriority()
|
||||
}
|
||||
|
||||
// RemoveRule removes a routing rule by ID
|
||||
func (mr *MessageRouter) RemoveRule(ruleID string) bool {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
for i, rule := range mr.rules {
|
||||
if rule.ID == ruleID {
|
||||
mr.rules = append(mr.rules[:i], mr.rules[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateRule updates an existing routing rule
|
||||
func (mr *MessageRouter) UpdateRule(ruleID string, updates func(*RoutingRule)) bool {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
for i := range mr.rules {
|
||||
if mr.rules[i].ID == ruleID {
|
||||
updates(&mr.rules[i])
|
||||
mr.sortRulesByPriority()
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetRules returns all routing rules
|
||||
func (mr *MessageRouter) GetRules() []RoutingRule {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
|
||||
rules := make([]RoutingRule, len(mr.rules))
|
||||
copy(rules, mr.rules)
|
||||
return rules
|
||||
}
|
||||
|
||||
// EnableRule enables a routing rule
|
||||
func (mr *MessageRouter) EnableRule(ruleID string) bool {
|
||||
return mr.UpdateRule(ruleID, func(rule *RoutingRule) {
|
||||
rule.Enabled = true
|
||||
})
|
||||
}
|
||||
|
||||
// DisableRule disables a routing rule
|
||||
func (mr *MessageRouter) DisableRule(ruleID string) bool {
|
||||
return mr.UpdateRule(ruleID, func(rule *RoutingRule) {
|
||||
rule.Enabled = false
|
||||
})
|
||||
}
|
||||
|
||||
// SetFallbackTransport sets the fallback transport type
|
||||
func (mr *MessageRouter) SetFallbackTransport(transportType TransportType) {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
mr.fallback = transportType
|
||||
}
|
||||
|
||||
// SetLoadBalancer sets the load balancer
|
||||
func (mr *MessageRouter) SetLoadBalancer(lb LoadBalancer) {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
mr.loadBalancer = lb
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (mr *MessageRouter) findMatchingRules(msg *Message) []RoutingRule {
|
||||
var matching []RoutingRule
|
||||
|
||||
for _, rule := range mr.rules {
|
||||
if rule.Enabled && (rule.Condition == nil || rule.Condition(msg)) {
|
||||
matching = append(matching, rule)
|
||||
}
|
||||
}
|
||||
|
||||
return matching
|
||||
}
|
||||
|
||||
func (mr *MessageRouter) sortRulesByPriority() {
|
||||
sort.Slice(mr.rules, func(i, j int) bool {
|
||||
return mr.rules[i].Priority > mr.rules[j].Priority
|
||||
})
|
||||
}
|
||||
|
||||
func (mr *MessageRouter) updateRuleUsage(ruleID string) {
|
||||
for i := range mr.rules {
|
||||
if mr.rules[i].ID == ruleID {
|
||||
mr.rules[i].LastUsed = time.Now()
|
||||
mr.rules[i].UsageCount++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *MessageRouter) getHealthyTransports(transports map[TransportType]Transport) []TransportType {
|
||||
var healthy []TransportType
|
||||
|
||||
for transportType, transport := range transports {
|
||||
if health := transport.Health(); health.Status == "healthy" {
|
||||
healthy = append(healthy, transportType)
|
||||
}
|
||||
}
|
||||
|
||||
return healthy
|
||||
}
|
||||
|
||||
// LoadBalancer implementations
|
||||
|
||||
// RoundRobinLoadBalancer implements round-robin load balancing
|
||||
type RoundRobinLoadBalancer struct {
|
||||
counter int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewRoundRobinLoadBalancer() *RoundRobinLoadBalancer {
|
||||
return &RoundRobinLoadBalancer{}
|
||||
}
|
||||
|
||||
func (lb *RoundRobinLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
|
||||
if len(transports) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
selected := transports[lb.counter%int64(len(transports))]
|
||||
lb.counter++
|
||||
return selected
|
||||
}
|
||||
|
||||
func (lb *RoundRobinLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
|
||||
// Round-robin doesn't use stats
|
||||
}
|
||||
|
||||
// WeightedLoadBalancer implements weighted load balancing based on performance
|
||||
type WeightedLoadBalancer struct {
|
||||
stats map[TransportType]*TransportStats
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type TransportStats struct {
|
||||
TotalRequests int64
|
||||
SuccessRequests int64
|
||||
TotalLatency time.Duration
|
||||
LastUpdate time.Time
|
||||
Weight float64
|
||||
}
|
||||
|
||||
func NewWeightedLoadBalancer() *WeightedLoadBalancer {
|
||||
return &WeightedLoadBalancer{
|
||||
stats: make(map[TransportType]*TransportStats),
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *WeightedLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
|
||||
if len(transports) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
lb.mu.RLock()
|
||||
defer lb.mu.RUnlock()
|
||||
|
||||
// Calculate weights and select based on weighted random selection
|
||||
totalWeight := 0.0
|
||||
weights := make(map[TransportType]float64)
|
||||
|
||||
for _, transport := range transports {
|
||||
weight := lb.calculateWeight(transport)
|
||||
weights[transport] = weight
|
||||
totalWeight += weight
|
||||
}
|
||||
|
||||
if totalWeight == 0 {
|
||||
// Fall back to random selection using crypto/rand
|
||||
idx, err := cryptoRandInt(len(transports))
|
||||
if err != nil {
|
||||
// Fallback to first transport if crypto/rand fails
|
||||
return transports[0]
|
||||
}
|
||||
return transports[idx]
|
||||
}
|
||||
|
||||
// Weighted random selection
|
||||
targetF, err := cryptoRandFloat()
|
||||
if err != nil {
|
||||
// Fallback to first transport if crypto/rand fails
|
||||
return transports[0]
|
||||
}
|
||||
target := targetF * totalWeight
|
||||
current := 0.0
|
||||
|
||||
for _, transport := range transports {
|
||||
current += weights[transport]
|
||||
if current >= target {
|
||||
return transport
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback (shouldn't happen)
|
||||
return transports[0]
|
||||
}
|
||||
|
||||
func (lb *WeightedLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
stats, exists := lb.stats[transport]
|
||||
if !exists {
|
||||
stats = &TransportStats{
|
||||
Weight: 1.0, // Default weight
|
||||
}
|
||||
lb.stats[transport] = stats
|
||||
}
|
||||
|
||||
stats.TotalRequests++
|
||||
stats.TotalLatency += latency
|
||||
stats.LastUpdate = time.Now()
|
||||
|
||||
if success {
|
||||
stats.SuccessRequests++
|
||||
}
|
||||
|
||||
// Recalculate weight based on performance
|
||||
stats.Weight = lb.calculateWeight(transport)
|
||||
}
|
||||
|
||||
func (lb *WeightedLoadBalancer) calculateWeight(transport TransportType) float64 {
|
||||
stats, exists := lb.stats[transport]
|
||||
if !exists {
|
||||
return 1.0 // Default weight for unknown transports
|
||||
}
|
||||
|
||||
if stats.TotalRequests == 0 {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
// Calculate success rate
|
||||
successRate := float64(stats.SuccessRequests) / float64(stats.TotalRequests)
|
||||
|
||||
// Calculate average latency
|
||||
avgLatency := stats.TotalLatency / time.Duration(stats.TotalRequests)
|
||||
|
||||
// Weight formula: success rate / (latency factor)
|
||||
// Lower latency and higher success rate = higher weight
|
||||
latencyFactor := float64(avgLatency) / float64(time.Millisecond)
|
||||
if latencyFactor < 1 {
|
||||
latencyFactor = 1
|
||||
}
|
||||
|
||||
weight := successRate / latencyFactor
|
||||
|
||||
// Ensure minimum weight
|
||||
if weight < 0.1 {
|
||||
weight = 0.1
|
||||
}
|
||||
|
||||
return weight
|
||||
}
|
||||
|
||||
// LeastLatencyLoadBalancer selects the transport with the lowest latency
|
||||
type LeastLatencyLoadBalancer struct {
|
||||
stats map[TransportType]*LatencyStats
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type LatencyStats struct {
|
||||
RecentLatencies []time.Duration
|
||||
MaxSamples int
|
||||
LastUpdate time.Time
|
||||
}
|
||||
|
||||
func NewLeastLatencyLoadBalancer() *LeastLatencyLoadBalancer {
|
||||
return &LeastLatencyLoadBalancer{
|
||||
stats: make(map[TransportType]*LatencyStats),
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LeastLatencyLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType {
|
||||
if len(transports) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
lb.mu.RLock()
|
||||
defer lb.mu.RUnlock()
|
||||
|
||||
bestTransport := transports[0]
|
||||
bestLatency := time.Hour // Large initial value
|
||||
|
||||
for _, transport := range transports {
|
||||
avgLatency := lb.getAverageLatency(transport)
|
||||
if avgLatency < bestLatency {
|
||||
bestLatency = avgLatency
|
||||
bestTransport = transport
|
||||
}
|
||||
}
|
||||
|
||||
return bestTransport
|
||||
}
|
||||
|
||||
func (lb *LeastLatencyLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) {
|
||||
if !success {
|
||||
return // Only track successful requests
|
||||
}
|
||||
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
stats, exists := lb.stats[transport]
|
||||
if !exists {
|
||||
stats = &LatencyStats{
|
||||
RecentLatencies: make([]time.Duration, 0),
|
||||
MaxSamples: 10, // Keep last 10 samples
|
||||
}
|
||||
lb.stats[transport] = stats
|
||||
}
|
||||
|
||||
// Add new latency sample
|
||||
stats.RecentLatencies = append(stats.RecentLatencies, latency)
|
||||
|
||||
// Keep only recent samples
|
||||
if len(stats.RecentLatencies) > stats.MaxSamples {
|
||||
stats.RecentLatencies = stats.RecentLatencies[1:]
|
||||
}
|
||||
|
||||
stats.LastUpdate = time.Now()
|
||||
}
|
||||
|
||||
func (lb *LeastLatencyLoadBalancer) getAverageLatency(transport TransportType) time.Duration {
|
||||
stats, exists := lb.stats[transport]
|
||||
if !exists || len(stats.RecentLatencies) == 0 {
|
||||
return time.Millisecond * 100 // Default estimate
|
||||
}
|
||||
|
||||
total := time.Duration(0)
|
||||
for _, latency := range stats.RecentLatencies {
|
||||
total += latency
|
||||
}
|
||||
|
||||
return total / time.Duration(len(stats.RecentLatencies))
|
||||
}
|
||||
|
||||
// Common routing rule factory functions
|
||||
|
||||
// CreateTopicRule creates a rule based on message topic
|
||||
func CreateTopicRule(name string, topic string, transport TransportType, priority int) RoutingRule {
|
||||
return RoutingRule{
|
||||
Name: name,
|
||||
Condition: func(msg *Message) bool { return msg.Topic == topic },
|
||||
Transport: transport,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTopicPatternRule creates a rule based on topic pattern matching
|
||||
func CreateTopicPatternRule(name string, pattern string, transport TransportType, priority int) RoutingRule {
|
||||
return RoutingRule{
|
||||
Name: name,
|
||||
Condition: func(msg *Message) bool {
|
||||
// Simple pattern matching (can be enhanced with regex)
|
||||
return msg.Topic == pattern ||
|
||||
(len(pattern) > 0 && pattern[len(pattern)-1] == '*' &&
|
||||
len(msg.Topic) >= len(pattern)-1 &&
|
||||
msg.Topic[:len(pattern)-1] == pattern[:len(pattern)-1])
|
||||
},
|
||||
Transport: transport,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePriorityRule creates a rule based on message priority
|
||||
func CreatePriorityRule(name string, msgPriority MessagePriority, transport TransportType, priority int) RoutingRule {
|
||||
return RoutingRule{
|
||||
Name: name,
|
||||
Condition: func(msg *Message) bool { return msg.Priority == msgPriority },
|
||||
Transport: transport,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTypeRule creates a rule based on message type
|
||||
func CreateTypeRule(name string, msgType MessageType, transport TransportType, priority int) RoutingRule {
|
||||
return RoutingRule{
|
||||
Name: name,
|
||||
Condition: func(msg *Message) bool { return msg.Type == msgType },
|
||||
Transport: transport,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSourceRule creates a rule based on message source
|
||||
func CreateSourceRule(name string, source string, transport TransportType, priority int) RoutingRule {
|
||||
return RoutingRule{
|
||||
Name: name,
|
||||
Condition: func(msg *Message) bool { return msg.Source == source },
|
||||
Transport: transport,
|
||||
Priority: priority,
|
||||
}
|
||||
}
|
||||
629
orig/pkg/transport/serialization.go
Normal file
629
orig/pkg/transport/serialization.go
Normal file
@@ -0,0 +1,629 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SerializationFormat defines supported serialization formats
|
||||
type SerializationFormat string
|
||||
|
||||
const (
|
||||
SerializationJSON SerializationFormat = "json"
|
||||
SerializationMsgPack SerializationFormat = "msgpack"
|
||||
SerializationProtobuf SerializationFormat = "protobuf"
|
||||
SerializationAvro SerializationFormat = "avro"
|
||||
)
|
||||
|
||||
// CompressionType defines supported compression algorithms
|
||||
type CompressionType string
|
||||
|
||||
const (
|
||||
CompressionNone CompressionType = "none"
|
||||
CompressionGZip CompressionType = "gzip"
|
||||
CompressionLZ4 CompressionType = "lz4"
|
||||
CompressionSnappy CompressionType = "snappy"
|
||||
)
|
||||
|
||||
// SerializationConfig configures serialization behavior
|
||||
type SerializationConfig struct {
|
||||
Format SerializationFormat
|
||||
Compression CompressionType
|
||||
Encryption bool
|
||||
Validation bool
|
||||
}
|
||||
|
||||
// SerializedMessage represents a serialized message with metadata
|
||||
type SerializedMessage struct {
|
||||
Format SerializationFormat `json:"format"`
|
||||
Compression CompressionType `json:"compression"`
|
||||
Encrypted bool `json:"encrypted"`
|
||||
Checksum string `json:"checksum"`
|
||||
Data []byte `json:"data"`
|
||||
Size int `json:"size"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Serializer interface defines serialization operations
|
||||
type Serializer interface {
|
||||
Serialize(msg *Message) (*SerializedMessage, error)
|
||||
Deserialize(serialized *SerializedMessage) (*Message, error)
|
||||
GetFormat() SerializationFormat
|
||||
GetConfig() SerializationConfig
|
||||
}
|
||||
|
||||
// SerializationLayer manages multiple serializers and format selection
|
||||
type SerializationLayer struct {
|
||||
serializers map[SerializationFormat]Serializer
|
||||
defaultFormat SerializationFormat
|
||||
compressor Compressor
|
||||
encryptor Encryptor
|
||||
validator Validator
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Compressor interface for data compression
|
||||
type Compressor interface {
|
||||
Compress(data []byte, algorithm CompressionType) ([]byte, error)
|
||||
Decompress(data []byte, algorithm CompressionType) ([]byte, error)
|
||||
GetSupportedAlgorithms() []CompressionType
|
||||
}
|
||||
|
||||
// Encryptor interface for data encryption
|
||||
type Encryptor interface {
|
||||
Encrypt(data []byte) ([]byte, error)
|
||||
Decrypt(data []byte) ([]byte, error)
|
||||
IsEnabled() bool
|
||||
}
|
||||
|
||||
// Validator interface for data validation
|
||||
type Validator interface {
|
||||
Validate(msg *Message) error
|
||||
GenerateChecksum(data []byte) string
|
||||
VerifyChecksum(data []byte, checksum string) bool
|
||||
}
|
||||
|
||||
// NewSerializationLayer creates a new serialization layer
|
||||
func NewSerializationLayer() *SerializationLayer {
|
||||
sl := &SerializationLayer{
|
||||
serializers: make(map[SerializationFormat]Serializer),
|
||||
defaultFormat: SerializationJSON,
|
||||
compressor: NewDefaultCompressor(),
|
||||
validator: NewDefaultValidator(),
|
||||
}
|
||||
|
||||
// Register default serializers
|
||||
sl.RegisterSerializer(NewJSONSerializer())
|
||||
|
||||
return sl
|
||||
}
|
||||
|
||||
// RegisterSerializer registers a new serializer
|
||||
func (sl *SerializationLayer) RegisterSerializer(serializer Serializer) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
sl.serializers[serializer.GetFormat()] = serializer
|
||||
}
|
||||
|
||||
// SetDefaultFormat sets the default serialization format
|
||||
func (sl *SerializationLayer) SetDefaultFormat(format SerializationFormat) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
sl.defaultFormat = format
|
||||
}
|
||||
|
||||
// SetCompressor sets the compression handler
|
||||
func (sl *SerializationLayer) SetCompressor(compressor Compressor) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
sl.compressor = compressor
|
||||
}
|
||||
|
||||
// SetEncryptor sets the encryption handler
|
||||
func (sl *SerializationLayer) SetEncryptor(encryptor Encryptor) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
sl.encryptor = encryptor
|
||||
}
|
||||
|
||||
// SetValidator sets the validation handler
|
||||
func (sl *SerializationLayer) SetValidator(validator Validator) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
sl.validator = validator
|
||||
}
|
||||
|
||||
// Serialize serializes a message using the specified or default format
|
||||
func (sl *SerializationLayer) Serialize(msg *Message, format ...SerializationFormat) (*SerializedMessage, error) {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
// Determine format to use
|
||||
selectedFormat := sl.defaultFormat
|
||||
if len(format) > 0 {
|
||||
selectedFormat = format[0]
|
||||
}
|
||||
|
||||
// Get serializer
|
||||
serializer, exists := sl.serializers[selectedFormat]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported serialization format: %s", selectedFormat)
|
||||
}
|
||||
|
||||
// Validate message if validator is configured
|
||||
if sl.validator != nil {
|
||||
if err := sl.validator.Validate(msg); err != nil {
|
||||
return nil, fmt.Errorf("message validation failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
serialized, err := serializer.Serialize(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialization failed: %w", err)
|
||||
}
|
||||
|
||||
// Apply compression if configured
|
||||
config := serializer.GetConfig()
|
||||
if config.Compression != CompressionNone && sl.compressor != nil {
|
||||
compressed, err := sl.compressor.Compress(serialized.Data, config.Compression)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compression failed: %w", err)
|
||||
}
|
||||
serialized.Data = compressed
|
||||
serialized.Compression = config.Compression
|
||||
}
|
||||
|
||||
// Apply encryption if configured
|
||||
if config.Encryption && sl.encryptor != nil && sl.encryptor.IsEnabled() {
|
||||
encrypted, err := sl.encryptor.Encrypt(serialized.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encryption failed: %w", err)
|
||||
}
|
||||
serialized.Data = encrypted
|
||||
serialized.Encrypted = true
|
||||
}
|
||||
|
||||
// Generate checksum
|
||||
if sl.validator != nil {
|
||||
serialized.Checksum = sl.validator.GenerateChecksum(serialized.Data)
|
||||
}
|
||||
|
||||
// Update metadata
|
||||
serialized.Size = len(serialized.Data)
|
||||
serialized.Timestamp = msg.Timestamp.UnixNano()
|
||||
|
||||
return serialized, nil
|
||||
}
|
||||
|
||||
// Deserialize deserializes a message
|
||||
func (sl *SerializationLayer) Deserialize(serialized *SerializedMessage) (*Message, error) {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
// Verify checksum if available
|
||||
if sl.validator != nil && serialized.Checksum != "" {
|
||||
if !sl.validator.VerifyChecksum(serialized.Data, serialized.Checksum) {
|
||||
return nil, fmt.Errorf("checksum verification failed")
|
||||
}
|
||||
}
|
||||
|
||||
data := serialized.Data
|
||||
|
||||
// Apply decryption if needed
|
||||
if serialized.Encrypted && sl.encryptor != nil && sl.encryptor.IsEnabled() {
|
||||
decrypted, err := sl.encryptor.Decrypt(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
data = decrypted
|
||||
}
|
||||
|
||||
// Apply decompression if needed
|
||||
if serialized.Compression != CompressionNone && sl.compressor != nil {
|
||||
decompressed, err := sl.compressor.Decompress(data, serialized.Compression)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompression failed: %w", err)
|
||||
}
|
||||
data = decompressed
|
||||
}
|
||||
|
||||
// Get serializer
|
||||
serializer, exists := sl.serializers[serialized.Format]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported serialization format: %s", serialized.Format)
|
||||
}
|
||||
|
||||
// Create temporary serialized message for deserializer
|
||||
tempSerialized := &SerializedMessage{
|
||||
Format: serialized.Format,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
// Deserialize message
|
||||
msg, err := serializer.Deserialize(tempSerialized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deserialization failed: %w", err)
|
||||
}
|
||||
|
||||
// Validate deserialized message if validator is configured
|
||||
if sl.validator != nil {
|
||||
if err := sl.validator.Validate(msg); err != nil {
|
||||
return nil, fmt.Errorf("deserialized message validation failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// GetSupportedFormats returns all supported serialization formats
|
||||
func (sl *SerializationLayer) GetSupportedFormats() []SerializationFormat {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
formats := make([]SerializationFormat, 0, len(sl.serializers))
|
||||
for format := range sl.serializers {
|
||||
formats = append(formats, format)
|
||||
}
|
||||
return formats
|
||||
}
|
||||
|
||||
// JSONSerializer implements JSON serialization
|
||||
type JSONSerializer struct {
|
||||
config SerializationConfig
|
||||
}
|
||||
|
||||
// NewJSONSerializer creates a new JSON serializer
|
||||
func NewJSONSerializer() *JSONSerializer {
|
||||
return &JSONSerializer{
|
||||
config: SerializationConfig{
|
||||
Format: SerializationJSON,
|
||||
Compression: CompressionNone,
|
||||
Encryption: false,
|
||||
Validation: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig updates the serializer configuration
|
||||
func (js *JSONSerializer) SetConfig(config SerializationConfig) {
|
||||
js.config = config
|
||||
js.config.Format = SerializationJSON // Ensure format is correct
|
||||
}
|
||||
|
||||
// Serialize serializes a message to JSON
|
||||
func (js *JSONSerializer) Serialize(msg *Message) (*SerializedMessage, error) {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("JSON marshal failed: %w", err)
|
||||
}
|
||||
|
||||
return &SerializedMessage{
|
||||
Format: SerializationJSON,
|
||||
Compression: CompressionNone,
|
||||
Encrypted: false,
|
||||
Data: data,
|
||||
Size: len(data),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Deserialize deserializes a message from JSON
|
||||
func (js *JSONSerializer) Deserialize(serialized *SerializedMessage) (*Message, error) {
|
||||
var msg Message
|
||||
if err := json.Unmarshal(serialized.Data, &msg); err != nil {
|
||||
return nil, fmt.Errorf("JSON unmarshal failed: %w", err)
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
// GetFormat returns the serialization format
|
||||
func (js *JSONSerializer) GetFormat() SerializationFormat {
|
||||
return SerializationJSON
|
||||
}
|
||||
|
||||
// GetConfig returns the serializer configuration
|
||||
func (js *JSONSerializer) GetConfig() SerializationConfig {
|
||||
return js.config
|
||||
}
|
||||
|
||||
// DefaultCompressor implements basic compression operations
|
||||
type DefaultCompressor struct {
|
||||
supportedAlgorithms []CompressionType
|
||||
}
|
||||
|
||||
// NewDefaultCompressor creates a new default compressor
|
||||
func NewDefaultCompressor() *DefaultCompressor {
|
||||
return &DefaultCompressor{
|
||||
supportedAlgorithms: []CompressionType{
|
||||
CompressionNone,
|
||||
CompressionGZip,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Compress compresses data using the specified algorithm
|
||||
func (dc *DefaultCompressor) Compress(data []byte, algorithm CompressionType) ([]byte, error) {
|
||||
switch algorithm {
|
||||
case CompressionNone:
|
||||
return data, nil
|
||||
|
||||
case CompressionGZip:
|
||||
var buf bytes.Buffer
|
||||
writer := gzip.NewWriter(&buf)
|
||||
|
||||
if _, err := writer.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("gzip write failed: %w", err)
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, fmt.Errorf("gzip close failed: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
// Decompress decompresses data using the specified algorithm
|
||||
func (dc *DefaultCompressor) Decompress(data []byte, algorithm CompressionType) ([]byte, error) {
|
||||
switch algorithm {
|
||||
case CompressionNone:
|
||||
return data, nil
|
||||
|
||||
case CompressionGZip:
|
||||
reader, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip reader creation failed: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip read failed: %w", err)
|
||||
}
|
||||
|
||||
return decompressed, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
// GetSupportedAlgorithms returns supported compression algorithms
|
||||
func (dc *DefaultCompressor) GetSupportedAlgorithms() []CompressionType {
|
||||
return dc.supportedAlgorithms
|
||||
}
|
||||
|
||||
// DefaultValidator implements basic message validation
|
||||
type DefaultValidator struct {
|
||||
strictMode bool
|
||||
}
|
||||
|
||||
// NewDefaultValidator creates a new default validator
|
||||
func NewDefaultValidator() *DefaultValidator {
|
||||
return &DefaultValidator{
|
||||
strictMode: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SetStrictMode enables/disables strict validation
|
||||
func (dv *DefaultValidator) SetStrictMode(enabled bool) {
|
||||
dv.strictMode = enabled
|
||||
}
|
||||
|
||||
// Validate validates a message
|
||||
func (dv *DefaultValidator) Validate(msg *Message) error {
|
||||
if msg == nil {
|
||||
return fmt.Errorf("message is nil")
|
||||
}
|
||||
|
||||
if msg.ID == "" {
|
||||
return fmt.Errorf("message ID is empty")
|
||||
}
|
||||
|
||||
if msg.Topic == "" {
|
||||
return fmt.Errorf("message topic is empty")
|
||||
}
|
||||
|
||||
if msg.Source == "" {
|
||||
return fmt.Errorf("message source is empty")
|
||||
}
|
||||
|
||||
if msg.Type == "" {
|
||||
return fmt.Errorf("message type is empty")
|
||||
}
|
||||
|
||||
if dv.strictMode {
|
||||
if msg.Data == nil {
|
||||
return fmt.Errorf("message data is nil")
|
||||
}
|
||||
|
||||
if msg.Timestamp.IsZero() {
|
||||
return fmt.Errorf("message timestamp is zero")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateChecksum generates a simple checksum for data
|
||||
func (dv *DefaultValidator) GenerateChecksum(data []byte) string {
|
||||
// Simple checksum implementation
|
||||
// In production, use a proper hash function like SHA-256
|
||||
var sum uint32
|
||||
for _, b := range data {
|
||||
sum += uint32(b)
|
||||
}
|
||||
return fmt.Sprintf("%08x", sum)
|
||||
}
|
||||
|
||||
// VerifyChecksum verifies a checksum
|
||||
func (dv *DefaultValidator) VerifyChecksum(data []byte, checksum string) bool {
|
||||
return dv.GenerateChecksum(data) == checksum
|
||||
}
|
||||
|
||||
// NoOpEncryptor implements a no-operation encryptor for testing
|
||||
type NoOpEncryptor struct {
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewNoOpEncryptor creates a new no-op encryptor
|
||||
func NewNoOpEncryptor() *NoOpEncryptor {
|
||||
return &NoOpEncryptor{enabled: false}
|
||||
}
|
||||
|
||||
// SetEnabled enables/disables the encryptor
|
||||
func (noe *NoOpEncryptor) SetEnabled(enabled bool) {
|
||||
noe.enabled = enabled
|
||||
}
|
||||
|
||||
// Encrypt returns data unchanged
|
||||
func (noe *NoOpEncryptor) Encrypt(data []byte) ([]byte, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Decrypt returns data unchanged
|
||||
func (noe *NoOpEncryptor) Decrypt(data []byte) ([]byte, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// IsEnabled returns whether encryption is enabled
|
||||
func (noe *NoOpEncryptor) IsEnabled() bool {
|
||||
return noe.enabled
|
||||
}
|
||||
|
||||
// SerializationMetrics tracks serialization performance
|
||||
type SerializationMetrics struct {
|
||||
SerializedMessages int64
|
||||
DeserializedMessages int64
|
||||
SerializationErrors int64
|
||||
CompressionRatio float64
|
||||
AverageMessageSize int64
|
||||
TotalDataProcessed int64
|
||||
}
|
||||
|
||||
// MetricsCollector collects serialization metrics
|
||||
type MetricsCollector struct {
|
||||
metrics SerializationMetrics
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMetricsCollector creates a new metrics collector
|
||||
func NewMetricsCollector() *MetricsCollector {
|
||||
return &MetricsCollector{}
|
||||
}
|
||||
|
||||
// RecordSerialization records a serialization operation
|
||||
func (mc *MetricsCollector) RecordSerialization(originalSize, serializedSize int) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
mc.metrics.SerializedMessages++
|
||||
mc.metrics.TotalDataProcessed += int64(originalSize)
|
||||
|
||||
// Update compression ratio
|
||||
if originalSize > 0 {
|
||||
ratio := float64(serializedSize) / float64(originalSize)
|
||||
mc.metrics.CompressionRatio = (mc.metrics.CompressionRatio + ratio) / 2
|
||||
}
|
||||
|
||||
// Update average message size
|
||||
mc.metrics.AverageMessageSize = mc.metrics.TotalDataProcessed / mc.metrics.SerializedMessages
|
||||
}
|
||||
|
||||
// RecordDeserialization records a deserialization operation
|
||||
func (mc *MetricsCollector) RecordDeserialization() {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.metrics.DeserializedMessages++
|
||||
}
|
||||
|
||||
// RecordError records a serialization error
|
||||
func (mc *MetricsCollector) RecordError() {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.metrics.SerializationErrors++
|
||||
}
|
||||
|
||||
// IncrementCounter increments a named counter
|
||||
func (mc *MetricsCollector) IncrementCounter(name string) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
// For simplicity, map all counters to serialization errors for now
|
||||
mc.metrics.SerializationErrors++
|
||||
}
|
||||
|
||||
// RecordLatency records a latency metric
|
||||
func (mc *MetricsCollector) RecordLatency(name string, duration time.Duration) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
// For now, we don't track specific latencies
|
||||
// This can be enhanced later with proper metrics storage
|
||||
}
|
||||
|
||||
// RecordEvent records an event metric
|
||||
func (mc *MetricsCollector) RecordEvent(name string) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.metrics.SerializationErrors++ // Simple implementation
|
||||
}
|
||||
|
||||
// RecordGauge records a gauge metric
|
||||
func (mc *MetricsCollector) RecordGauge(name string, value float64) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
// Simple implementation - not storing actual values
|
||||
}
|
||||
|
||||
// GetAll returns all metrics
|
||||
func (mc *MetricsCollector) GetAll() map[string]interface{} {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
return map[string]interface{}{
|
||||
"serialized_messages": mc.metrics.SerializedMessages,
|
||||
"deserialized_messages": mc.metrics.DeserializedMessages,
|
||||
"serialization_errors": mc.metrics.SerializationErrors,
|
||||
"compression_ratio": mc.metrics.CompressionRatio,
|
||||
"average_message_size": mc.metrics.AverageMessageSize,
|
||||
"total_data_processed": mc.metrics.TotalDataProcessed,
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a specific metric
|
||||
func (mc *MetricsCollector) Get(name string) interface{} {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
switch name {
|
||||
case "serialized_messages":
|
||||
return mc.metrics.SerializedMessages
|
||||
case "deserialized_messages":
|
||||
return mc.metrics.DeserializedMessages
|
||||
case "serialization_errors":
|
||||
return mc.metrics.SerializationErrors
|
||||
case "compression_ratio":
|
||||
return mc.metrics.CompressionRatio
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns current metrics
|
||||
func (mc *MetricsCollector) GetMetrics() SerializationMetrics {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
return mc.metrics
|
||||
}
|
||||
|
||||
// Reset resets all metrics
|
||||
func (mc *MetricsCollector) Reset() {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.metrics = SerializationMetrics{}
|
||||
}
|
||||
451
orig/pkg/transport/tcp_transport.go
Normal file
451
orig/pkg/transport/tcp_transport.go
Normal file
@@ -0,0 +1,451 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TCPTransport implements TCP transport for remote communication
|
||||
type TCPTransport struct {
|
||||
address string
|
||||
port int
|
||||
listener net.Listener
|
||||
connections map[string]net.Conn
|
||||
metrics TransportMetrics
|
||||
connected bool
|
||||
isServer bool
|
||||
receiveChan chan *Message
|
||||
tlsConfig *tls.Config
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
retryConfig RetryConfig
|
||||
}
|
||||
|
||||
// NewTCPTransport creates a new TCP transport
|
||||
func NewTCPTransport(address string, port int, isServer bool) *TCPTransport {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &TCPTransport{
|
||||
address: address,
|
||||
port: port,
|
||||
connections: make(map[string]net.Conn),
|
||||
metrics: TransportMetrics{},
|
||||
isServer: isServer,
|
||||
receiveChan: make(chan *Message, 1000),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
retryConfig: RetryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
Jitter: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetTLSConfig configures TLS for secure communication
|
||||
func (tt *TCPTransport) SetTLSConfig(config *tls.Config) {
|
||||
tt.tlsConfig = config
|
||||
}
|
||||
|
||||
// SetRetryConfig configures retry behavior
|
||||
func (tt *TCPTransport) SetRetryConfig(config RetryConfig) {
|
||||
tt.retryConfig = config
|
||||
}
|
||||
|
||||
// Connect establishes the TCP connection
|
||||
func (tt *TCPTransport) Connect(ctx context.Context) error {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
if tt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tt.isServer {
|
||||
return tt.startServer()
|
||||
} else {
|
||||
return tt.connectToServer(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnect closes the TCP connection
|
||||
func (tt *TCPTransport) Disconnect(ctx context.Context) error {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
if !tt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
tt.cancel()
|
||||
|
||||
if tt.isServer && tt.listener != nil {
|
||||
tt.listener.Close()
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
for id, conn := range tt.connections {
|
||||
conn.Close()
|
||||
delete(tt.connections, id)
|
||||
}
|
||||
|
||||
close(tt.receiveChan)
|
||||
tt.connected = false
|
||||
tt.metrics.Connections = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send transmits a message through TCP
|
||||
func (tt *TCPTransport) Send(ctx context.Context, msg *Message) error {
|
||||
start := time.Now()
|
||||
|
||||
tt.mu.RLock()
|
||||
if !tt.connected {
|
||||
tt.mu.RUnlock()
|
||||
tt.metrics.Errors++
|
||||
return fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
tt.mu.RUnlock()
|
||||
tt.metrics.Errors++
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Add length prefix for framing
|
||||
frame := fmt.Sprintf("%d\n%s", len(data), data)
|
||||
frameBytes := []byte(frame)
|
||||
|
||||
// Send to all connections with retry
|
||||
var sendErr error
|
||||
connectionCount := len(tt.connections)
|
||||
tt.mu.RUnlock()
|
||||
|
||||
if connectionCount == 0 {
|
||||
tt.metrics.Errors++
|
||||
return fmt.Errorf("no active connections")
|
||||
}
|
||||
|
||||
tt.mu.RLock()
|
||||
for connID, conn := range tt.connections {
|
||||
if err := tt.sendWithRetry(ctx, conn, frameBytes); err != nil {
|
||||
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
|
||||
// Remove failed connection
|
||||
go tt.removeConnection(connID)
|
||||
}
|
||||
}
|
||||
tt.mu.RUnlock()
|
||||
|
||||
if sendErr == nil {
|
||||
tt.updateSendMetrics(msg, time.Since(start))
|
||||
} else {
|
||||
tt.metrics.Errors++
|
||||
}
|
||||
|
||||
return sendErr
|
||||
}
|
||||
|
||||
// Receive returns a channel for receiving messages
|
||||
func (tt *TCPTransport) Receive(ctx context.Context) (<-chan *Message, error) {
|
||||
tt.mu.RLock()
|
||||
defer tt.mu.RUnlock()
|
||||
|
||||
if !tt.connected {
|
||||
return nil, fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
return tt.receiveChan, nil
|
||||
}
|
||||
|
||||
// Health returns the health status of the transport
|
||||
func (tt *TCPTransport) Health() ComponentHealth {
|
||||
tt.mu.RLock()
|
||||
defer tt.mu.RUnlock()
|
||||
|
||||
status := "unhealthy"
|
||||
var responseTime time.Duration
|
||||
|
||||
if tt.connected {
|
||||
if len(tt.connections) > 0 {
|
||||
status = "healthy"
|
||||
responseTime = time.Millisecond * 10 // Estimate for TCP
|
||||
} else {
|
||||
status = "degraded" // Connected but no active connections
|
||||
}
|
||||
}
|
||||
|
||||
return ComponentHealth{
|
||||
Status: status,
|
||||
LastCheck: time.Now(),
|
||||
ResponseTime: responseTime,
|
||||
ErrorCount: tt.metrics.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns transport-specific metrics
|
||||
func (tt *TCPTransport) GetMetrics() TransportMetrics {
|
||||
tt.mu.RLock()
|
||||
defer tt.mu.RUnlock()
|
||||
|
||||
return TransportMetrics{
|
||||
BytesSent: tt.metrics.BytesSent,
|
||||
BytesReceived: tt.metrics.BytesReceived,
|
||||
MessagesSent: tt.metrics.MessagesSent,
|
||||
MessagesReceived: tt.metrics.MessagesReceived,
|
||||
Connections: len(tt.connections),
|
||||
Errors: tt.metrics.Errors,
|
||||
Latency: tt.metrics.Latency,
|
||||
}
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (tt *TCPTransport) startServer() error {
|
||||
addr := fmt.Sprintf("%s:%d", tt.address, tt.port)
|
||||
|
||||
var listener net.Listener
|
||||
var err error
|
||||
|
||||
if tt.tlsConfig != nil {
|
||||
listener, err = tls.Listen("tcp", addr, tt.tlsConfig)
|
||||
} else {
|
||||
listener, err = net.Listen("tcp", addr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
tt.listener = listener
|
||||
tt.connected = true
|
||||
|
||||
// Start accepting connections
|
||||
go tt.acceptConnections()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) connectToServer(ctx context.Context) error {
|
||||
addr := net.JoinHostPort(tt.address, fmt.Sprintf("%d", tt.port))
|
||||
|
||||
var conn net.Conn
|
||||
var err error
|
||||
|
||||
// Retry connection with exponential backoff
|
||||
delay := tt.retryConfig.InitialDelay
|
||||
for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ {
|
||||
if tt.tlsConfig != nil {
|
||||
conn, err = tls.Dial("tcp", addr, tt.tlsConfig)
|
||||
} else {
|
||||
conn, err = net.Dial("tcp", addr)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if attempt == tt.retryConfig.MaxRetries {
|
||||
return fmt.Errorf("failed to connect to %s after %d attempts: %w", addr, attempt+1, err)
|
||||
}
|
||||
|
||||
// Wait with exponential backoff
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor)
|
||||
if delay > tt.retryConfig.MaxDelay {
|
||||
delay = tt.retryConfig.MaxDelay
|
||||
}
|
||||
// Add jitter if enabled
|
||||
if tt.retryConfig.Jitter {
|
||||
jitter := time.Duration(float64(delay) * 0.1)
|
||||
jitterFactor := float64(2*time.Now().UnixNano()%1000)/1000.0 - 1
|
||||
delay += time.Duration(float64(jitter) * jitterFactor)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
|
||||
tt.connections[connID] = conn
|
||||
tt.connected = true
|
||||
tt.metrics.Connections = 1
|
||||
|
||||
// Start receiving from server
|
||||
go tt.handleConnection(connID, conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) acceptConnections() {
|
||||
for {
|
||||
select {
|
||||
case <-tt.ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn, err := tt.listener.Accept()
|
||||
if err != nil {
|
||||
if tt.ctx.Err() != nil {
|
||||
return // Context cancelled
|
||||
}
|
||||
tt.metrics.Errors++
|
||||
continue
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
|
||||
|
||||
tt.mu.Lock()
|
||||
tt.connections[connID] = conn
|
||||
tt.metrics.Connections = len(tt.connections)
|
||||
tt.mu.Unlock()
|
||||
|
||||
go tt.handleConnection(connID, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) handleConnection(connID string, conn net.Conn) {
|
||||
defer tt.removeConnection(connID)
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
var messageBuffer []byte
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tt.ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue // Continue on timeout
|
||||
}
|
||||
return // Connection closed or error
|
||||
}
|
||||
|
||||
messageBuffer = append(messageBuffer, buffer[:n]...)
|
||||
|
||||
// Process complete messages
|
||||
for {
|
||||
msg, remaining, err := ExtractMessage(messageBuffer)
|
||||
if err != nil {
|
||||
return // Invalid message format
|
||||
}
|
||||
if msg == nil {
|
||||
break // No complete message yet
|
||||
}
|
||||
|
||||
// Deliver message
|
||||
select {
|
||||
case tt.receiveChan <- msg:
|
||||
tt.updateReceiveMetrics(msg)
|
||||
case <-tt.ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Channel full, drop message
|
||||
tt.metrics.Errors++
|
||||
}
|
||||
|
||||
messageBuffer = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) sendWithRetry(ctx context.Context, conn net.Conn, data []byte) error {
|
||||
delay := tt.retryConfig.InitialDelay
|
||||
|
||||
for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ {
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
_, err := conn.Write(data)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if attempt == tt.retryConfig.MaxRetries {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait with exponential backoff
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor)
|
||||
if delay > tt.retryConfig.MaxDelay {
|
||||
delay = tt.retryConfig.MaxDelay
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) removeConnection(connID string) {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
if conn, exists := tt.connections[connID]; exists {
|
||||
conn.Close()
|
||||
delete(tt.connections, connID)
|
||||
tt.metrics.Connections = len(tt.connections)
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) updateSendMetrics(msg *Message, latency time.Duration) {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
tt.metrics.MessagesSent++
|
||||
tt.metrics.Latency = latency
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
tt.metrics.BytesSent += messageSize
|
||||
}
|
||||
|
||||
func (tt *TCPTransport) updateReceiveMetrics(msg *Message) {
|
||||
tt.mu.Lock()
|
||||
defer tt.mu.Unlock()
|
||||
|
||||
tt.metrics.MessagesReceived++
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
tt.metrics.BytesReceived += messageSize
|
||||
}
|
||||
|
||||
// GetAddress returns the transport address (for testing/debugging)
|
||||
func (tt *TCPTransport) GetAddress() string {
|
||||
return fmt.Sprintf("%s:%d", tt.address, tt.port)
|
||||
}
|
||||
|
||||
// GetConnectionCount returns the number of active connections
|
||||
func (tt *TCPTransport) GetConnectionCount() int {
|
||||
tt.mu.RLock()
|
||||
defer tt.mu.RUnlock()
|
||||
return len(tt.connections)
|
||||
}
|
||||
|
||||
// IsSecure returns whether TLS is enabled
|
||||
func (tt *TCPTransport) IsSecure() bool {
|
||||
return tt.tlsConfig != nil
|
||||
}
|
||||
316
orig/pkg/transport/unified_provider_manager.go
Normal file
316
orig/pkg/transport/unified_provider_manager.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ethereum/go-ethereum/ethclient"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// UnifiedProviderManager manages all provider pools (read-only, execution, testing)
|
||||
type UnifiedProviderManager struct {
|
||||
ReadOnlyPool *ReadOnlyProviderPool
|
||||
ExecutionPool *ExecutionProviderPool
|
||||
TestingPool *TestingProviderPool
|
||||
config ProvidersConfig
|
||||
providerConfigs map[string]ProviderConfig
|
||||
}
|
||||
|
||||
// OperationMode defines the type of operation being performed
|
||||
type OperationMode int
|
||||
|
||||
const (
|
||||
ModeReadOnly OperationMode = iota
|
||||
ModeExecution
|
||||
ModeTesting
|
||||
)
|
||||
|
||||
// NewUnifiedProviderManager creates a new unified provider manager
|
||||
func NewUnifiedProviderManager(configPath string) (*UnifiedProviderManager, error) {
|
||||
// Load configuration
|
||||
config, err := LoadProvidersConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load provider config: %w", err)
|
||||
}
|
||||
|
||||
// Create provider configs map for easy lookup
|
||||
providerConfigs := make(map[string]ProviderConfig)
|
||||
for _, provider := range config.Providers {
|
||||
providerConfigs[provider.Name] = provider
|
||||
}
|
||||
|
||||
manager := &UnifiedProviderManager{
|
||||
config: config,
|
||||
providerConfigs: providerConfigs,
|
||||
}
|
||||
|
||||
// Initialize provider pools
|
||||
if err := manager.initializePools(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize provider pools: %w", err)
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
// initializePools initializes all provider pools based on configuration
|
||||
func (upm *UnifiedProviderManager) initializePools() error {
|
||||
var err error
|
||||
|
||||
// Initialize read-only pool if configured
|
||||
if poolConfig, exists := upm.config.ProviderPools["read_only"]; exists {
|
||||
upm.ReadOnlyPool, err = NewReadOnlyProviderPool(poolConfig, upm.providerConfigs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize read-only pool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize execution pool if configured
|
||||
if poolConfig, exists := upm.config.ProviderPools["execution"]; exists {
|
||||
upm.ExecutionPool, err = NewExecutionProviderPool(poolConfig, upm.providerConfigs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize execution pool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize testing pool if configured
|
||||
if poolConfig, exists := upm.config.ProviderPools["testing"]; exists {
|
||||
upm.TestingPool, err = NewTestingProviderPool(poolConfig, upm.providerConfigs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize testing pool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPoolForMode returns the appropriate provider pool for the given operation mode
|
||||
func (upm *UnifiedProviderManager) GetPoolForMode(mode OperationMode) (ProviderPool, error) {
|
||||
switch mode {
|
||||
case ModeReadOnly:
|
||||
if upm.ReadOnlyPool == nil {
|
||||
return nil, fmt.Errorf("read-only pool not initialized")
|
||||
}
|
||||
return upm.ReadOnlyPool, nil
|
||||
case ModeExecution:
|
||||
if upm.ExecutionPool == nil {
|
||||
return nil, fmt.Errorf("execution pool not initialized")
|
||||
}
|
||||
return upm.ExecutionPool, nil
|
||||
case ModeTesting:
|
||||
if upm.TestingPool == nil {
|
||||
return nil, fmt.Errorf("testing pool not initialized")
|
||||
}
|
||||
return upm.TestingPool, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown operation mode: %d", mode)
|
||||
}
|
||||
}
|
||||
|
||||
// GetReadOnlyHTTPClient returns an HTTP client optimized for read-only operations
|
||||
func (upm *UnifiedProviderManager) GetReadOnlyHTTPClient() (*ethclient.Client, error) {
|
||||
if upm.ReadOnlyPool == nil {
|
||||
return nil, fmt.Errorf("read-only pool not initialized")
|
||||
}
|
||||
return upm.ReadOnlyPool.GetHTTPClient()
|
||||
}
|
||||
|
||||
// GetReadOnlyWSClient returns a WebSocket client for real-time data
|
||||
func (upm *UnifiedProviderManager) GetReadOnlyWSClient() (*ethclient.Client, error) {
|
||||
if upm.ReadOnlyPool == nil {
|
||||
return nil, fmt.Errorf("read-only pool not initialized")
|
||||
}
|
||||
return upm.ReadOnlyPool.GetWSClient()
|
||||
}
|
||||
|
||||
// GetExecutionHTTPClient returns an HTTP client optimized for transaction execution
|
||||
func (upm *UnifiedProviderManager) GetExecutionHTTPClient() (*ethclient.Client, error) {
|
||||
if upm.ExecutionPool == nil {
|
||||
return nil, fmt.Errorf("execution pool not initialized")
|
||||
}
|
||||
return upm.ExecutionPool.GetHTTPClient()
|
||||
}
|
||||
|
||||
// GetTestingHTTPClient returns an HTTP client for testing (preferably Anvil)
|
||||
func (upm *UnifiedProviderManager) GetTestingHTTPClient() (*ethclient.Client, error) {
|
||||
if upm.TestingPool == nil {
|
||||
return nil, fmt.Errorf("testing pool not initialized")
|
||||
}
|
||||
return upm.TestingPool.GetHTTPClient()
|
||||
}
|
||||
|
||||
// GetAllStats returns statistics for all provider pools
|
||||
func (upm *UnifiedProviderManager) GetAllStats() map[string]interface{} {
|
||||
stats := make(map[string]interface{})
|
||||
|
||||
if upm.ReadOnlyPool != nil {
|
||||
stats["read_only"] = upm.ReadOnlyPool.GetStats()
|
||||
}
|
||||
|
||||
if upm.ExecutionPool != nil {
|
||||
stats["execution"] = upm.ExecutionPool.GetStats()
|
||||
}
|
||||
|
||||
if upm.TestingPool != nil {
|
||||
stats["testing"] = upm.TestingPool.GetStats()
|
||||
}
|
||||
|
||||
// Add overall summary
|
||||
summary := map[string]interface{}{
|
||||
"total_pools": len(stats),
|
||||
"pools_initialized": []string{},
|
||||
}
|
||||
|
||||
for poolName := range stats {
|
||||
summary["pools_initialized"] = append(summary["pools_initialized"].([]string), poolName)
|
||||
}
|
||||
|
||||
stats["summary"] = summary
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// CreateTestingSnapshot creates a snapshot in the testing environment
|
||||
func (upm *UnifiedProviderManager) CreateTestingSnapshot() (string, error) {
|
||||
if upm.TestingPool == nil {
|
||||
return "", fmt.Errorf("testing pool not initialized")
|
||||
}
|
||||
return upm.TestingPool.CreateSnapshot()
|
||||
}
|
||||
|
||||
// RevertTestingSnapshot reverts to a snapshot in the testing environment
|
||||
func (upm *UnifiedProviderManager) RevertTestingSnapshot(snapshotID string) error {
|
||||
if upm.TestingPool == nil {
|
||||
return fmt.Errorf("testing pool not initialized")
|
||||
}
|
||||
return upm.TestingPool.RevertToSnapshot(snapshotID)
|
||||
}
|
||||
|
||||
// Close shuts down all provider pools
|
||||
func (upm *UnifiedProviderManager) Close() error {
|
||||
var errors []error
|
||||
|
||||
if upm.ReadOnlyPool != nil {
|
||||
if err := upm.ReadOnlyPool.Close(); err != nil {
|
||||
errors = append(errors, fmt.Errorf("failed to close read-only pool: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if upm.ExecutionPool != nil {
|
||||
if err := upm.ExecutionPool.Close(); err != nil {
|
||||
errors = append(errors, fmt.Errorf("failed to close execution pool: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if upm.TestingPool != nil {
|
||||
if err := upm.TestingPool.Close(); err != nil {
|
||||
errors = append(errors, fmt.Errorf("failed to close testing pool: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("errors closing provider pools: %v", errors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadProvidersConfigFromFile loads configuration from a YAML file
|
||||
func LoadProvidersConfigFromFile(path string) (ProvidersConfig, error) {
|
||||
var config ProvidersConfig
|
||||
|
||||
// Read the YAML file
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return config, fmt.Errorf("failed to read config file %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Unmarshal the YAML data
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return config, fmt.Errorf("failed to parse YAML config: %w", err)
|
||||
}
|
||||
|
||||
// Validate the configuration
|
||||
if err := validateProvidersConfig(&config); err != nil {
|
||||
return config, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// validateProvidersConfig validates the provider configuration
|
||||
func validateProvidersConfig(config *ProvidersConfig) error {
|
||||
if len(config.Providers) == 0 {
|
||||
return fmt.Errorf("no providers configured")
|
||||
}
|
||||
|
||||
// Validate provider pools
|
||||
for poolName, poolConfig := range config.ProviderPools {
|
||||
if len(poolConfig.Providers) == 0 {
|
||||
return fmt.Errorf("provider pool '%s' has no providers", poolName)
|
||||
}
|
||||
|
||||
// Check that all referenced providers exist
|
||||
providerNames := make(map[string]bool)
|
||||
for _, provider := range config.Providers {
|
||||
providerNames[provider.Name] = true
|
||||
}
|
||||
|
||||
for _, providerName := range poolConfig.Providers {
|
||||
if !providerNames[providerName] {
|
||||
return fmt.Errorf("provider pool '%s' references unknown provider '%s'", poolName, providerName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate individual providers
|
||||
for i, provider := range config.Providers {
|
||||
if provider.Name == "" {
|
||||
return fmt.Errorf("provider %d has no name", i)
|
||||
}
|
||||
if provider.HTTPEndpoint == "" && provider.WSEndpoint == "" {
|
||||
return fmt.Errorf("provider %s has no endpoints", provider.Name)
|
||||
}
|
||||
if provider.RateLimit.RequestsPerSecond <= 0 {
|
||||
return fmt.Errorf("provider %s has invalid rate limit", provider.Name)
|
||||
}
|
||||
|
||||
// Validate Anvil config if present
|
||||
if provider.Type == "anvil_fork" && provider.AnvilConfig == nil {
|
||||
return fmt.Errorf("provider %s is anvil_fork type but has no anvil_config", provider.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProviderByName returns a specific provider configuration by name
|
||||
func (upm *UnifiedProviderManager) GetProviderByName(name string) (ProviderConfig, bool) {
|
||||
config, exists := upm.providerConfigs[name]
|
||||
return config, exists
|
||||
}
|
||||
|
||||
// GetProvidersByType returns all providers of a specific type
|
||||
func (upm *UnifiedProviderManager) GetProvidersByType(providerType string) []ProviderConfig {
|
||||
var providers []ProviderConfig
|
||||
for _, provider := range upm.config.Providers {
|
||||
if provider.Type == providerType {
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// GetProvidersByFeature returns all providers that support a specific feature
|
||||
func (upm *UnifiedProviderManager) GetProvidersByFeature(feature string) []ProviderConfig {
|
||||
var providers []ProviderConfig
|
||||
for _, provider := range upm.config.Providers {
|
||||
for _, providerFeature := range provider.Features {
|
||||
if providerFeature == feature {
|
||||
providers = append(providers, provider)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return providers
|
||||
}
|
||||
359
orig/pkg/transport/unix_transport.go
Normal file
359
orig/pkg/transport/unix_transport.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UnixSocketTransport implements Unix socket transport for local IPC
|
||||
type UnixSocketTransport struct {
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
connections map[string]net.Conn
|
||||
metrics TransportMetrics
|
||||
connected bool
|
||||
isServer bool
|
||||
receiveChan chan *Message
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewUnixSocketTransport creates a new Unix socket transport
|
||||
func NewUnixSocketTransport(socketPath string, isServer bool) *UnixSocketTransport {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &UnixSocketTransport{
|
||||
socketPath: socketPath,
|
||||
connections: make(map[string]net.Conn),
|
||||
metrics: TransportMetrics{},
|
||||
isServer: isServer,
|
||||
receiveChan: make(chan *Message, 1000),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes the Unix socket connection
|
||||
func (ut *UnixSocketTransport) Connect(ctx context.Context) error {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
if ut.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ut.isServer {
|
||||
return ut.startServer()
|
||||
} else {
|
||||
return ut.connectToServer()
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnect closes the Unix socket connection
|
||||
func (ut *UnixSocketTransport) Disconnect(ctx context.Context) error {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
if !ut.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
ut.cancel()
|
||||
|
||||
if ut.isServer && ut.listener != nil {
|
||||
ut.listener.Close()
|
||||
// Remove socket file
|
||||
os.Remove(ut.socketPath)
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
for id, conn := range ut.connections {
|
||||
conn.Close()
|
||||
delete(ut.connections, id)
|
||||
}
|
||||
|
||||
close(ut.receiveChan)
|
||||
ut.connected = false
|
||||
ut.metrics.Connections = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send transmits a message through the Unix socket
|
||||
func (ut *UnixSocketTransport) Send(ctx context.Context, msg *Message) error {
|
||||
start := time.Now()
|
||||
|
||||
ut.mu.RLock()
|
||||
if !ut.connected {
|
||||
ut.mu.RUnlock()
|
||||
ut.metrics.Errors++
|
||||
return fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
ut.mu.RUnlock()
|
||||
ut.metrics.Errors++
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Add length prefix for framing
|
||||
frame := fmt.Sprintf("%d\n%s", len(data), data)
|
||||
frameBytes := []byte(frame)
|
||||
|
||||
// Send to all connections
|
||||
var sendErr error
|
||||
connectionCount := len(ut.connections)
|
||||
ut.mu.RUnlock()
|
||||
|
||||
if connectionCount == 0 {
|
||||
ut.metrics.Errors++
|
||||
return fmt.Errorf("no active connections")
|
||||
}
|
||||
|
||||
ut.mu.RLock()
|
||||
for connID, conn := range ut.connections {
|
||||
if err := ut.sendToConnection(conn, frameBytes); err != nil {
|
||||
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
|
||||
// Remove failed connection
|
||||
go ut.removeConnection(connID)
|
||||
}
|
||||
}
|
||||
ut.mu.RUnlock()
|
||||
|
||||
if sendErr == nil {
|
||||
ut.updateSendMetrics(msg, time.Since(start))
|
||||
} else {
|
||||
ut.metrics.Errors++
|
||||
}
|
||||
|
||||
return sendErr
|
||||
}
|
||||
|
||||
// Receive returns a channel for receiving messages
|
||||
func (ut *UnixSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
if !ut.connected {
|
||||
return nil, fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
return ut.receiveChan, nil
|
||||
}
|
||||
|
||||
// Health returns the health status of the transport
|
||||
func (ut *UnixSocketTransport) Health() ComponentHealth {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
status := "unhealthy"
|
||||
if ut.connected {
|
||||
if len(ut.connections) > 0 {
|
||||
status = "healthy"
|
||||
} else {
|
||||
status = "degraded" // Connected but no active connections
|
||||
}
|
||||
}
|
||||
|
||||
return ComponentHealth{
|
||||
Status: status,
|
||||
LastCheck: time.Now(),
|
||||
ResponseTime: time.Millisecond, // Fast for local sockets
|
||||
ErrorCount: ut.metrics.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns transport-specific metrics
|
||||
func (ut *UnixSocketTransport) GetMetrics() TransportMetrics {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
return TransportMetrics{
|
||||
BytesSent: ut.metrics.BytesSent,
|
||||
BytesReceived: ut.metrics.BytesReceived,
|
||||
MessagesSent: ut.metrics.MessagesSent,
|
||||
MessagesReceived: ut.metrics.MessagesReceived,
|
||||
Connections: len(ut.connections),
|
||||
Errors: ut.metrics.Errors,
|
||||
Latency: ut.metrics.Latency,
|
||||
}
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (ut *UnixSocketTransport) startServer() error {
|
||||
// Remove existing socket file
|
||||
os.Remove(ut.socketPath)
|
||||
|
||||
listener, err := net.Listen("unix", ut.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on socket %s: %w", ut.socketPath, err)
|
||||
}
|
||||
|
||||
ut.listener = listener
|
||||
ut.connected = true
|
||||
|
||||
// Start accepting connections
|
||||
go ut.acceptConnections()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) connectToServer() error {
|
||||
conn, err := net.Dial("unix", ut.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to socket %s: %w", ut.socketPath, err)
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
|
||||
ut.connections[connID] = conn
|
||||
ut.connected = true
|
||||
ut.metrics.Connections = 1
|
||||
|
||||
// Start receiving from server
|
||||
go ut.handleConnection(connID, conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) acceptConnections() {
|
||||
for {
|
||||
select {
|
||||
case <-ut.ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn, err := ut.listener.Accept()
|
||||
if err != nil {
|
||||
if ut.ctx.Err() != nil {
|
||||
return // Context cancelled
|
||||
}
|
||||
ut.metrics.Errors++
|
||||
continue
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
|
||||
|
||||
ut.mu.Lock()
|
||||
ut.connections[connID] = conn
|
||||
ut.metrics.Connections = len(ut.connections)
|
||||
ut.mu.Unlock()
|
||||
|
||||
go ut.handleConnection(connID, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) handleConnection(connID string, conn net.Conn) {
|
||||
defer ut.removeConnection(connID)
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
var messageBuffer []byte
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ut.ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second)) // Non-blocking read
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue // Continue on timeout
|
||||
}
|
||||
return // Connection closed or error
|
||||
}
|
||||
|
||||
messageBuffer = append(messageBuffer, buffer[:n]...)
|
||||
|
||||
// Process complete messages
|
||||
for {
|
||||
msg, remaining, err := ExtractMessage(messageBuffer)
|
||||
if err != nil {
|
||||
return // Invalid message format
|
||||
}
|
||||
if msg == nil {
|
||||
break // No complete message yet
|
||||
}
|
||||
|
||||
// Deliver message
|
||||
select {
|
||||
case ut.receiveChan <- msg:
|
||||
ut.updateReceiveMetrics(msg)
|
||||
case <-ut.ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Channel full, drop message
|
||||
ut.metrics.Errors++
|
||||
}
|
||||
|
||||
messageBuffer = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) sendToConnection(conn net.Conn, data []byte) error {
|
||||
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
_, err := conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) removeConnection(connID string) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
if conn, exists := ut.connections[connID]; exists {
|
||||
conn.Close()
|
||||
delete(ut.connections, connID)
|
||||
ut.metrics.Connections = len(ut.connections)
|
||||
}
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
ut.metrics.MessagesSent++
|
||||
ut.metrics.Latency = latency
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
ut.metrics.BytesSent += messageSize
|
||||
}
|
||||
|
||||
func (ut *UnixSocketTransport) updateReceiveMetrics(msg *Message) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
|
||||
ut.metrics.MessagesReceived++
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
ut.metrics.BytesReceived += messageSize
|
||||
}
|
||||
|
||||
// GetSocketPath returns the socket path (for testing/debugging)
|
||||
func (ut *UnixSocketTransport) GetSocketPath() string {
|
||||
return ut.socketPath
|
||||
}
|
||||
|
||||
// GetConnectionCount returns the number of active connections
|
||||
func (ut *UnixSocketTransport) GetConnectionCount() int {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
return len(ut.connections)
|
||||
}
|
||||
48
orig/pkg/transport/utils.go
Normal file
48
orig/pkg/transport/utils.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ExtractMessage extracts a message from a byte buffer with length prefix format
|
||||
// Format: "length\nmessage_data"
|
||||
func ExtractMessage(buffer []byte) (*Message, []byte, error) {
|
||||
// Look for length prefix (format: "length\nmessage_data")
|
||||
newlineIndex := -1
|
||||
for i, b := range buffer {
|
||||
if b == '\n' {
|
||||
newlineIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if newlineIndex == -1 {
|
||||
return nil, buffer, nil // No complete length prefix yet
|
||||
}
|
||||
|
||||
// Parse length
|
||||
lengthStr := string(buffer[:newlineIndex])
|
||||
var messageLength int
|
||||
if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid length prefix: %s", lengthStr)
|
||||
}
|
||||
|
||||
// Check if we have the complete message
|
||||
messageStart := newlineIndex + 1
|
||||
messageEnd := messageStart + messageLength
|
||||
if len(buffer) < messageEnd {
|
||||
return nil, buffer, nil // Incomplete message
|
||||
}
|
||||
|
||||
// Extract and parse message
|
||||
messageData := buffer[messageStart:messageEnd]
|
||||
var msg Message
|
||||
if err := json.Unmarshal(messageData, &msg); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal message: %w", err)
|
||||
}
|
||||
|
||||
// Return message and remaining buffer
|
||||
remaining := buffer[messageEnd:]
|
||||
return &msg, remaining, nil
|
||||
}
|
||||
428
orig/pkg/transport/websocket_transport.go
Normal file
428
orig/pkg/transport/websocket_transport.go
Normal file
@@ -0,0 +1,428 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WebSocketTransport implements WebSocket transport for real-time monitoring
|
||||
type WebSocketTransport struct {
|
||||
address string
|
||||
port int
|
||||
path string
|
||||
upgrader websocket.Upgrader
|
||||
connections map[string]*websocket.Conn
|
||||
metrics TransportMetrics
|
||||
connected bool
|
||||
isServer bool
|
||||
receiveChan chan *Message
|
||||
server *http.Server
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
pingInterval time.Duration
|
||||
pongTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewWebSocketTransport creates a new WebSocket transport
|
||||
func NewWebSocketTransport(address string, port int, path string, isServer bool) *WebSocketTransport {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &WebSocketTransport{
|
||||
address: address,
|
||||
port: port,
|
||||
path: path,
|
||||
connections: make(map[string]*websocket.Conn),
|
||||
metrics: TransportMetrics{},
|
||||
isServer: isServer,
|
||||
receiveChan: make(chan *Message, 1000),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins for now
|
||||
},
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
},
|
||||
pingInterval: 30 * time.Second,
|
||||
pongTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPingPongSettings configures WebSocket ping/pong settings
|
||||
func (wt *WebSocketTransport) SetPingPongSettings(pingInterval, pongTimeout time.Duration) {
|
||||
wt.pingInterval = pingInterval
|
||||
wt.pongTimeout = pongTimeout
|
||||
}
|
||||
|
||||
// Connect establishes the WebSocket connection
|
||||
func (wt *WebSocketTransport) Connect(ctx context.Context) error {
|
||||
wt.mu.Lock()
|
||||
defer wt.mu.Unlock()
|
||||
|
||||
if wt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
if wt.isServer {
|
||||
return wt.startServer()
|
||||
} else {
|
||||
return wt.connectToServer(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnect closes the WebSocket connection
|
||||
func (wt *WebSocketTransport) Disconnect(ctx context.Context) error {
|
||||
wt.mu.Lock()
|
||||
defer wt.mu.Unlock()
|
||||
|
||||
if !wt.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
wt.cancel()
|
||||
|
||||
if wt.isServer && wt.server != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
wt.server.Shutdown(shutdownCtx)
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
for id, conn := range wt.connections {
|
||||
conn.Close()
|
||||
delete(wt.connections, id)
|
||||
}
|
||||
|
||||
close(wt.receiveChan)
|
||||
wt.connected = false
|
||||
wt.metrics.Connections = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send transmits a message through WebSocket
|
||||
func (wt *WebSocketTransport) Send(ctx context.Context, msg *Message) error {
|
||||
start := time.Now()
|
||||
|
||||
wt.mu.RLock()
|
||||
if !wt.connected {
|
||||
wt.mu.RUnlock()
|
||||
wt.metrics.Errors++
|
||||
return fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
// Serialize message
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
wt.mu.RUnlock()
|
||||
wt.metrics.Errors++
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Send to all connections
|
||||
var sendErr error
|
||||
connectionCount := len(wt.connections)
|
||||
wt.mu.RUnlock()
|
||||
|
||||
if connectionCount == 0 {
|
||||
wt.metrics.Errors++
|
||||
return fmt.Errorf("no active connections")
|
||||
}
|
||||
|
||||
wt.mu.RLock()
|
||||
for connID, conn := range wt.connections {
|
||||
if err := wt.sendToConnection(conn, data); err != nil {
|
||||
sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err)
|
||||
// Remove failed connection
|
||||
go wt.removeConnection(connID)
|
||||
}
|
||||
}
|
||||
wt.mu.RUnlock()
|
||||
|
||||
if sendErr == nil {
|
||||
wt.updateSendMetrics(msg, time.Since(start))
|
||||
} else {
|
||||
wt.metrics.Errors++
|
||||
}
|
||||
|
||||
return sendErr
|
||||
}
|
||||
|
||||
// Receive returns a channel for receiving messages
|
||||
func (wt *WebSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) {
|
||||
wt.mu.RLock()
|
||||
defer wt.mu.RUnlock()
|
||||
|
||||
if !wt.connected {
|
||||
return nil, fmt.Errorf("transport not connected")
|
||||
}
|
||||
|
||||
return wt.receiveChan, nil
|
||||
}
|
||||
|
||||
// Health returns the health status of the transport
|
||||
func (wt *WebSocketTransport) Health() ComponentHealth {
|
||||
wt.mu.RLock()
|
||||
defer wt.mu.RUnlock()
|
||||
|
||||
status := "unhealthy"
|
||||
if wt.connected {
|
||||
if len(wt.connections) > 0 {
|
||||
status = "healthy"
|
||||
} else {
|
||||
status = "degraded" // Connected but no active connections
|
||||
}
|
||||
}
|
||||
|
||||
return ComponentHealth{
|
||||
Status: status,
|
||||
LastCheck: time.Now(),
|
||||
ResponseTime: time.Millisecond * 5, // Very fast for WebSocket
|
||||
ErrorCount: wt.metrics.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns transport-specific metrics
|
||||
func (wt *WebSocketTransport) GetMetrics() TransportMetrics {
|
||||
wt.mu.RLock()
|
||||
defer wt.mu.RUnlock()
|
||||
|
||||
return TransportMetrics{
|
||||
BytesSent: wt.metrics.BytesSent,
|
||||
BytesReceived: wt.metrics.BytesReceived,
|
||||
MessagesSent: wt.metrics.MessagesSent,
|
||||
MessagesReceived: wt.metrics.MessagesReceived,
|
||||
Connections: len(wt.connections),
|
||||
Errors: wt.metrics.Errors,
|
||||
Latency: wt.metrics.Latency,
|
||||
}
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
func (wt *WebSocketTransport) startServer() error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(wt.path, wt.handleWebSocket)
|
||||
|
||||
// Add health check endpoint
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
health := wt.Health()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(health)
|
||||
})
|
||||
|
||||
// Add metrics endpoint
|
||||
mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
|
||||
metrics := wt.GetMetrics()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(metrics)
|
||||
})
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", wt.address, wt.port)
|
||||
wt.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second, // Prevent Slowloris attacks
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 60 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
wt.connected = true
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
if err := wt.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
wt.metrics.Errors++
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) connectToServer(ctx context.Context) error {
|
||||
url := fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path)
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(ctx, url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to WebSocket server: %w", err)
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("client_%d", time.Now().UnixNano())
|
||||
wt.connections[connID] = conn
|
||||
wt.connected = true
|
||||
wt.metrics.Connections = 1
|
||||
|
||||
// Start handling connection
|
||||
go wt.handleConnection(connID, conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := wt.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
wt.metrics.Errors++
|
||||
return
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("server_%d", time.Now().UnixNano())
|
||||
|
||||
wt.mu.Lock()
|
||||
wt.connections[connID] = conn
|
||||
wt.metrics.Connections = len(wt.connections)
|
||||
wt.mu.Unlock()
|
||||
|
||||
go wt.handleConnection(connID, conn)
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) handleConnection(connID string, conn *websocket.Conn) {
|
||||
defer wt.removeConnection(connID)
|
||||
|
||||
// Set up ping/pong handling
|
||||
conn.SetReadDeadline(time.Now().Add(wt.pongTimeout))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(wt.pongTimeout))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start ping routine
|
||||
go wt.pingRoutine(connID, conn)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-wt.ctx.Done():
|
||||
return
|
||||
default:
|
||||
var msg Message
|
||||
err := conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
wt.metrics.Errors++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Deliver message
|
||||
select {
|
||||
case wt.receiveChan <- &msg:
|
||||
wt.updateReceiveMetrics(&msg)
|
||||
case <-wt.ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Channel full, drop message
|
||||
wt.metrics.Errors++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) pingRoutine(connID string, conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(wt.pingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
|
||||
return // Connection is likely closed
|
||||
}
|
||||
case <-wt.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) sendToConnection(conn *websocket.Conn, data []byte) error {
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
return conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) removeConnection(connID string) {
|
||||
wt.mu.Lock()
|
||||
defer wt.mu.Unlock()
|
||||
|
||||
if conn, exists := wt.connections[connID]; exists {
|
||||
conn.Close()
|
||||
delete(wt.connections, connID)
|
||||
wt.metrics.Connections = len(wt.connections)
|
||||
}
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) {
|
||||
wt.mu.Lock()
|
||||
defer wt.mu.Unlock()
|
||||
|
||||
wt.metrics.MessagesSent++
|
||||
wt.metrics.Latency = latency
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
wt.metrics.BytesSent += messageSize
|
||||
}
|
||||
|
||||
func (wt *WebSocketTransport) updateReceiveMetrics(msg *Message) {
|
||||
wt.mu.Lock()
|
||||
defer wt.mu.Unlock()
|
||||
|
||||
wt.metrics.MessagesReceived++
|
||||
|
||||
// Estimate message size
|
||||
messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source))
|
||||
if msg.Data != nil {
|
||||
messageSize += int64(len(fmt.Sprintf("%v", msg.Data)))
|
||||
}
|
||||
wt.metrics.BytesReceived += messageSize
|
||||
}
|
||||
|
||||
// Broadcast sends a message to all connected clients (server mode only)
|
||||
func (wt *WebSocketTransport) Broadcast(ctx context.Context, msg *Message) error {
|
||||
if !wt.isServer {
|
||||
return fmt.Errorf("broadcast only available in server mode")
|
||||
}
|
||||
|
||||
return wt.Send(ctx, msg)
|
||||
}
|
||||
|
||||
// GetURL returns the WebSocket URL (for testing/debugging)
|
||||
func (wt *WebSocketTransport) GetURL() string {
|
||||
return fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path)
|
||||
}
|
||||
|
||||
// GetConnectionCount returns the number of active connections
|
||||
func (wt *WebSocketTransport) GetConnectionCount() int {
|
||||
wt.mu.RLock()
|
||||
defer wt.mu.RUnlock()
|
||||
return len(wt.connections)
|
||||
}
|
||||
|
||||
// SetAllowedOrigins configures CORS for WebSocket connections
|
||||
func (wt *WebSocketTransport) SetAllowedOrigins(origins []string) {
|
||||
if len(origins) == 0 {
|
||||
wt.upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||
return true // Allow all origins
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
originMap := make(map[string]bool)
|
||||
for _, origin := range origins {
|
||||
originMap[origin] = true
|
||||
}
|
||||
|
||||
wt.upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
return originMap[origin]
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user