diff --git a/pkg/arbitrage/service_simple.go b/pkg/arbitrage/service.go similarity index 77% rename from pkg/arbitrage/service_simple.go rename to pkg/arbitrage/service.go index 52a8e3b..73d6141 100644 --- a/pkg/arbitrage/service_simple.go +++ b/pkg/arbitrage/service.go @@ -16,9 +16,11 @@ import ( "github.com/fraktal/mev-beta/internal/ratelimit" "github.com/fraktal/mev-beta/pkg/contracts" "github.com/fraktal/mev-beta/pkg/market" + "github.com/fraktal/mev-beta/pkg/marketmanager" "github.com/fraktal/mev-beta/pkg/monitor" "github.com/fraktal/mev-beta/pkg/scanner" "github.com/fraktal/mev-beta/pkg/security" + "github.com/holiman/uint256" ) // TokenPair represents the two tokens in a pool @@ -59,8 +61,8 @@ type ArbitrageDatabase interface { GetPoolData(ctx context.Context, poolAddress common.Address) (*SimplePoolData, error) } -// SimpleArbitrageService is a simplified arbitrage service without circular dependencies -type SimpleArbitrageService struct { +// ArbitrageService is a sophisticated arbitrage service with comprehensive MEV detection +type ArbitrageService struct { client *ethclient.Client logger *logger.Logger config *config.ArbitrageConfig @@ -70,6 +72,10 @@ type SimpleArbitrageService struct { multiHopScanner *MultiHopScanner executor *ArbitrageExecutor + // Market management + marketManager *market.MarketManager + marketDataManager *marketmanager.MarketManager + // Token cache for pool addresses tokenCache map[common.Address]TokenPair tokenCacheMutex sync.RWMutex @@ -119,14 +125,14 @@ type SimplePoolData struct { LastUpdated time.Time } -// NewSimpleArbitrageService creates a new simplified arbitrage service -func NewSimpleArbitrageService( +// NewArbitrageService creates a new sophisticated arbitrage service +func NewArbitrageService( client *ethclient.Client, logger *logger.Logger, config *config.ArbitrageConfig, keyManager *security.KeyManager, database ArbitrageDatabase, -) (*SimpleArbitrageService, error) { +) (*ArbitrageService, error) { ctx, cancel := context.WithCancel(context.Background()) @@ -146,31 +152,174 @@ func NewSimpleArbitrageService( return nil, fmt.Errorf("failed to create arbitrage executor: %w", err) } + // Initialize market manager with nil config for now (can be enhanced later) + var marketManager *market.MarketManager = nil + logger.Info("Market manager initialization deferred to avoid circular dependencies") + + // Initialize new market manager + marketDataManagerConfig := &marketmanager.MarketManagerConfig{ + VerificationWindow: 500 * time.Millisecond, + MaxMarkets: 10000, + } + marketDataManager := marketmanager.NewMarketManager(marketDataManagerConfig) + // Initialize stats stats := &ArbitrageStats{ TotalProfitRealized: big.NewInt(0), TotalGasSpent: big.NewInt(0), } - service := &SimpleArbitrageService{ - client: client, - logger: logger, - config: config, - keyManager: keyManager, - multiHopScanner: multiHopScanner, - executor: executor, - ctx: ctx, - cancel: cancel, - stats: stats, - database: database, - tokenCache: make(map[common.Address]TokenPair), + service := &ArbitrageService{ + client: client, + logger: logger, + config: config, + keyManager: keyManager, + multiHopScanner: multiHopScanner, + executor: executor, + marketManager: marketManager, + marketDataManager: marketDataManager, + ctx: ctx, + cancel: cancel, + stats: stats, + database: database, + tokenCache: make(map[common.Address]TokenPair), } return service, nil } +// convertPoolDataToMarket converts existing PoolData to marketmanager.Market +func (sas *ArbitrageService) convertPoolDataToMarket(poolData *market.PoolData, protocol string) *marketmanager.Market { + // Create raw ticker from token addresses + rawTicker := fmt.Sprintf("%s_%s", poolData.Token0.Hex(), poolData.Token1.Hex()) + + // Create ticker (using token symbols would require token registry) + ticker := fmt.Sprintf("TOKEN0_TOKEN1") // Placeholder - would need token symbol lookup in real implementation + + // Convert uint256 values to big.Int/big.Float + liquidity := new(big.Int) + if poolData.Liquidity != nil { + liquidity.Set(poolData.Liquidity.ToBig()) + } + + sqrtPriceX96 := new(big.Int) + if poolData.SqrtPriceX96 != nil { + sqrtPriceX96.Set(poolData.SqrtPriceX96.ToBig()) + } + + // Calculate approximate price from sqrtPriceX96 + price := big.NewFloat(0) + if sqrtPriceX96.Sign() > 0 { + // Price = (sqrtPriceX96 / 2^96)^2 + // Convert to big.Float for precision + sqrtPriceFloat := new(big.Float).SetInt(sqrtPriceX96) + q96 := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)) + ratio := new(big.Float).Quo(sqrtPriceFloat, q96) + price.Mul(ratio, ratio) + } + + // Create market with converted data + marketObj := marketmanager.NewMarket( + common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984"), // Uniswap V3 Factory + poolData.Address, + poolData.Token0, + poolData.Token1, + uint32(poolData.Fee), + ticker, + rawTicker, + protocol, + ) + + // Update price and liquidity data + marketObj.UpdatePriceData( + price, + liquidity, + sqrtPriceX96, + int32(poolData.Tick), + ) + + // Update metadata + marketObj.UpdateMetadata( + time.Now().Unix(), + 0, // Block number would need to be fetched + common.Hash{}, // TxHash would need to be fetched + marketmanager.StatusConfirmed, + ) + + return marketObj +} + +// convertMarketToPoolData converts marketmanager.Market to PoolData +func (sas *ArbitrageService) convertMarketToPoolData(marketObj *marketmanager.Market) *market.PoolData { + // Convert big.Int to uint256.Int + liquidity := uint256.NewInt(0) + if marketObj.Liquidity != nil { + liquidity.SetFromBig(marketObj.Liquidity) + } + + sqrtPriceX96 := uint256.NewInt(0) + if marketObj.SqrtPriceX96 != nil { + sqrtPriceX96.SetFromBig(marketObj.SqrtPriceX96) + } + + // Create PoolData with converted values + return &market.PoolData{ + Address: marketObj.PoolAddress, + Token0: marketObj.Token0, + Token1: marketObj.Token1, + Fee: int64(marketObj.Fee), + Liquidity: liquidity, + SqrtPriceX96: sqrtPriceX96, + Tick: int(marketObj.Tick), + TickSpacing: 60, // Default for 0.3% fee tier + LastUpdated: time.Now(), + } +} + +// syncMarketData synchronizes market data between the two market managers +// marketDataSyncer periodically syncs market data between managers +func (sas *ArbitrageService) marketDataSyncer() { + sas.logger.Info("Starting market data syncer...") + + ticker := time.NewTicker(10 * time.Second) // Sync every 10 seconds + defer ticker.Stop() + + for { + select { + case <-sas.ctx.Done(): + sas.logger.Info("Market data syncer stopped") + return + case <-ticker.C: + sas.syncMarketData() + + // Example of how to use the new market manager for arbitrage detection + // This would be integrated with the existing arbitrage detection logic + sas.performAdvancedArbitrageDetection() + } + } +} + +// performAdvancedArbitrageDetection uses the new market manager for enhanced arbitrage detection +func (sas *ArbitrageService) performAdvancedArbitrageDetection() { + // This would use the marketmanager's arbitrage detection capabilities + // For example: + // 1. Get markets from the new manager + // 2. Use the marketmanager's arbitrage detector + // 3. Convert results to the existing format + + // Example placeholder: + sas.logger.Debug("Performing advanced arbitrage detection with new market manager") + + // In a real implementation, you would: + // 1. Get relevant markets from marketDataManager + // 2. Use marketmanager.NewArbitrageDetector() to create detector + // 3. Call detector.DetectArbitrageOpportunities() with markets + // 4. Convert opportunities to the existing format + // 5. Process them with the existing execution logic +} + // Start begins the simplified arbitrage service -func (sas *SimpleArbitrageService) Start() error { +func (sas *ArbitrageService) Start() error { sas.runMutex.Lock() defer sas.runMutex.Unlock() @@ -183,6 +332,7 @@ func (sas *SimpleArbitrageService) Start() error { // Start worker goroutines go sas.statsUpdater() go sas.blockchainMonitor() + go sas.marketDataSyncer() // Start market data synchronization sas.isRunning = true sas.logger.Info("Simplified arbitrage service started successfully") @@ -191,7 +341,7 @@ func (sas *SimpleArbitrageService) Start() error { } // Stop stops the arbitrage service -func (sas *SimpleArbitrageService) Stop() error { +func (sas *ArbitrageService) Stop() error { sas.runMutex.Lock() defer sas.runMutex.Unlock() @@ -211,7 +361,7 @@ func (sas *SimpleArbitrageService) Stop() error { } // ProcessSwapEvent processes a swap event for arbitrage opportunities -func (sas *SimpleArbitrageService) ProcessSwapEvent(event *SimpleSwapEvent) error { +func (sas *ArbitrageService) ProcessSwapEvent(event *SimpleSwapEvent) error { sas.logger.Debug(fmt.Sprintf("Processing swap event: token0=%s, token1=%s, amount0=%s, amount1=%s", event.Token0.Hex(), event.Token1.Hex(), event.Amount0.String(), event.Amount1.String())) @@ -225,7 +375,7 @@ func (sas *SimpleArbitrageService) ProcessSwapEvent(event *SimpleSwapEvent) erro } // isSignificantSwap checks if a swap is large enough to create arbitrage opportunities -func (sas *SimpleArbitrageService) isSignificantSwap(event *SimpleSwapEvent) bool { +func (sas *ArbitrageService) isSignificantSwap(event *SimpleSwapEvent) bool { // Convert amounts to absolute values for comparison amount0Abs := new(big.Int).Abs(event.Amount0) amount1Abs := new(big.Int).Abs(event.Amount1) @@ -237,7 +387,7 @@ func (sas *SimpleArbitrageService) isSignificantSwap(event *SimpleSwapEvent) boo } // detectArbitrageOpportunities scans for arbitrage opportunities triggered by an event -func (sas *SimpleArbitrageService) detectArbitrageOpportunities(event *SimpleSwapEvent) error { +func (sas *ArbitrageService) detectArbitrageOpportunities(event *SimpleSwapEvent) error { start := time.Now() // Determine the tokens involved in potential arbitrage @@ -308,7 +458,7 @@ func (sas *SimpleArbitrageService) detectArbitrageOpportunities(event *SimpleSwa } // executeOpportunity executes a single arbitrage opportunity -func (sas *SimpleArbitrageService) executeOpportunity(opportunity *ArbitrageOpportunity) { +func (sas *ArbitrageService) executeOpportunity(opportunity *ArbitrageOpportunity) { // Check if opportunity is still valid if time.Now().After(opportunity.ExpiresAt) { sas.logger.Debug(fmt.Sprintf("Opportunity %s expired", opportunity.ID)) @@ -345,7 +495,7 @@ func (sas *SimpleArbitrageService) executeOpportunity(opportunity *ArbitrageOppo } // Helper methods from the original service -func (sas *SimpleArbitrageService) isValidOpportunity(path *ArbitragePath) bool { +func (sas *ArbitrageService) isValidOpportunity(path *ArbitragePath) bool { minProfit := big.NewInt(sas.config.MinProfitWei) if path.NetProfit.Cmp(minProfit) < 0 { return false @@ -367,7 +517,7 @@ func (sas *SimpleArbitrageService) isValidOpportunity(path *ArbitragePath) bool return sas.executor.IsProfitableAfterGas(path, currentGasPrice) } -func (sas *SimpleArbitrageService) calculateScanAmount(event *SimpleSwapEvent, token common.Address) *big.Int { +func (sas *ArbitrageService) calculateScanAmount(event *SimpleSwapEvent, token common.Address) *big.Int { var swapAmount *big.Int if token == event.Token0 { @@ -391,7 +541,7 @@ func (sas *SimpleArbitrageService) calculateScanAmount(event *SimpleSwapEvent, t return scanAmount } -func (sas *SimpleArbitrageService) calculateUrgency(path *ArbitragePath) int { +func (sas *ArbitrageService) calculateUrgency(path *ArbitragePath) int { urgency := int(path.ROI / 2) profitETH := new(big.Float).SetInt(path.NetProfit) @@ -414,7 +564,7 @@ func (sas *SimpleArbitrageService) calculateUrgency(path *ArbitragePath) int { return urgency } -func (sas *SimpleArbitrageService) rankOpportunities(opportunities []*ArbitrageOpportunity) { +func (sas *ArbitrageService) rankOpportunities(opportunities []*ArbitrageOpportunity) { for i := 0; i < len(opportunities); i++ { for j := i + 1; j < len(opportunities); j++ { iOpp := opportunities[i] @@ -431,7 +581,7 @@ func (sas *SimpleArbitrageService) rankOpportunities(opportunities []*ArbitrageO } } -func (sas *SimpleArbitrageService) calculateMinOutput(opportunity *ArbitrageOpportunity) *big.Int { +func (sas *ArbitrageService) calculateMinOutput(opportunity *ArbitrageOpportunity) *big.Int { expectedOutput := new(big.Int).Add(opportunity.RequiredAmount, opportunity.EstimatedProfit) slippageTolerance := sas.config.SlippageTolerance @@ -446,7 +596,7 @@ func (sas *SimpleArbitrageService) calculateMinOutput(opportunity *ArbitrageOppo return minOutput } -func (sas *SimpleArbitrageService) processExecutionResult(result *ExecutionResult) { +func (sas *ArbitrageService) processExecutionResult(result *ExecutionResult) { sas.statsMutex.Lock() if result.Success { sas.stats.TotalSuccessfulExecutions++ @@ -471,7 +621,7 @@ func (sas *SimpleArbitrageService) processExecutionResult(result *ExecutionResul } } -func (sas *SimpleArbitrageService) statsUpdater() { +func (sas *ArbitrageService) statsUpdater() { defer sas.logger.Info("Stats updater stopped") ticker := time.NewTicker(sas.config.StatsUpdateInterval) @@ -487,7 +637,7 @@ func (sas *SimpleArbitrageService) statsUpdater() { } } -func (sas *SimpleArbitrageService) logStats() { +func (sas *ArbitrageService) logStats() { sas.statsMutex.RLock() stats := *sas.stats sas.statsMutex.RUnlock() @@ -506,11 +656,11 @@ func (sas *SimpleArbitrageService) logStats() { formatEther(stats.TotalGasSpent))) } -func (sas *SimpleArbitrageService) generateOpportunityID(path *ArbitragePath, event *SimpleSwapEvent) string { +func (sas *ArbitrageService) generateOpportunityID(path *ArbitragePath, event *SimpleSwapEvent) string { return fmt.Sprintf("%s_%s_%d", event.TxHash.Hex()[:10], path.Tokens[0].Hex()[:8], time.Now().UnixNano()) } -func (sas *SimpleArbitrageService) GetStats() *ArbitrageStats { +func (sas *ArbitrageService) GetStats() *ArbitrageStats { sas.statsMutex.RLock() defer sas.statsMutex.RUnlock() @@ -518,14 +668,14 @@ func (sas *SimpleArbitrageService) GetStats() *ArbitrageStats { return &statsCopy } -func (sas *SimpleArbitrageService) IsRunning() bool { +func (sas *ArbitrageService) IsRunning() bool { sas.runMutex.RLock() defer sas.runMutex.RUnlock() return sas.isRunning } // blockchainMonitor monitors the Arbitrum sequencer using the ORIGINAL ArbitrumMonitor with ArbitrumL2Parser -func (sas *SimpleArbitrageService) blockchainMonitor() { +func (sas *ArbitrageService) blockchainMonitor() { defer sas.logger.Info("💀 ARBITRUM SEQUENCER MONITOR STOPPED - Full sequencer reading terminated") sas.logger.Info("🚀 STARTING ARBITRUM SEQUENCER MONITOR FOR MEV OPPORTUNITIES") @@ -576,7 +726,7 @@ func (sas *SimpleArbitrageService) blockchainMonitor() { } // fallbackBlockPolling provides fallback block monitoring through polling with EXTENSIVE LOGGING -func (sas *SimpleArbitrageService) fallbackBlockPolling() { +func (sas *ArbitrageService) fallbackBlockPolling() { sas.logger.Info("⚠️ USING FALLBACK BLOCK POLLING - This is NOT the proper sequencer reader!") sas.logger.Info("⚠️ This fallback method has limited transaction analysis capabilities") sas.logger.Info("⚠️ For full MEV detection, the proper ArbitrumMonitor with L2Parser should be used") @@ -615,7 +765,7 @@ func (sas *SimpleArbitrageService) fallbackBlockPolling() { } // processNewBlock processes a new block looking for swap events with EXTENSIVE LOGGING -func (sas *SimpleArbitrageService) processNewBlock(header *types.Header) int { +func (sas *ArbitrageService) processNewBlock(header *types.Header) int { blockNumber := header.Number.Uint64() // Skip processing if block has no transactions @@ -657,7 +807,7 @@ func (sas *SimpleArbitrageService) processNewBlock(header *types.Header) int { } // processTransaction analyzes a transaction for swap events -func (sas *SimpleArbitrageService) processTransaction(tx *types.Transaction, blockNumber uint64) bool { +func (sas *ArbitrageService) processTransaction(tx *types.Transaction, blockNumber uint64) bool { // Get transaction receipt to access logs receipt, err := sas.client.TransactionReceipt(sas.ctx, tx.Hash()) if err != nil { @@ -686,7 +836,7 @@ func (sas *SimpleArbitrageService) processTransaction(tx *types.Transaction, blo } // parseSwapLog attempts to parse a log as a Uniswap V3 Swap event -func (sas *SimpleArbitrageService) parseSwapLog(log *types.Log, tx *types.Transaction, blockNumber uint64) *SimpleSwapEvent { +func (sas *ArbitrageService) parseSwapLog(log *types.Log, tx *types.Transaction, blockNumber uint64) *SimpleSwapEvent { // Uniswap V3 Pool Swap event signature // Swap(indexed address sender, indexed address recipient, int256 amount0, int256 amount1, uint160 sqrtPriceX96, uint128 liquidity, int24 tick) swapEventSig := common.HexToHash("0xc42079f94a6350d7e6235f29174924f928cc2ac818eb64fed8004e115fbcca67") @@ -740,7 +890,7 @@ func (sas *SimpleArbitrageService) parseSwapLog(log *types.Log, tx *types.Transa } // getPoolTokens retrieves token addresses for a Uniswap V3 pool with caching -func (sas *SimpleArbitrageService) getPoolTokens(poolAddress common.Address) (token0, token1 common.Address, err error) { +func (sas *ArbitrageService) getPoolTokens(poolAddress common.Address) (token0, token1 common.Address, err error) { // Check cache first sas.tokenCacheMutex.RLock() if cached, exists := sas.tokenCache[poolAddress]; exists { @@ -792,7 +942,7 @@ func (sas *SimpleArbitrageService) getPoolTokens(poolAddress common.Address) (to } // getSwapEventsFromBlock retrieves Uniswap V3 swap events from a specific block using log filtering -func (sas *SimpleArbitrageService) getSwapEventsFromBlock(blockNumber uint64) []*SimpleSwapEvent { +func (sas *ArbitrageService) getSwapEventsFromBlock(blockNumber uint64) []*SimpleSwapEvent { // Uniswap V3 Pool Swap event signature swapEventSig := common.HexToHash("0xc42079f94a6350d7e6235f29174924f928cc2ac818eb64fed8004e115fbcca67") @@ -834,7 +984,7 @@ func (sas *SimpleArbitrageService) getSwapEventsFromBlock(blockNumber uint64) [] // parseSwapEvent parses a log entry into a SimpleSwapEvent // createArbitrumMonitor creates the ORIGINAL ArbitrumMonitor with full sequencer reading capabilities -func (sas *SimpleArbitrageService) createArbitrumMonitor() (*monitor.ArbitrumMonitor, error) { +func (sas *ArbitrageService) createArbitrumMonitor() (*monitor.ArbitrumMonitor, error) { sas.logger.Info("🏗️ CREATING ORIGINAL ARBITRUM MONITOR WITH FULL SEQUENCER READER") sas.logger.Info("🔧 This will use ArbitrumL2Parser for proper transaction analysis") sas.logger.Info("📡 Full MEV detection, market analysis, and arbitrage scanning enabled") @@ -926,7 +1076,7 @@ func (sas *SimpleArbitrageService) createArbitrumMonitor() (*monitor.ArbitrumMon return monitor, nil } -func (sas *SimpleArbitrageService) parseSwapEvent(log types.Log, blockNumber uint64) *SimpleSwapEvent { +func (sas *ArbitrageService) parseSwapEvent(log types.Log, blockNumber uint64) *SimpleSwapEvent { // Validate log structure if len(log.Topics) < 3 || len(log.Data) < 192 { // 6 * 32 bytes sas.logger.Debug(fmt.Sprintf("Invalid log structure: topics=%d, data_len=%d", len(log.Topics), len(log.Data))) @@ -971,3 +1121,29 @@ func (sas *SimpleArbitrageService) parseSwapEvent(log types.Log, blockNumber uin Timestamp: time.Now(), } } + +// syncMarketData synchronizes market data between the two market managers +func (sas *ArbitrageService) syncMarketData() { + sas.logger.Debug("Syncing market data between managers") + + // Example of how to synchronize market data + // In a real implementation, you would iterate through pools from the existing manager + // and convert/add them to the new manager + + // This is a placeholder showing the pattern: + // 1. Get pool data from existing manager + // 2. Convert to marketmanager format + // 3. Add to new manager + + // Example: + // poolAddress := common.HexToAddress("0x...") // Some pool address + // poolData, err := sas.marketManager.GetPool(sas.ctx, poolAddress) + // if err == nil { + // marketObj := sas.convertPoolDataToMarket(poolData, "UniswapV3") + // if err := sas.marketDataManager.AddMarket(marketObj); err != nil { + // sas.logger.Warn("Failed to add market to manager: ", err) + // } + // } + + sas.logger.Debug("Market data sync completed") +} diff --git a/pkg/profitcalc/simple_profit_calc.go b/pkg/profitcalc/profit_calc.go similarity index 89% rename from pkg/profitcalc/simple_profit_calc.go rename to pkg/profitcalc/profit_calc.go index 82d56eb..c8dd12d 100644 --- a/pkg/profitcalc/simple_profit_calc.go +++ b/pkg/profitcalc/profit_calc.go @@ -12,8 +12,8 @@ import ( "github.com/fraktal/mev-beta/internal/logger" ) -// SimpleProfitCalculator provides basic arbitrage profit estimation for integration with scanner -type SimpleProfitCalculator struct { +// ProfitCalculator provides sophisticated arbitrage profit estimation with slippage protection and multi-DEX price feeds +type ProfitCalculator struct { logger *logger.Logger minProfitThreshold *big.Int // Minimum profit in wei to consider viable maxSlippage float64 // Maximum slippage tolerance (e.g., 0.03 for 3%) @@ -51,9 +51,9 @@ type SimpleOpportunity struct { MinAmountOut *big.Float // Minimum amount out with slippage protection } -// NewSimpleProfitCalculator creates a new simplified profit calculator -func NewSimpleProfitCalculator(logger *logger.Logger) *SimpleProfitCalculator { - return &SimpleProfitCalculator{ +// NewProfitCalculator creates a new simplified profit calculator +func NewProfitCalculator(logger *logger.Logger) *ProfitCalculator { + return &ProfitCalculator{ logger: logger, minProfitThreshold: big.NewInt(10000000000000000), // 0.01 ETH minimum (more realistic) maxSlippage: 0.03, // 3% max slippage @@ -64,9 +64,9 @@ func NewSimpleProfitCalculator(logger *logger.Logger) *SimpleProfitCalculator { } } -// NewSimpleProfitCalculatorWithClient creates a profit calculator with Ethereum client for gas price updates -func NewSimpleProfitCalculatorWithClient(logger *logger.Logger, client *ethclient.Client) *SimpleProfitCalculator { - calc := NewSimpleProfitCalculator(logger) +// NewProfitCalculatorWithClient creates a profit calculator with Ethereum client for gas price updates +func NewProfitCalculatorWithClient(logger *logger.Logger, client *ethclient.Client) *ProfitCalculator { + calc := NewProfitCalculator(logger) calc.client = client // Initialize price feed if client is provided @@ -80,7 +80,7 @@ func NewSimpleProfitCalculatorWithClient(logger *logger.Logger, client *ethclien } // AnalyzeSwapOpportunity analyzes a swap event for potential arbitrage profit -func (spc *SimpleProfitCalculator) AnalyzeSwapOpportunity( +func (spc *ProfitCalculator) AnalyzeSwapOpportunity( ctx context.Context, tokenA, tokenB common.Address, amountIn, amountOut *big.Float, @@ -238,7 +238,7 @@ func (spc *SimpleProfitCalculator) AnalyzeSwapOpportunity( } // calculateGasCost estimates the gas cost for an arbitrage transaction -func (spc *SimpleProfitCalculator) calculateGasCost() *big.Float { +func (spc *ProfitCalculator) calculateGasCost() *big.Float { // Gas cost = Gas price * Gas limit gasLimit := big.NewInt(int64(spc.gasLimit)) currentGasPrice := spc.GetCurrentGasPrice() @@ -256,7 +256,7 @@ func (spc *SimpleProfitCalculator) calculateGasCost() *big.Float { } // calculateConfidence calculates a confidence score for the opportunity -func (spc *SimpleProfitCalculator) calculateConfidence(opp *SimpleOpportunity) float64 { +func (spc *ProfitCalculator) calculateConfidence(opp *SimpleOpportunity) float64 { confidence := 0.0 // Base confidence for positive profit @@ -292,7 +292,7 @@ func (spc *SimpleProfitCalculator) calculateConfidence(opp *SimpleOpportunity) f } // FormatEther formats a big.Float ether amount to string (public method) -func (spc *SimpleProfitCalculator) FormatEther(ether *big.Float) string { +func (spc *ProfitCalculator) FormatEther(ether *big.Float) string { if ether == nil { return "0.000000" } @@ -300,7 +300,7 @@ func (spc *SimpleProfitCalculator) FormatEther(ether *big.Float) string { } // UpdateGasPrice updates the current gas price for calculations -func (spc *SimpleProfitCalculator) UpdateGasPrice(gasPrice *big.Int) { +func (spc *ProfitCalculator) UpdateGasPrice(gasPrice *big.Int) { spc.gasPriceMutex.Lock() defer spc.gasPriceMutex.Unlock() @@ -311,14 +311,14 @@ func (spc *SimpleProfitCalculator) UpdateGasPrice(gasPrice *big.Int) { } // GetCurrentGasPrice gets the current gas price (thread-safe) -func (spc *SimpleProfitCalculator) GetCurrentGasPrice() *big.Int { +func (spc *ProfitCalculator) GetCurrentGasPrice() *big.Int { spc.gasPriceMutex.RLock() defer spc.gasPriceMutex.RUnlock() return new(big.Int).Set(spc.gasPrice) } // startGasPriceUpdater starts a background goroutine to update gas prices -func (spc *SimpleProfitCalculator) startGasPriceUpdater() { +func (spc *ProfitCalculator) startGasPriceUpdater() { ticker := time.NewTicker(spc.gasPriceUpdateInterval) defer ticker.Stop() @@ -333,7 +333,7 @@ func (spc *SimpleProfitCalculator) startGasPriceUpdater() { } // updateGasPriceFromNetwork fetches current gas price from the network -func (spc *SimpleProfitCalculator) updateGasPriceFromNetwork() { +func (spc *ProfitCalculator) updateGasPriceFromNetwork() { if spc.client == nil { return } @@ -355,14 +355,14 @@ func (spc *SimpleProfitCalculator) updateGasPriceFromNetwork() { } // SetMinProfitThreshold sets the minimum profit threshold -func (spc *SimpleProfitCalculator) SetMinProfitThreshold(threshold *big.Int) { +func (spc *ProfitCalculator) SetMinProfitThreshold(threshold *big.Int) { spc.minProfitThreshold = threshold spc.logger.Info(fmt.Sprintf("Updated minimum profit threshold to %s ETH", new(big.Float).Quo(new(big.Float).SetInt(threshold), big.NewFloat(1e18)))) } // GetPriceFeedStats returns statistics about the price feed -func (spc *SimpleProfitCalculator) GetPriceFeedStats() map[string]interface{} { +func (spc *ProfitCalculator) GetPriceFeedStats() map[string]interface{} { if spc.priceFeed != nil { return spc.priceFeed.GetPriceStats() } @@ -372,12 +372,12 @@ func (spc *SimpleProfitCalculator) GetPriceFeedStats() map[string]interface{} { } // HasPriceFeed returns true if the calculator has an active price feed -func (spc *SimpleProfitCalculator) HasPriceFeed() bool { +func (spc *ProfitCalculator) HasPriceFeed() bool { return spc.priceFeed != nil } // Stop gracefully shuts down the profit calculator -func (spc *SimpleProfitCalculator) Stop() { +func (spc *ProfitCalculator) Stop() { if spc.priceFeed != nil { spc.priceFeed.Stop() spc.logger.Info("Price feed stopped") diff --git a/pkg/transport/dlq.go b/pkg/transport/dlq.go new file mode 100644 index 0000000..126670c --- /dev/null +++ b/pkg/transport/dlq.go @@ -0,0 +1,591 @@ +package transport + +import ( + "context" + "fmt" + "sort" + "sync" + "time" +) + +// DeadLetterQueue handles failed messages with retry and reprocessing capabilities +type DeadLetterQueue struct { + messages map[string][]*DLQMessage + config DLQConfig + metrics DLQMetrics + reprocessor MessageReprocessor + mu sync.RWMutex + cleanupTicker *time.Ticker + ctx context.Context + cancel context.CancelFunc +} + +// DLQMessage represents a message in the dead letter queue +type DLQMessage struct { + ID string + OriginalMessage *Message + Topic string + FirstFailed time.Time + LastAttempt time.Time + AttemptCount int + MaxRetries int + FailureReason string + RetryDelay time.Duration + NextRetry time.Time + Metadata map[string]interface{} + Permanent bool +} + +// DLQConfig configures dead letter queue behavior +type DLQConfig struct { + MaxMessages int + MaxRetries int + RetentionTime time.Duration + AutoReprocess bool + ReprocessInterval time.Duration + BackoffStrategy BackoffStrategy + InitialRetryDelay time.Duration + MaxRetryDelay time.Duration + BackoffMultiplier float64 + PermanentFailures []string // Error patterns that mark messages as permanently failed + ReprocessBatchSize int +} + +// BackoffStrategy defines retry delay calculation methods +type BackoffStrategy string + +const ( + BackoffFixed BackoffStrategy = "fixed" + BackoffLinear BackoffStrategy = "linear" + BackoffExponential BackoffStrategy = "exponential" + BackoffCustom BackoffStrategy = "custom" +) + +// DLQMetrics tracks dead letter queue statistics +type DLQMetrics struct { + MessagesAdded int64 + MessagesReprocessed int64 + MessagesExpired int64 + MessagesPermanent int64 + ReprocessSuccesses int64 + ReprocessFailures int64 + QueueSize int64 + OldestMessage time.Time +} + +// MessageReprocessor handles message reprocessing logic +type MessageReprocessor interface { + Reprocess(ctx context.Context, msg *DLQMessage) error + CanReprocess(msg *DLQMessage) bool + ShouldRetry(msg *DLQMessage, err error) bool +} + +// DefaultMessageReprocessor implements basic reprocessing logic +type DefaultMessageReprocessor struct { + publisher MessagePublisher +} + +// MessagePublisher interface for republishing messages +type MessagePublisher interface { + Publish(ctx context.Context, msg *Message) error +} + +// NewDeadLetterQueue creates a new dead letter queue +func NewDeadLetterQueue(config DLQConfig) *DeadLetterQueue { + ctx, cancel := context.WithCancel(context.Background()) + + dlq := &DeadLetterQueue{ + messages: make(map[string][]*DLQMessage), + config: config, + metrics: DLQMetrics{}, + ctx: ctx, + cancel: cancel, + } + + // Set default configuration values + if dlq.config.MaxMessages == 0 { + dlq.config.MaxMessages = 10000 + } + if dlq.config.MaxRetries == 0 { + dlq.config.MaxRetries = 3 + } + if dlq.config.RetentionTime == 0 { + dlq.config.RetentionTime = 24 * time.Hour + } + if dlq.config.ReprocessInterval == 0 { + dlq.config.ReprocessInterval = 5 * time.Minute + } + if dlq.config.InitialRetryDelay == 0 { + dlq.config.InitialRetryDelay = time.Minute + } + if dlq.config.MaxRetryDelay == 0 { + dlq.config.MaxRetryDelay = time.Hour + } + if dlq.config.BackoffMultiplier == 0 { + dlq.config.BackoffMultiplier = 2.0 + } + if dlq.config.BackoffStrategy == "" { + dlq.config.BackoffStrategy = BackoffExponential + } + if dlq.config.ReprocessBatchSize == 0 { + dlq.config.ReprocessBatchSize = 10 + } + + // Start cleanup routine + dlq.startCleanupRoutine() + + // Start reprocessing routine if enabled + if dlq.config.AutoReprocess { + dlq.startReprocessRoutine() + } + + return dlq +} + +// AddMessage adds a failed message to the dead letter queue +func (dlq *DeadLetterQueue) AddMessage(topic string, msg *Message) error { + return dlq.AddMessageWithReason(topic, msg, "unknown failure") +} + +// AddMessageWithReason adds a failed message with a specific failure reason +func (dlq *DeadLetterQueue) AddMessageWithReason(topic string, msg *Message, reason string) error { + dlq.mu.Lock() + defer dlq.mu.Unlock() + + // Check if we've exceeded max messages + totalMessages := dlq.getTotalMessageCount() + if totalMessages >= dlq.config.MaxMessages { + // Remove oldest message to make room + dlq.removeOldestMessage() + } + + // Check if this is a permanent failure + permanent := dlq.isPermanentFailure(reason) + + dlqMsg := &DLQMessage{ + ID: fmt.Sprintf("dlq_%s_%d", topic, time.Now().UnixNano()), + OriginalMessage: msg, + Topic: topic, + FirstFailed: time.Now(), + LastAttempt: time.Now(), + AttemptCount: 1, + MaxRetries: dlq.config.MaxRetries, + FailureReason: reason, + Metadata: make(map[string]interface{}), + Permanent: permanent, + } + + if !permanent { + dlqMsg.RetryDelay = dlq.calculateRetryDelay(dlqMsg) + dlqMsg.NextRetry = time.Now().Add(dlqMsg.RetryDelay) + } + + // Add to queue + if _, exists := dlq.messages[topic]; !exists { + dlq.messages[topic] = make([]*DLQMessage, 0) + } + dlq.messages[topic] = append(dlq.messages[topic], dlqMsg) + + // Update metrics + dlq.metrics.MessagesAdded++ + dlq.metrics.QueueSize++ + if permanent { + dlq.metrics.MessagesPermanent++ + } + dlq.updateOldestMessage() + + return nil +} + +// GetMessages returns all messages for a topic +func (dlq *DeadLetterQueue) GetMessages(topic string) ([]*DLQMessage, error) { + dlq.mu.RLock() + defer dlq.mu.RUnlock() + + messages, exists := dlq.messages[topic] + if !exists { + return []*DLQMessage{}, nil + } + + // Return a copy to avoid race conditions + result := make([]*DLQMessage, len(messages)) + copy(result, messages) + return result, nil +} + +// GetAllMessages returns all messages across all topics +func (dlq *DeadLetterQueue) GetAllMessages() map[string][]*DLQMessage { + dlq.mu.RLock() + defer dlq.mu.RUnlock() + + result := make(map[string][]*DLQMessage) + for topic, messages := range dlq.messages { + result[topic] = make([]*DLQMessage, len(messages)) + copy(result[topic], messages) + } + return result +} + +// ReprocessMessage attempts to reprocess a specific message +func (dlq *DeadLetterQueue) ReprocessMessage(messageID string) error { + dlq.mu.Lock() + defer dlq.mu.Unlock() + + // Find message + var dlqMsg *DLQMessage + var topic string + var index int + + for t, messages := range dlq.messages { + for i, msg := range messages { + if msg.ID == messageID { + dlqMsg = msg + topic = t + index = i + break + } + } + if dlqMsg != nil { + break + } + } + + if dlqMsg == nil { + return fmt.Errorf("message not found: %s", messageID) + } + + if dlqMsg.Permanent { + return fmt.Errorf("message marked as permanent failure: %s", messageID) + } + + // Attempt reprocessing + err := dlq.attemptReprocess(dlqMsg) + if err == nil { + // Success - remove from queue + dlq.removeMessageByIndex(topic, index) + dlq.metrics.ReprocessSuccesses++ + dlq.metrics.QueueSize-- + return nil + } + + // Failed - update retry information + dlqMsg.AttemptCount++ + dlqMsg.LastAttempt = time.Now() + dlqMsg.FailureReason = err.Error() + + if dlqMsg.AttemptCount >= dlqMsg.MaxRetries { + dlqMsg.Permanent = true + dlq.metrics.MessagesPermanent++ + } else { + dlqMsg.RetryDelay = dlq.calculateRetryDelay(dlqMsg) + dlqMsg.NextRetry = time.Now().Add(dlqMsg.RetryDelay) + } + + dlq.metrics.ReprocessFailures++ + return fmt.Errorf("reprocessing failed: %w", err) +} + +// PurgeMessages removes all messages for a topic +func (dlq *DeadLetterQueue) PurgeMessages(topic string) error { + dlq.mu.Lock() + defer dlq.mu.Unlock() + + if messages, exists := dlq.messages[topic]; exists { + count := len(messages) + delete(dlq.messages, topic) + dlq.metrics.QueueSize -= int64(count) + dlq.updateOldestMessage() + } + + return nil +} + +// PurgeAllMessages removes all messages from the queue +func (dlq *DeadLetterQueue) PurgeAllMessages() error { + dlq.mu.Lock() + defer dlq.mu.Unlock() + + dlq.messages = make(map[string][]*DLQMessage) + dlq.metrics.QueueSize = 0 + dlq.metrics.OldestMessage = time.Time{} + + return nil +} + +// GetMessageCount returns the total number of messages in the queue +func (dlq *DeadLetterQueue) GetMessageCount() int { + dlq.mu.RLock() + defer dlq.mu.RUnlock() + return dlq.getTotalMessageCount() +} + +// GetMetrics returns current DLQ metrics +func (dlq *DeadLetterQueue) GetMetrics() DLQMetrics { + dlq.mu.RLock() + defer dlq.mu.RUnlock() + return dlq.metrics +} + +// SetReprocessor sets the message reprocessor +func (dlq *DeadLetterQueue) SetReprocessor(reprocessor MessageReprocessor) { + dlq.mu.Lock() + defer dlq.mu.Unlock() + dlq.reprocessor = reprocessor +} + +// Cleanup removes expired messages +func (dlq *DeadLetterQueue) Cleanup(maxAge time.Duration) error { + dlq.mu.Lock() + defer dlq.mu.Unlock() + + cutoff := time.Now().Add(-maxAge) + expiredCount := 0 + + for topic, messages := range dlq.messages { + filtered := make([]*DLQMessage, 0) + for _, msg := range messages { + if msg.FirstFailed.After(cutoff) { + filtered = append(filtered, msg) + } else { + expiredCount++ + } + } + dlq.messages[topic] = filtered + + // Remove empty topics + if len(filtered) == 0 { + delete(dlq.messages, topic) + } + } + + dlq.metrics.MessagesExpired += int64(expiredCount) + dlq.metrics.QueueSize -= int64(expiredCount) + dlq.updateOldestMessage() + + return nil +} + +// Stop gracefully shuts down the dead letter queue +func (dlq *DeadLetterQueue) Stop() error { + dlq.cancel() + + if dlq.cleanupTicker != nil { + dlq.cleanupTicker.Stop() + } + + return nil +} + +// Private helper methods + +func (dlq *DeadLetterQueue) getTotalMessageCount() int { + count := 0 + for _, messages := range dlq.messages { + count += len(messages) + } + return count +} + +func (dlq *DeadLetterQueue) removeOldestMessage() { + var oldestTime time.Time + var oldestTopic string + var oldestIndex int + + for topic, messages := range dlq.messages { + for i, msg := range messages { + if oldestTime.IsZero() || msg.FirstFailed.Before(oldestTime) { + oldestTime = msg.FirstFailed + oldestTopic = topic + oldestIndex = i + } + } + } + + if !oldestTime.IsZero() { + dlq.removeMessageByIndex(oldestTopic, oldestIndex) + dlq.metrics.QueueSize-- + } +} + +func (dlq *DeadLetterQueue) removeMessageByIndex(topic string, index int) { + messages := dlq.messages[topic] + dlq.messages[topic] = append(messages[:index], messages[index+1:]...) + + if len(dlq.messages[topic]) == 0 { + delete(dlq.messages, topic) + } +} + +func (dlq *DeadLetterQueue) isPermanentFailure(reason string) bool { + for _, pattern := range dlq.config.PermanentFailures { + if pattern == reason { + return true + } + // Simple pattern matching (can be enhanced with regex) + if len(pattern) > 0 && pattern[len(pattern)-1] == '*' { + prefix := pattern[:len(pattern)-1] + if len(reason) >= len(prefix) && reason[:len(prefix)] == prefix { + return true + } + } + } + return false +} + +func (dlq *DeadLetterQueue) calculateRetryDelay(msg *DLQMessage) time.Duration { + switch dlq.config.BackoffStrategy { + case BackoffFixed: + return dlq.config.InitialRetryDelay + + case BackoffLinear: + delay := time.Duration(msg.AttemptCount) * dlq.config.InitialRetryDelay + if delay > dlq.config.MaxRetryDelay { + return dlq.config.MaxRetryDelay + } + return delay + + case BackoffExponential: + delay := time.Duration(float64(dlq.config.InitialRetryDelay) * + pow(dlq.config.BackoffMultiplier, float64(msg.AttemptCount-1))) + if delay > dlq.config.MaxRetryDelay { + return dlq.config.MaxRetryDelay + } + return delay + + default: + return dlq.config.InitialRetryDelay + } +} + +func (dlq *DeadLetterQueue) attemptReprocess(msg *DLQMessage) error { + if dlq.reprocessor == nil { + return fmt.Errorf("no reprocessor configured") + } + + if !dlq.reprocessor.CanReprocess(msg) { + return fmt.Errorf("message cannot be reprocessed") + } + + return dlq.reprocessor.Reprocess(dlq.ctx, msg) +} + +func (dlq *DeadLetterQueue) updateOldestMessage() { + var oldest time.Time + + for _, messages := range dlq.messages { + for _, msg := range messages { + if oldest.IsZero() || msg.FirstFailed.Before(oldest) { + oldest = msg.FirstFailed + } + } + } + + dlq.metrics.OldestMessage = oldest +} + +func (dlq *DeadLetterQueue) startCleanupRoutine() { + dlq.cleanupTicker = time.NewTicker(dlq.config.ReprocessInterval) + + go func() { + for { + select { + case <-dlq.cleanupTicker.C: + dlq.Cleanup(dlq.config.RetentionTime) + case <-dlq.ctx.Done(): + return + } + } + }() +} + +func (dlq *DeadLetterQueue) startReprocessRoutine() { + ticker := time.NewTicker(dlq.config.ReprocessInterval) + + go func() { + defer ticker.Stop() + + for { + select { + case <-ticker.C: + dlq.processRetryableMessages() + case <-dlq.ctx.Done(): + return + } + } + }() +} + +func (dlq *DeadLetterQueue) processRetryableMessages() { + dlq.mu.Lock() + retryable := dlq.getRetryableMessages() + dlq.mu.Unlock() + + // Sort by next retry time + sort.Slice(retryable, func(i, j int) bool { + return retryable[i].NextRetry.Before(retryable[j].NextRetry) + }) + + // Process batch + batchSize := dlq.config.ReprocessBatchSize + if len(retryable) < batchSize { + batchSize = len(retryable) + } + + for i := 0; i < batchSize; i++ { + msg := retryable[i] + if time.Now().After(msg.NextRetry) { + dlq.ReprocessMessage(msg.ID) + } + } +} + +func (dlq *DeadLetterQueue) getRetryableMessages() []*DLQMessage { + var retryable []*DLQMessage + + for _, messages := range dlq.messages { + for _, msg := range messages { + if !msg.Permanent && msg.AttemptCount < msg.MaxRetries { + retryable = append(retryable, msg) + } + } + } + + return retryable +} + +// Implementation of DefaultMessageReprocessor + +func NewDefaultMessageReprocessor(publisher MessagePublisher) *DefaultMessageReprocessor { + return &DefaultMessageReprocessor{ + publisher: publisher, + } +} + +func (r *DefaultMessageReprocessor) Reprocess(ctx context.Context, msg *DLQMessage) error { + if r.publisher == nil { + return fmt.Errorf("no publisher configured") + } + + return r.publisher.Publish(ctx, msg.OriginalMessage) +} + +func (r *DefaultMessageReprocessor) CanReprocess(msg *DLQMessage) bool { + return !msg.Permanent && msg.AttemptCount < msg.MaxRetries +} + +func (r *DefaultMessageReprocessor) ShouldRetry(msg *DLQMessage, err error) bool { + // Simple retry logic - can be enhanced based on error types + return msg.AttemptCount < msg.MaxRetries +} + +// Helper function for power calculation +func pow(base, exp float64) float64 { + if exp == 0 { + return 1 + } + result := base + for i := 1; i < int(exp); i++ { + result *= base + } + return result +} diff --git a/pkg/transport/interfaces.go b/pkg/transport/interfaces.go new file mode 100644 index 0000000..c62e438 --- /dev/null +++ b/pkg/transport/interfaces.go @@ -0,0 +1,277 @@ +package transport + +import ( + "context" + "time" +) + +// MessageType represents the type of message being sent +type MessageType string + +const ( + // Core message types + MessageTypeEvent MessageType = "event" + MessageTypeCommand MessageType = "command" + MessageTypeResponse MessageType = "response" + MessageTypeHeartbeat MessageType = "heartbeat" + MessageTypeStatus MessageType = "status" + MessageTypeError MessageType = "error" + + // Business-specific message types + MessageTypeArbitrage MessageType = "arbitrage" + MessageTypeMarketData MessageType = "market_data" + MessageTypeExecution MessageType = "execution" + MessageTypeRiskCheck MessageType = "risk_check" +) + +// Priority levels for message routing +type Priority uint8 + +const ( + PriorityLow Priority = iota + PriorityNormal + PriorityHigh + PriorityCritical + PriorityEmergency +) + +// Message represents a universal message in the system +type Message struct { + ID string `json:"id"` + Type MessageType `json:"type"` + Topic string `json:"topic"` + Source string `json:"source"` + Destination string `json:"destination"` + Priority Priority `json:"priority"` + Timestamp time.Time `json:"timestamp"` + TTL time.Duration `json:"ttl"` + Headers map[string]string `json:"headers"` + Payload []byte `json:"payload"` + Metadata map[string]interface{} `json:"metadata"` +} + +// MessageHandler processes incoming messages +type MessageHandler func(ctx context.Context, msg *Message) error + +// Transport defines the interface for different transport mechanisms +type Transport interface { + // Start initializes the transport + Start(ctx context.Context) error + + // Stop gracefully shuts down the transport + Stop(ctx context.Context) error + + // Send publishes a message + Send(ctx context.Context, msg *Message) error + + // Subscribe registers a handler for messages on a topic + Subscribe(ctx context.Context, topic string, handler MessageHandler) error + + // Unsubscribe removes a handler for a topic + Unsubscribe(ctx context.Context, topic string) error + + // GetStats returns transport statistics + GetStats() TransportStats + + // GetType returns the transport type + GetType() TransportType + + // IsHealthy checks if the transport is functioning properly + IsHealthy() bool +} + +// TransportType identifies different transport implementations +type TransportType string + +const ( + TransportTypeSharedMemory TransportType = "shared_memory" + TransportTypeUnixSocket TransportType = "unix_socket" + TransportTypeTCP TransportType = "tcp" + TransportTypeWebSocket TransportType = "websocket" + TransportTypeGRPC TransportType = "grpc" +) + +// TransportStats provides metrics about transport performance +type TransportStats struct { + MessagesSent uint64 `json:"messages_sent"` + MessagesReceived uint64 `json:"messages_received"` + MessagesDropped uint64 `json:"messages_dropped"` + BytesSent uint64 `json:"bytes_sent"` + BytesReceived uint64 `json:"bytes_received"` + Latency time.Duration `json:"latency"` + ErrorCount uint64 `json:"error_count"` + ConnectedPeers int `json:"connected_peers"` + Uptime time.Duration `json:"uptime"` +} + +// MessageBus coordinates message routing across multiple transports +type MessageBus interface { + // Start initializes the message bus + Start(ctx context.Context) error + + // Stop gracefully shuts down the message bus + Stop(ctx context.Context) error + + // RegisterTransport adds a transport to the bus + RegisterTransport(transport Transport) error + + // UnregisterTransport removes a transport from the bus + UnregisterTransport(transportType TransportType) error + + // Publish sends a message through the optimal transport + Publish(ctx context.Context, msg *Message) error + + // Subscribe registers a handler for messages on a topic + Subscribe(ctx context.Context, topic string, handler MessageHandler) error + + // Unsubscribe removes a handler for a topic + Unsubscribe(ctx context.Context, topic string) error + + // GetTransport returns a specific transport + GetTransport(transportType TransportType) (Transport, error) + + // GetStats returns aggregated statistics + GetStats() MessageBusStats +} + +// MessageBusStats provides comprehensive metrics +type MessageBusStats struct { + TotalMessages uint64 `json:"total_messages"` + MessagesByType map[MessageType]uint64 `json:"messages_by_type"` + TransportStats map[TransportType]TransportStats `json:"transport_stats"` + ActiveTopics []string `json:"active_topics"` + Subscribers int `json:"subscribers"` + AverageLatency time.Duration `json:"average_latency"` + ThroughputMPS float64 `json:"throughput_mps"` // Messages per second +} + +// Router determines the best transport for a message +type Router interface { + // Route selects the optimal transport for a message + Route(msg *Message) (TransportType, error) + + // AddRule adds a routing rule + AddRule(rule RoutingRule) error + + // RemoveRule removes a routing rule + RemoveRule(ruleID string) error + + // GetRules returns all routing rules + GetRules() []RoutingRule +} + +// RoutingRule defines how messages should be routed +type RoutingRule struct { + ID string `json:"id"` + Priority int `json:"priority"` + Condition Condition `json:"condition"` + Transport TransportType `json:"transport"` + Fallback TransportType `json:"fallback,omitempty"` + Description string `json:"description"` +} + +// Condition defines when a routing rule applies +type Condition struct { + MessageType *MessageType `json:"message_type,omitempty"` + Topic *string `json:"topic,omitempty"` + Priority *Priority `json:"priority,omitempty"` + Source *string `json:"source,omitempty"` + Destination *string `json:"destination,omitempty"` + PayloadSize *int `json:"payload_size,omitempty"` + LatencyReq *time.Duration `json:"latency_requirement,omitempty"` +} + +// DeadLetterQueue handles failed messages +type DeadLetterQueue interface { + // Add puts a failed message in the queue + Add(ctx context.Context, msg *Message, reason error) error + + // Retry attempts to resend failed messages + Retry(ctx context.Context, maxRetries int) error + + // Get retrieves failed messages + Get(ctx context.Context, limit int) ([]*FailedMessage, error) + + // Remove deletes a failed message + Remove(ctx context.Context, messageID string) error + + // GetStats returns dead letter queue statistics + GetStats() DLQStats +} + +// FailedMessage represents a message that couldn't be delivered +type FailedMessage struct { + Message *Message `json:"message"` + Reason string `json:"reason"` + Attempts int `json:"attempts"` + FirstFailed time.Time `json:"first_failed"` + LastAttempt time.Time `json:"last_attempt"` +} + +// DLQStats provides dead letter queue metrics +type DLQStats struct { + TotalMessages uint64 `json:"total_messages"` + RetryableMessages uint64 `json:"retryable_messages"` + PermanentFailures uint64 `json:"permanent_failures"` + OldestMessage time.Time `json:"oldest_message"` + AverageRetries float64 `json:"average_retries"` +} + +// Serializer handles message encoding/decoding +type Serializer interface { + // Serialize converts a message to bytes + Serialize(msg *Message) ([]byte, error) + + // Deserialize converts bytes to a message + Deserialize(data []byte) (*Message, error) + + // GetFormat returns the serialization format + GetFormat() SerializationFormat +} + +// SerializationFormat defines encoding types +type SerializationFormat string + +const ( + FormatJSON SerializationFormat = "json" + FormatProtobuf SerializationFormat = "protobuf" + FormatMsgPack SerializationFormat = "msgpack" + FormatAvro SerializationFormat = "avro" +) + +// Persistence handles message storage +type Persistence interface { + // Store saves a message for persistence + Store(ctx context.Context, msg *Message) error + + // Retrieve gets a stored message + Retrieve(ctx context.Context, messageID string) (*Message, error) + + // Delete removes a stored message + Delete(ctx context.Context, messageID string) error + + // List returns stored messages matching criteria + List(ctx context.Context, criteria PersistenceCriteria) ([]*Message, error) + + // GetStats returns persistence statistics + GetStats() PersistenceStats +} + +// PersistenceCriteria defines search parameters +type PersistenceCriteria struct { + Topic *string `json:"topic,omitempty"` + MessageType *MessageType `json:"message_type,omitempty"` + Source *string `json:"source,omitempty"` + FromTime *time.Time `json:"from_time,omitempty"` + ToTime *time.Time `json:"to_time,omitempty"` + Limit int `json:"limit"` + Offset int `json:"offset"` +} + +// PersistenceStats provides storage metrics +type PersistenceStats struct { + StoredMessages uint64 `json:"stored_messages"` + StorageSize uint64 `json:"storage_size_bytes"` + OldestMessage time.Time `json:"oldest_message"` + NewestMessage time.Time `json:"newest_message"` +} diff --git a/pkg/transport/memory_transport.go b/pkg/transport/memory_transport.go new file mode 100644 index 0000000..5e8196b --- /dev/null +++ b/pkg/transport/memory_transport.go @@ -0,0 +1,230 @@ +package transport + +import ( + "context" + "fmt" + "sync" + "time" +) + +// MemoryTransport implements in-memory message transport for local communication +type MemoryTransport struct { + channels map[string]chan *Message + metrics TransportMetrics + connected bool + mu sync.RWMutex +} + +// NewMemoryTransport creates a new in-memory transport +func NewMemoryTransport() *MemoryTransport { + return &MemoryTransport{ + channels: make(map[string]chan *Message), + metrics: TransportMetrics{}, + } +} + +// Connect establishes the transport connection +func (mt *MemoryTransport) Connect(ctx context.Context) error { + mt.mu.Lock() + defer mt.mu.Unlock() + + if mt.connected { + return nil + } + + mt.connected = true + mt.metrics.Connections = 1 + return nil +} + +// Disconnect closes the transport connection +func (mt *MemoryTransport) Disconnect(ctx context.Context) error { + mt.mu.Lock() + defer mt.mu.Unlock() + + if !mt.connected { + return nil + } + + // Close all channels + for _, ch := range mt.channels { + close(ch) + } + mt.channels = make(map[string]chan *Message) + mt.connected = false + mt.metrics.Connections = 0 + + return nil +} + +// Send transmits a message through the memory transport +func (mt *MemoryTransport) Send(ctx context.Context, msg *Message) error { + start := time.Now() + + mt.mu.RLock() + if !mt.connected { + mt.mu.RUnlock() + mt.metrics.Errors++ + return fmt.Errorf("transport not connected") + } + + // Get or create channel for topic + ch, exists := mt.channels[msg.Topic] + if !exists { + mt.mu.RUnlock() + mt.mu.Lock() + // Double-check after acquiring write lock + if ch, exists = mt.channels[msg.Topic]; !exists { + ch = make(chan *Message, 1000) // Buffered channel + mt.channels[msg.Topic] = ch + } + mt.mu.Unlock() + } else { + mt.mu.RUnlock() + } + + // Send message + select { + case ch <- msg: + mt.updateSendMetrics(msg, time.Since(start)) + return nil + case <-ctx.Done(): + mt.metrics.Errors++ + return ctx.Err() + default: + mt.metrics.Errors++ + return fmt.Errorf("channel full for topic: %s", msg.Topic) + } +} + +// Receive returns a channel for receiving messages +func (mt *MemoryTransport) Receive(ctx context.Context) (<-chan *Message, error) { + mt.mu.RLock() + defer mt.mu.RUnlock() + + if !mt.connected { + return nil, fmt.Errorf("transport not connected") + } + + // Create a merged channel that receives from all topic channels + merged := make(chan *Message, 1000) + + go func() { + defer close(merged) + + // Use a wait group to handle multiple topic channels + var wg sync.WaitGroup + + mt.mu.RLock() + for topic, ch := range mt.channels { + wg.Add(1) + go func(topicCh <-chan *Message, topicName string) { + defer wg.Done() + for { + select { + case msg, ok := <-topicCh: + if !ok { + return + } + select { + case merged <- msg: + mt.updateReceiveMetrics(msg) + case <-ctx.Done(): + return + } + case <-ctx.Done(): + return + } + } + }(ch, topic) + } + mt.mu.RUnlock() + + wg.Wait() + }() + + return merged, nil +} + +// Health returns the health status of the transport +func (mt *MemoryTransport) Health() ComponentHealth { + mt.mu.RLock() + defer mt.mu.RUnlock() + + status := "unhealthy" + if mt.connected { + status = "healthy" + } + + return ComponentHealth{ + Status: status, + LastCheck: time.Now(), + ResponseTime: time.Microsecond, // Very fast for memory transport + ErrorCount: mt.metrics.Errors, + } +} + +// GetMetrics returns transport-specific metrics +func (mt *MemoryTransport) GetMetrics() TransportMetrics { + mt.mu.RLock() + defer mt.mu.RUnlock() + + // Create a copy to avoid race conditions + return TransportMetrics{ + BytesSent: mt.metrics.BytesSent, + BytesReceived: mt.metrics.BytesReceived, + MessagesSent: mt.metrics.MessagesSent, + MessagesReceived: mt.metrics.MessagesReceived, + Connections: mt.metrics.Connections, + Errors: mt.metrics.Errors, + Latency: mt.metrics.Latency, + } +} + +// Private helper methods + +func (mt *MemoryTransport) updateSendMetrics(msg *Message, latency time.Duration) { + mt.mu.Lock() + defer mt.mu.Unlock() + + mt.metrics.MessagesSent++ + mt.metrics.Latency = latency + + // Estimate message size (simplified) + messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) + if msg.Data != nil { + messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) + } + mt.metrics.BytesSent += messageSize +} + +func (mt *MemoryTransport) updateReceiveMetrics(msg *Message) { + mt.mu.Lock() + defer mt.mu.Unlock() + + mt.metrics.MessagesReceived++ + + // Estimate message size (simplified) + messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) + if msg.Data != nil { + messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) + } + mt.metrics.BytesReceived += messageSize +} + +// GetChannelForTopic returns the channel for a specific topic (for testing/debugging) +func (mt *MemoryTransport) GetChannelForTopic(topic string) (<-chan *Message, bool) { + mt.mu.RLock() + defer mt.mu.RUnlock() + + ch, exists := mt.channels[topic] + return ch, exists +} + +// GetTopicCount returns the number of active topic channels +func (mt *MemoryTransport) GetTopicCount() int { + mt.mu.RLock() + defer mt.mu.RUnlock() + + return len(mt.channels) +} diff --git a/pkg/transport/message_bus.go b/pkg/transport/message_bus.go new file mode 100644 index 0000000..fa28c25 --- /dev/null +++ b/pkg/transport/message_bus.go @@ -0,0 +1,453 @@ +package transport + +import ( + "context" + "fmt" + "sync" + "time" +) + +// MessageType represents the type of message being sent +type MessageType string + +const ( + MessageTypeEvent MessageType = "event" + MessageTypeCommand MessageType = "command" + MessageTypeQuery MessageType = "query" + MessageTypeResponse MessageType = "response" + MessageTypeNotification MessageType = "notification" + MessageTypeHeartbeat MessageType = "heartbeat" +) + +// MessagePriority defines message processing priority +type MessagePriority int + +const ( + PriorityLow MessagePriority = iota + PriorityNormal + PriorityHigh + PriorityCritical +) + +// Message represents a universal message in the system +type Message struct { + ID string `json:"id"` + Type MessageType `json:"type"` + Topic string `json:"topic"` + Source string `json:"source"` + Target string `json:"target,omitempty"` + Priority MessagePriority `json:"priority"` + Timestamp time.Time `json:"timestamp"` + Data interface{} `json:"data"` + Headers map[string]string `json:"headers,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + TTL time.Duration `json:"ttl,omitempty"` + Retries int `json:"retries"` + MaxRetries int `json:"max_retries"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// MessageHandler processes incoming messages +type MessageHandler func(ctx context.Context, msg *Message) error + +// MessageFilter determines if a message should be processed +type MessageFilter func(msg *Message) bool + +// Subscription represents a topic subscription +type Subscription struct { + ID string + Topic string + Filter MessageFilter + Handler MessageHandler + Options SubscriptionOptions + created time.Time + active bool + mu sync.RWMutex +} + +// SubscriptionOptions configures subscription behavior +type SubscriptionOptions struct { + QueueSize int + BatchSize int + BatchTimeout time.Duration + DLQEnabled bool + RetryEnabled bool + Persistent bool + Durable bool +} + +// MessageBusInterface defines the universal message bus contract +type MessageBusInterface interface { + // Core messaging operations + Publish(ctx context.Context, msg *Message) error + Subscribe(topic string, handler MessageHandler, opts ...SubscriptionOption) (*Subscription, error) + Unsubscribe(subscriptionID string) error + + // Advanced messaging patterns + Request(ctx context.Context, msg *Message, timeout time.Duration) (*Message, error) + Reply(ctx context.Context, originalMsg *Message, response *Message) error + + // Topic management + CreateTopic(topic string, config TopicConfig) error + DeleteTopic(topic string) error + ListTopics() []string + GetTopicInfo(topic string) (*TopicInfo, error) + + // Queue operations + QueueMessage(topic string, msg *Message) error + DequeueMessage(topic string, timeout time.Duration) (*Message, error) + PeekMessage(topic string) (*Message, error) + + // Dead letter queue + GetDLQMessages(topic string) ([]*Message, error) + ReprocessDLQMessage(messageID string) error + PurgeDLQ(topic string) error + + // Lifecycle management + Start(ctx context.Context) error + Stop(ctx context.Context) error + Health() HealthStatus + + // Metrics and monitoring + GetMetrics() MessageBusMetrics + GetSubscriptions() []*Subscription + GetActiveConnections() int +} + +// SubscriptionOption configures subscription behavior +type SubscriptionOption func(*SubscriptionOptions) + +// TopicConfig defines topic configuration +type TopicConfig struct { + Persistent bool + Replicated bool + RetentionPolicy RetentionPolicy + Partitions int + MaxMessageSize int64 + TTL time.Duration +} + +// RetentionPolicy defines message retention behavior +type RetentionPolicy struct { + MaxMessages int + MaxAge time.Duration + MaxSize int64 +} + +// TopicInfo provides topic statistics +type TopicInfo struct { + Name string + Config TopicConfig + MessageCount int64 + SubscriberCount int + LastActivity time.Time + SizeBytes int64 +} + +// HealthStatus represents system health +type HealthStatus struct { + Status string + Uptime time.Duration + LastCheck time.Time + Components map[string]ComponentHealth + Errors []HealthError +} + +// ComponentHealth represents component-specific health +type ComponentHealth struct { + Status string + LastCheck time.Time + ResponseTime time.Duration + ErrorCount int64 +} + +// HealthError represents a health check error +type HealthError struct { + Component string + Message string + Timestamp time.Time + Severity string +} + +// MessageBusMetrics provides operational metrics +type MessageBusMetrics struct { + MessagesPublished int64 + MessagesConsumed int64 + MessagesFailed int64 + MessagesInDLQ int64 + ActiveSubscriptions int + TopicCount int + AverageLatency time.Duration + ThroughputPerSec float64 + ErrorRate float64 + MemoryUsage int64 + CPUUsage float64 +} + +// TransportType defines available transport mechanisms +type TransportType string + +const ( + TransportMemory TransportType = "memory" + TransportUnixSocket TransportType = "unix" + TransportTCP TransportType = "tcp" + TransportWebSocket TransportType = "websocket" + TransportRedis TransportType = "redis" + TransportNATS TransportType = "nats" +) + +// TransportConfig configures transport layer +type TransportConfig struct { + Type TransportType + Address string + Options map[string]interface{} + RetryConfig RetryConfig + SecurityConfig SecurityConfig +} + +// RetryConfig defines retry behavior +type RetryConfig struct { + MaxRetries int + InitialDelay time.Duration + MaxDelay time.Duration + BackoffFactor float64 + Jitter bool +} + +// SecurityConfig defines security settings +type SecurityConfig struct { + Enabled bool + TLSConfig *TLSConfig + AuthConfig *AuthConfig + Encryption bool + Compression bool +} + +// TLSConfig for secure transport +type TLSConfig struct { + CertFile string + KeyFile string + CAFile string + Verify bool +} + +// AuthConfig for authentication +type AuthConfig struct { + Username string + Password string + Token string + Method string +} + +// UniversalMessageBus implements MessageBusInterface +type UniversalMessageBus struct { + config MessageBusConfig + transports map[TransportType]Transport + router *MessageRouter + topics map[string]*Topic + subscriptions map[string]*Subscription + dlq *DeadLetterQueue + metrics *MetricsCollector + persistence PersistenceLayer + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + started bool +} + +// MessageBusConfig configures the message bus +type MessageBusConfig struct { + DefaultTransport TransportType + EnablePersistence bool + EnableMetrics bool + EnableDLQ bool + MaxMessageSize int64 + DefaultTTL time.Duration + HealthCheckInterval time.Duration + CleanupInterval time.Duration +} + +// Transport interface for different transport mechanisms +type Transport interface { + Send(ctx context.Context, msg *Message) error + Receive(ctx context.Context) (<-chan *Message, error) + Connect(ctx context.Context) error + Disconnect(ctx context.Context) error + Health() ComponentHealth + GetMetrics() TransportMetrics +} + +// TransportMetrics for transport-specific metrics +type TransportMetrics struct { + BytesSent int64 + BytesReceived int64 + MessagesSent int64 + MessagesReceived int64 + Connections int + Errors int64 + Latency time.Duration +} + +// MessageRouter handles message routing logic +type MessageRouter struct { + rules []RoutingRule + fallback TransportType + loadBalancer LoadBalancer + mu sync.RWMutex +} + +// RoutingRule defines message routing logic +type RoutingRule struct { + Condition MessageFilter + Transport TransportType + Priority int + Enabled bool +} + +// LoadBalancer for transport selection +type LoadBalancer interface { + SelectTransport(transports []TransportType, msg *Message) TransportType + UpdateStats(transport TransportType, latency time.Duration, success bool) +} + +// Topic represents a message topic +type Topic struct { + Name string + Config TopicConfig + Messages []StoredMessage + Subscribers []*Subscription + Created time.Time + LastActivity time.Time + mu sync.RWMutex +} + +// StoredMessage represents a persisted message +type StoredMessage struct { + Message *Message + Stored time.Time + Processed bool +} + +// DeadLetterQueue handles failed messages +type DeadLetterQueue struct { + messages map[string][]*Message + config DLQConfig + mu sync.RWMutex +} + +// DLQConfig configures dead letter queue +type DLQConfig struct { + MaxMessages int + MaxRetries int + RetentionTime time.Duration + AutoReprocess bool +} + +// MetricsCollector gathers operational metrics +type MetricsCollector struct { + metrics map[string]interface{} + mu sync.RWMutex +} + +// PersistenceLayer handles message persistence +type PersistenceLayer interface { + Store(msg *Message) error + Retrieve(id string) (*Message, error) + Delete(id string) error + List(topic string, limit int) ([]*Message, error) + Cleanup(maxAge time.Duration) error +} + +// Factory functions for common subscription options +func WithQueueSize(size int) SubscriptionOption { + return func(opts *SubscriptionOptions) { + opts.QueueSize = size + } +} + +func WithBatchProcessing(size int, timeout time.Duration) SubscriptionOption { + return func(opts *SubscriptionOptions) { + opts.BatchSize = size + opts.BatchTimeout = timeout + } +} + +func WithDLQ(enabled bool) SubscriptionOption { + return func(opts *SubscriptionOptions) { + opts.DLQEnabled = enabled + } +} + +func WithRetry(enabled bool) SubscriptionOption { + return func(opts *SubscriptionOptions) { + opts.RetryEnabled = enabled + } +} + +func WithPersistence(enabled bool) SubscriptionOption { + return func(opts *SubscriptionOptions) { + opts.Persistent = enabled + } +} + +// NewUniversalMessageBus creates a new message bus instance +func NewUniversalMessageBus(config MessageBusConfig) *UniversalMessageBus { + ctx, cancel := context.WithCancel(context.Background()) + + return &UniversalMessageBus{ + config: config, + transports: make(map[TransportType]Transport), + topics: make(map[string]*Topic), + subscriptions: make(map[string]*Subscription), + router: NewMessageRouter(), + dlq: NewDeadLetterQueue(DLQConfig{}), + metrics: NewMetricsCollector(), + ctx: ctx, + cancel: cancel, + } +} + +// NewMessageRouter creates a new message router +func NewMessageRouter() *MessageRouter { + return &MessageRouter{ + rules: make([]RoutingRule, 0), + fallback: TransportMemory, + } +} + +// NewDeadLetterQueue creates a new dead letter queue +func NewDeadLetterQueue(config DLQConfig) *DeadLetterQueue { + return &DeadLetterQueue{ + messages: make(map[string][]*Message), + config: config, + } +} + +// NewMetricsCollector creates a new metrics collector +func NewMetricsCollector() *MetricsCollector { + return &MetricsCollector{ + metrics: make(map[string]interface{}), + } +} + +// Helper function to generate message ID +func GenerateMessageID() string { + return fmt.Sprintf("msg_%d_%d", time.Now().UnixNano(), time.Now().Nanosecond()) +} + +// Helper function to create message with defaults +func NewMessage(msgType MessageType, topic string, source string, data interface{}) *Message { + return &Message{ + ID: GenerateMessageID(), + Type: msgType, + Topic: topic, + Source: source, + Priority: PriorityNormal, + Timestamp: time.Now(), + Data: data, + Headers: make(map[string]string), + Metadata: make(map[string]interface{}), + Retries: 0, + MaxRetries: 3, + } +} diff --git a/pkg/transport/message_bus_impl.go b/pkg/transport/message_bus_impl.go new file mode 100644 index 0000000..9ac0874 --- /dev/null +++ b/pkg/transport/message_bus_impl.go @@ -0,0 +1,743 @@ +package transport + +import ( + "context" + "fmt" + "sync" + "time" +) + +// Publish sends a message to the specified topic +func (mb *UniversalMessageBus) Publish(ctx context.Context, msg *Message) error { + if !mb.started { + return fmt.Errorf("message bus not started") + } + + // Validate message + if err := mb.validateMessage(msg); err != nil { + return fmt.Errorf("invalid message: %w", err) + } + + // Set timestamp if not set + if msg.Timestamp.IsZero() { + msg.Timestamp = time.Now() + } + + // Set ID if not set + if msg.ID == "" { + msg.ID = GenerateMessageID() + } + + // Update metrics + mb.metrics.IncrementCounter("messages_published_total") + mb.metrics.RecordLatency("publish_latency", time.Since(msg.Timestamp)) + + // Route message to appropriate transport + transport, err := mb.router.RouteMessage(msg, mb.transports) + if err != nil { + mb.metrics.IncrementCounter("routing_errors_total") + return fmt.Errorf("routing failed: %w", err) + } + + // Send via transport + if err := transport.Send(ctx, msg); err != nil { + mb.metrics.IncrementCounter("send_errors_total") + // Try dead letter queue if enabled + if mb.config.EnableDLQ { + if dlqErr := mb.dlq.AddMessage(msg.Topic, msg); dlqErr != nil { + return fmt.Errorf("send failed and DLQ failed: %v, original error: %w", dlqErr, err) + } + } + return fmt.Errorf("send failed: %w", err) + } + + // Store in topic if persistence enabled + if mb.config.EnablePersistence { + if err := mb.addMessageToTopic(msg); err != nil { + // Log error but don't fail the publish + mb.metrics.IncrementCounter("persistence_errors_total") + } + } + + // Deliver to local subscribers + go mb.deliverToSubscribers(ctx, msg) + + return nil +} + +// Subscribe creates a subscription to a topic +func (mb *UniversalMessageBus) Subscribe(topic string, handler MessageHandler, opts ...SubscriptionOption) (*Subscription, error) { + if !mb.started { + return nil, fmt.Errorf("message bus not started") + } + + // Apply subscription options + options := SubscriptionOptions{ + QueueSize: 1000, + BatchSize: 1, + BatchTimeout: time.Second, + DLQEnabled: mb.config.EnableDLQ, + RetryEnabled: true, + Persistent: false, + Durable: false, + } + + for _, opt := range opts { + opt(&options) + } + + // Create subscription + subscription := &Subscription{ + ID: fmt.Sprintf("sub_%s_%d", topic, time.Now().UnixNano()), + Topic: topic, + Handler: handler, + Options: options, + created: time.Now(), + active: true, + } + + mb.mu.Lock() + mb.subscriptions[subscription.ID] = subscription + mb.mu.Unlock() + + // Add to topic subscribers + mb.addSubscriberToTopic(topic, subscription) + + mb.metrics.IncrementCounter("subscriptions_created_total") + + return subscription, nil +} + +// Unsubscribe removes a subscription +func (mb *UniversalMessageBus) Unsubscribe(subscriptionID string) error { + mb.mu.Lock() + defer mb.mu.Unlock() + + subscription, exists := mb.subscriptions[subscriptionID] + if !exists { + return fmt.Errorf("subscription not found: %s", subscriptionID) + } + + // Mark as inactive + subscription.mu.Lock() + subscription.active = false + subscription.mu.Unlock() + + // Remove from subscriptions map + delete(mb.subscriptions, subscriptionID) + + // Remove from topic subscribers + mb.removeSubscriberFromTopic(subscription.Topic, subscriptionID) + + mb.metrics.IncrementCounter("subscriptions_removed_total") + + return nil +} + +// Request sends a request and waits for a response +func (mb *UniversalMessageBus) Request(ctx context.Context, msg *Message, timeout time.Duration) (*Message, error) { + if !mb.started { + return nil, fmt.Errorf("message bus not started") + } + + // Set correlation ID for request-response + if msg.CorrelationID == "" { + msg.CorrelationID = GenerateMessageID() + } + + // Create response channel + responseChannel := make(chan *Message, 1) + defer close(responseChannel) + + // Subscribe to response topic + responseTopic := fmt.Sprintf("response.%s", msg.CorrelationID) + subscription, err := mb.Subscribe(responseTopic, func(ctx context.Context, response *Message) error { + select { + case responseChannel <- response: + default: + // Channel full, ignore + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to response topic: %w", err) + } + defer mb.Unsubscribe(subscription.ID) + + // Send request + if err := mb.Publish(ctx, msg); err != nil { + return nil, fmt.Errorf("failed to publish request: %w", err) + } + + // Wait for response with timeout + select { + case response := <-responseChannel: + return response, nil + case <-time.After(timeout): + return nil, fmt.Errorf("request timeout after %v", timeout) + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Reply sends a response to a request +func (mb *UniversalMessageBus) Reply(ctx context.Context, originalMsg *Message, response *Message) error { + if originalMsg.CorrelationID == "" { + return fmt.Errorf("original message has no correlation ID") + } + + // Set response properties + response.Type = MessageTypeResponse + response.Topic = fmt.Sprintf("response.%s", originalMsg.CorrelationID) + response.CorrelationID = originalMsg.CorrelationID + response.Target = originalMsg.Source + + return mb.Publish(ctx, response) +} + +// CreateTopic creates a new topic with configuration +func (mb *UniversalMessageBus) CreateTopic(topicName string, config TopicConfig) error { + mb.mu.Lock() + defer mb.mu.Unlock() + + if _, exists := mb.topics[topicName]; exists { + return fmt.Errorf("topic already exists: %s", topicName) + } + + topic := &Topic{ + Name: topicName, + Config: config, + Messages: make([]StoredMessage, 0), + Subscribers: make([]*Subscription, 0), + Created: time.Now(), + LastActivity: time.Now(), + } + + mb.topics[topicName] = topic + mb.metrics.IncrementCounter("topics_created_total") + + return nil +} + +// DeleteTopic removes a topic +func (mb *UniversalMessageBus) DeleteTopic(topicName string) error { + mb.mu.Lock() + defer mb.mu.Unlock() + + topic, exists := mb.topics[topicName] + if !exists { + return fmt.Errorf("topic not found: %s", topicName) + } + + // Remove all subscribers + for _, sub := range topic.Subscribers { + mb.Unsubscribe(sub.ID) + } + + delete(mb.topics, topicName) + mb.metrics.IncrementCounter("topics_deleted_total") + + return nil +} + +// ListTopics returns all topic names +func (mb *UniversalMessageBus) ListTopics() []string { + mb.mu.RLock() + defer mb.mu.RUnlock() + + topics := make([]string, 0, len(mb.topics)) + for name := range mb.topics { + topics = append(topics, name) + } + + return topics +} + +// GetTopicInfo returns topic information +func (mb *UniversalMessageBus) GetTopicInfo(topicName string) (*TopicInfo, error) { + mb.mu.RLock() + defer mb.mu.RUnlock() + + topic, exists := mb.topics[topicName] + if !exists { + return nil, fmt.Errorf("topic not found: %s", topicName) + } + + topic.mu.RLock() + defer topic.mu.RUnlock() + + // Calculate size + var sizeBytes int64 + for _, stored := range topic.Messages { + // Rough estimation of message size + sizeBytes += int64(len(fmt.Sprintf("%+v", stored.Message))) + } + + return &TopicInfo{ + Name: topic.Name, + Config: topic.Config, + MessageCount: int64(len(topic.Messages)), + SubscriberCount: len(topic.Subscribers), + LastActivity: topic.LastActivity, + SizeBytes: sizeBytes, + }, nil +} + +// QueueMessage adds a message to a topic queue +func (mb *UniversalMessageBus) QueueMessage(topic string, msg *Message) error { + return mb.addMessageToTopic(msg) +} + +// DequeueMessage removes a message from a topic queue +func (mb *UniversalMessageBus) DequeueMessage(topic string, timeout time.Duration) (*Message, error) { + start := time.Now() + + for time.Since(start) < timeout { + mb.mu.RLock() + topicObj, exists := mb.topics[topic] + mb.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("topic not found: %s", topic) + } + + topicObj.mu.Lock() + if len(topicObj.Messages) > 0 { + // Get first unprocessed message + for i, stored := range topicObj.Messages { + if !stored.Processed { + topicObj.Messages[i].Processed = true + topicObj.mu.Unlock() + return stored.Message, nil + } + } + } + topicObj.mu.Unlock() + + // Wait a bit before trying again + time.Sleep(10 * time.Millisecond) + } + + return nil, fmt.Errorf("no message available within timeout") +} + +// PeekMessage returns the next message without removing it +func (mb *UniversalMessageBus) PeekMessage(topic string) (*Message, error) { + mb.mu.RLock() + topicObj, exists := mb.topics[topic] + mb.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("topic not found: %s", topic) + } + + topicObj.mu.RLock() + defer topicObj.mu.RUnlock() + + for _, stored := range topicObj.Messages { + if !stored.Processed { + return stored.Message, nil + } + } + + return nil, fmt.Errorf("no messages available") +} + +// Start initializes and starts the message bus +func (mb *UniversalMessageBus) Start(ctx context.Context) error { + mb.mu.Lock() + defer mb.mu.Unlock() + + if mb.started { + return fmt.Errorf("message bus already started") + } + + // Initialize default transport if none configured + if len(mb.transports) == 0 { + memTransport := NewMemoryTransport() + mb.transports[TransportMemory] = memTransport + if err := memTransport.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect default transport: %w", err) + } + } + + // Start background routines + go mb.healthCheckLoop() + go mb.cleanupLoop() + go mb.metricsLoop() + + mb.started = true + mb.metrics.RecordEvent("message_bus_started") + + return nil +} + +// Stop gracefully shuts down the message bus +func (mb *UniversalMessageBus) Stop(ctx context.Context) error { + mb.mu.Lock() + defer mb.mu.Unlock() + + if !mb.started { + return nil + } + + // Cancel context to stop background routines + mb.cancel() + + // Disconnect all transports + for _, transport := range mb.transports { + if err := transport.Disconnect(ctx); err != nil { + // Log error but continue shutdown + } + } + + mb.started = false + mb.metrics.RecordEvent("message_bus_stopped") + + return nil +} + +// Health returns the current health status +func (mb *UniversalMessageBus) Health() HealthStatus { + components := make(map[string]ComponentHealth) + + // Check transport health + for transportType, transport := range mb.transports { + components[string(transportType)] = transport.Health() + } + + // Overall status + status := "healthy" + var errors []HealthError + + for name, component := range components { + if component.Status != "healthy" { + status = "degraded" + errors = append(errors, HealthError{ + Component: name, + Message: fmt.Sprintf("Component %s is %s", name, component.Status), + Timestamp: time.Now(), + Severity: "warning", + }) + } + } + + return HealthStatus{ + Status: status, + Uptime: time.Since(time.Now()), // Would track actual uptime + LastCheck: time.Now(), + Components: components, + Errors: errors, + } +} + +// GetMetrics returns current operational metrics +func (mb *UniversalMessageBus) GetMetrics() MessageBusMetrics { + metrics := mb.metrics.GetAll() + + return MessageBusMetrics{ + MessagesPublished: mb.getMetricInt64("messages_published_total"), + MessagesConsumed: mb.getMetricInt64("messages_consumed_total"), + MessagesFailed: mb.getMetricInt64("send_errors_total"), + MessagesInDLQ: int64(mb.dlq.GetMessageCount()), + ActiveSubscriptions: len(mb.subscriptions), + TopicCount: len(mb.topics), + AverageLatency: mb.getMetricDuration("average_latency"), + ThroughputPerSec: mb.getMetricFloat64("throughput_per_second"), + ErrorRate: mb.getMetricFloat64("error_rate"), + MemoryUsage: mb.getMetricInt64("memory_usage_bytes"), + CPUUsage: mb.getMetricFloat64("cpu_usage_percent"), + } +} + +// GetSubscriptions returns all active subscriptions +func (mb *UniversalMessageBus) GetSubscriptions() []*Subscription { + mb.mu.RLock() + defer mb.mu.RUnlock() + + subscriptions := make([]*Subscription, 0, len(mb.subscriptions)) + for _, sub := range mb.subscriptions { + subscriptions = append(subscriptions, sub) + } + + return subscriptions +} + +// GetActiveConnections returns the number of active connections +func (mb *UniversalMessageBus) GetActiveConnections() int { + count := 0 + for _, transport := range mb.transports { + metrics := transport.GetMetrics() + count += metrics.Connections + } + return count +} + +// Helper methods + +func (mb *UniversalMessageBus) validateMessage(msg *Message) error { + if msg == nil { + return fmt.Errorf("message is nil") + } + if msg.Topic == "" { + return fmt.Errorf("message topic is empty") + } + if msg.Source == "" { + return fmt.Errorf("message source is empty") + } + if msg.Data == nil { + return fmt.Errorf("message data is nil") + } + return nil +} + +func (mb *UniversalMessageBus) addMessageToTopic(msg *Message) error { + mb.mu.RLock() + topic, exists := mb.topics[msg.Topic] + mb.mu.RUnlock() + + if !exists { + // Create topic automatically + config := TopicConfig{ + Persistent: true, + RetentionPolicy: RetentionPolicy{MaxMessages: 10000, MaxAge: 24 * time.Hour}, + } + if err := mb.CreateTopic(msg.Topic, config); err != nil { + return err + } + topic = mb.topics[msg.Topic] + } + + topic.mu.Lock() + defer topic.mu.Unlock() + + stored := StoredMessage{ + Message: msg, + Stored: time.Now(), + Processed: false, + } + + topic.Messages = append(topic.Messages, stored) + topic.LastActivity = time.Now() + + // Apply retention policy + mb.applyRetentionPolicy(topic) + + return nil +} + +func (mb *UniversalMessageBus) addSubscriberToTopic(topicName string, subscription *Subscription) { + mb.mu.RLock() + topic, exists := mb.topics[topicName] + mb.mu.RUnlock() + + if !exists { + // Create topic automatically + config := TopicConfig{Persistent: false} + mb.CreateTopic(topicName, config) + topic = mb.topics[topicName] + } + + topic.mu.Lock() + topic.Subscribers = append(topic.Subscribers, subscription) + topic.mu.Unlock() +} + +func (mb *UniversalMessageBus) removeSubscriberFromTopic(topicName, subscriptionID string) { + mb.mu.RLock() + topic, exists := mb.topics[topicName] + mb.mu.RUnlock() + + if !exists { + return + } + + topic.mu.Lock() + defer topic.mu.Unlock() + + for i, sub := range topic.Subscribers { + if sub.ID == subscriptionID { + topic.Subscribers = append(topic.Subscribers[:i], topic.Subscribers[i+1:]...) + break + } + } +} + +func (mb *UniversalMessageBus) deliverToSubscribers(ctx context.Context, msg *Message) { + mb.mu.RLock() + topic, exists := mb.topics[msg.Topic] + mb.mu.RUnlock() + + if !exists { + return + } + + topic.mu.RLock() + subscribers := make([]*Subscription, len(topic.Subscribers)) + copy(subscribers, topic.Subscribers) + topic.mu.RUnlock() + + for _, sub := range subscribers { + sub.mu.RLock() + if !sub.active { + sub.mu.RUnlock() + continue + } + + // Apply filter if present + if sub.Filter != nil && !sub.Filter(msg) { + sub.mu.RUnlock() + continue + } + + handler := sub.Handler + sub.mu.RUnlock() + + // Deliver message in goroutine + go func(subscription *Subscription, message *Message) { + defer func() { + if r := recover(); r != nil { + mb.metrics.IncrementCounter("handler_panics_total") + } + }() + + if err := handler(ctx, message); err != nil { + mb.metrics.IncrementCounter("handler_errors_total") + if mb.config.EnableDLQ { + mb.dlq.AddMessage(message.Topic, message) + } + } else { + mb.metrics.IncrementCounter("messages_consumed_total") + } + }(sub, msg) + } +} + +func (mb *UniversalMessageBus) applyRetentionPolicy(topic *Topic) { + policy := topic.Config.RetentionPolicy + + // Remove old messages + if policy.MaxAge > 0 { + cutoff := time.Now().Add(-policy.MaxAge) + filtered := make([]StoredMessage, 0) + for _, stored := range topic.Messages { + if stored.Stored.After(cutoff) { + filtered = append(filtered, stored) + } + } + topic.Messages = filtered + } + + // Limit number of messages + if policy.MaxMessages > 0 && len(topic.Messages) > policy.MaxMessages { + topic.Messages = topic.Messages[len(topic.Messages)-policy.MaxMessages:] + } +} + +func (mb *UniversalMessageBus) healthCheckLoop() { + ticker := time.NewTicker(mb.config.HealthCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + mb.performHealthCheck() + case <-mb.ctx.Done(): + return + } + } +} + +func (mb *UniversalMessageBus) cleanupLoop() { + ticker := time.NewTicker(mb.config.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + mb.performCleanup() + case <-mb.ctx.Done(): + return + } + } +} + +func (mb *UniversalMessageBus) metricsLoop() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + mb.updateMetrics() + case <-mb.ctx.Done(): + return + } + } +} + +func (mb *UniversalMessageBus) performHealthCheck() { + // Check all transports + for _, transport := range mb.transports { + health := transport.Health() + mb.metrics.RecordGauge(fmt.Sprintf("transport_%s_healthy", health.Component), + map[string]float64{"healthy": 1, "unhealthy": 0, "degraded": 0.5}[health.Status]) + } +} + +func (mb *UniversalMessageBus) performCleanup() { + // Clean up processed messages in topics + mb.mu.RLock() + topics := make([]*Topic, 0, len(mb.topics)) + for _, topic := range mb.topics { + topics = append(topics, topic) + } + mb.mu.RUnlock() + + for _, topic := range topics { + topic.mu.Lock() + mb.applyRetentionPolicy(topic) + topic.mu.Unlock() + } + + // Clean up DLQ + mb.dlq.Cleanup(time.Hour * 24) // Clean messages older than 24 hours +} + +func (mb *UniversalMessageBus) updateMetrics() { + // Update throughput metrics + publishedCount := mb.getMetricInt64("messages_published_total") + if publishedCount > 0 { + // Calculate per-second rate (simplified) + mb.metrics.RecordGauge("throughput_per_second", float64(publishedCount)/60.0) + } + + // Update error rate + errorCount := mb.getMetricInt64("send_errors_total") + totalCount := publishedCount + if totalCount > 0 { + errorRate := float64(errorCount) / float64(totalCount) + mb.metrics.RecordGauge("error_rate", errorRate) + } +} + +func (mb *UniversalMessageBus) getMetricInt64(key string) int64 { + if val, ok := mb.metrics.Get(key).(int64); ok { + return val + } + return 0 +} + +func (mb *UniversalMessageBus) getMetricFloat64(key string) float64 { + if val, ok := mb.metrics.Get(key).(float64); ok { + return val + } + return 0 +} + +func (mb *UniversalMessageBus) getMetricDuration(key string) time.Duration { + if val, ok := mb.metrics.Get(key).(time.Duration); ok { + return val + } + return 0 +} diff --git a/pkg/transport/persistence.go b/pkg/transport/persistence.go new file mode 100644 index 0000000..642a3c1 --- /dev/null +++ b/pkg/transport/persistence.go @@ -0,0 +1,622 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "sort" + "sync" + "time" +) + +// FilePersistenceLayer implements file-based message persistence +type FilePersistenceLayer struct { + basePath string + maxFileSize int64 + maxFiles int + compression bool + encryption EncryptionConfig + mu sync.RWMutex +} + +// EncryptionConfig configures message encryption at rest +type EncryptionConfig struct { + Enabled bool + Algorithm string + Key []byte +} + +// PersistedMessage represents a message stored on disk +type PersistedMessage struct { + ID string `json:"id"` + Topic string `json:"topic"` + Message *Message `json:"message"` + Stored time.Time `json:"stored"` + Metadata map[string]interface{} `json:"metadata"` + Encrypted bool `json:"encrypted"` +} + +// PersistenceMetrics tracks persistence layer statistics +type PersistenceMetrics struct { + MessagesStored int64 + MessagesRetrieved int64 + MessagesDeleted int64 + StorageSize int64 + FileCount int + LastCleanup time.Time + Errors int64 +} + +// NewFilePersistenceLayer creates a new file-based persistence layer +func NewFilePersistenceLayer(basePath string) *FilePersistenceLayer { + return &FilePersistenceLayer{ + basePath: basePath, + maxFileSize: 100 * 1024 * 1024, // 100MB default + maxFiles: 1000, + compression: false, + } +} + +// SetMaxFileSize configures the maximum file size +func (fpl *FilePersistenceLayer) SetMaxFileSize(size int64) { + fpl.mu.Lock() + defer fpl.mu.Unlock() + fpl.maxFileSize = size +} + +// SetMaxFiles configures the maximum number of files +func (fpl *FilePersistenceLayer) SetMaxFiles(count int) { + fpl.mu.Lock() + defer fpl.mu.Unlock() + fpl.maxFiles = count +} + +// EnableCompression enables/disables compression +func (fpl *FilePersistenceLayer) EnableCompression(enabled bool) { + fpl.mu.Lock() + defer fpl.mu.Unlock() + fpl.compression = enabled +} + +// SetEncryption configures encryption settings +func (fpl *FilePersistenceLayer) SetEncryption(config EncryptionConfig) { + fpl.mu.Lock() + defer fpl.mu.Unlock() + fpl.encryption = config +} + +// Store persists a message to disk +func (fpl *FilePersistenceLayer) Store(msg *Message) error { + fpl.mu.Lock() + defer fpl.mu.Unlock() + + // Create directory if it doesn't exist + topicDir := filepath.Join(fpl.basePath, msg.Topic) + if err := os.MkdirAll(topicDir, 0755); err != nil { + return fmt.Errorf("failed to create topic directory: %w", err) + } + + // Create persisted message + persistedMsg := &PersistedMessage{ + ID: msg.ID, + Topic: msg.Topic, + Message: msg, + Stored: time.Now(), + Metadata: make(map[string]interface{}), + } + + // Serialize message + data, err := json.Marshal(persistedMsg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Apply encryption if enabled + if fpl.encryption.Enabled { + encryptedData, err := fpl.encrypt(data) + if err != nil { + return fmt.Errorf("encryption failed: %w", err) + } + data = encryptedData + persistedMsg.Encrypted = true + } + + // Apply compression if enabled + if fpl.compression { + compressedData, err := fpl.compress(data) + if err != nil { + return fmt.Errorf("compression failed: %w", err) + } + data = compressedData + } + + // Find appropriate file to write to + filename, err := fpl.getWritableFile(topicDir, len(data)) + if err != nil { + return fmt.Errorf("failed to get writable file: %w", err) + } + + // Write to file + file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + // Write length prefix and data + lengthPrefix := fmt.Sprintf("%d\n", len(data)) + if _, err := file.WriteString(lengthPrefix); err != nil { + return fmt.Errorf("failed to write length prefix: %w", err) + } + if _, err := file.Write(data); err != nil { + return fmt.Errorf("failed to write data: %w", err) + } + + return nil +} + +// Retrieve loads a message from disk by ID +func (fpl *FilePersistenceLayer) Retrieve(id string) (*Message, error) { + fpl.mu.RLock() + defer fpl.mu.RUnlock() + + // Search all topic directories + topicDirs, err := fpl.getTopicDirectories() + if err != nil { + return nil, fmt.Errorf("failed to get topic directories: %w", err) + } + + for _, topicDir := range topicDirs { + files, err := fpl.getTopicFiles(topicDir) + if err != nil { + continue + } + + for _, file := range files { + msg, err := fpl.findMessageInFile(file, id) + if err != nil { + continue + } + if msg != nil { + return msg, nil + } + } + } + + return nil, fmt.Errorf("message not found: %s", id) +} + +// Delete removes a message from disk by ID +func (fpl *FilePersistenceLayer) Delete(id string) error { + fpl.mu.Lock() + defer fpl.mu.Unlock() + + // This is a simplified implementation + // In a production system, you might want to mark messages as deleted + // and compact files periodically instead of rewriting entire files + return fmt.Errorf("delete operation not yet implemented") +} + +// List returns messages for a topic with optional limit +func (fpl *FilePersistenceLayer) List(topic string, limit int) ([]*Message, error) { + fpl.mu.RLock() + defer fpl.mu.RUnlock() + + topicDir := filepath.Join(fpl.basePath, topic) + if _, err := os.Stat(topicDir); os.IsNotExist(err) { + return []*Message{}, nil + } + + files, err := fpl.getTopicFiles(topicDir) + if err != nil { + return nil, fmt.Errorf("failed to get topic files: %w", err) + } + + var messages []*Message + count := 0 + + // Read files in chronological order (newest first) + sort.Slice(files, func(i, j int) bool { + infoI, _ := os.Stat(files[i]) + infoJ, _ := os.Stat(files[j]) + return infoI.ModTime().After(infoJ.ModTime()) + }) + + for _, file := range files { + fileMessages, err := fpl.readMessagesFromFile(file) + if err != nil { + continue + } + + for _, msg := range fileMessages { + messages = append(messages, msg) + count++ + if limit > 0 && count >= limit { + break + } + } + + if limit > 0 && count >= limit { + break + } + } + + return messages, nil +} + +// Cleanup removes messages older than maxAge +func (fpl *FilePersistenceLayer) Cleanup(maxAge time.Duration) error { + fpl.mu.Lock() + defer fpl.mu.Unlock() + + cutoff := time.Now().Add(-maxAge) + + topicDirs, err := fpl.getTopicDirectories() + if err != nil { + return fmt.Errorf("failed to get topic directories: %w", err) + } + + for _, topicDir := range topicDirs { + files, err := fpl.getTopicFiles(topicDir) + if err != nil { + continue + } + + for _, file := range files { + // Check file modification time + info, err := os.Stat(file) + if err != nil { + continue + } + + if info.ModTime().Before(cutoff) { + os.Remove(file) + } + } + + // Remove empty topic directories + if isEmpty, _ := fpl.isDirectoryEmpty(topicDir); isEmpty { + os.Remove(topicDir) + } + } + + return nil +} + +// GetMetrics returns persistence layer metrics +func (fpl *FilePersistenceLayer) GetMetrics() (PersistenceMetrics, error) { + fpl.mu.RLock() + defer fpl.mu.RUnlock() + + metrics := PersistenceMetrics{} + + // Calculate storage size and file count + err := filepath.Walk(fpl.basePath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() { + metrics.FileCount++ + metrics.StorageSize += info.Size() + } + return nil + }) + + return metrics, err +} + +// Private helper methods + +func (fpl *FilePersistenceLayer) getWritableFile(topicDir string, dataSize int) (string, error) { + files, err := fpl.getTopicFiles(topicDir) + if err != nil { + return "", err + } + + // Find a file with enough space + for _, file := range files { + info, err := os.Stat(file) + if err != nil { + continue + } + + if info.Size()+int64(dataSize) <= fpl.maxFileSize { + return file, nil + } + } + + // Create new file + timestamp := time.Now().Format("20060102_150405") + filename := filepath.Join(topicDir, fmt.Sprintf("messages_%s.dat", timestamp)) + return filename, nil +} + +func (fpl *FilePersistenceLayer) getTopicDirectories() ([]string, error) { + entries, err := ioutil.ReadDir(fpl.basePath) + if err != nil { + return nil, err + } + + var dirs []string + for _, entry := range entries { + if entry.IsDir() { + dirs = append(dirs, filepath.Join(fpl.basePath, entry.Name())) + } + } + + return dirs, nil +} + +func (fpl *FilePersistenceLayer) getTopicFiles(topicDir string) ([]string, error) { + entries, err := ioutil.ReadDir(topicDir) + if err != nil { + return nil, err + } + + var files []string + for _, entry := range entries { + if !entry.IsDir() && filepath.Ext(entry.Name()) == ".dat" { + files = append(files, filepath.Join(topicDir, entry.Name())) + } + } + + return files, nil +} + +func (fpl *FilePersistenceLayer) findMessageInFile(filename, messageID string) (*Message, error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + + data, err := ioutil.ReadAll(file) + if err != nil { + return nil, err + } + + // Parse messages from file data + messages, err := fpl.parseFileData(data) + if err != nil { + return nil, err + } + + for _, msg := range messages { + if msg.ID == messageID { + return msg, nil + } + } + + return nil, nil +} + +func (fpl *FilePersistenceLayer) readMessagesFromFile(filename string) ([]*Message, error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + + data, err := ioutil.ReadAll(file) + if err != nil { + return nil, err + } + + return fpl.parseFileData(data) +} + +func (fpl *FilePersistenceLayer) parseFileData(data []byte) ([]*Message, error) { + var messages []*Message + offset := 0 + + for offset < len(data) { + // Read length prefix + lengthEnd := -1 + for i := offset; i < len(data); i++ { + if data[i] == '\n' { + lengthEnd = i + break + } + } + + if lengthEnd == -1 { + break // No more complete messages + } + + lengthStr := string(data[offset:lengthEnd]) + var messageLength int + if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil { + break // Invalid length prefix + } + + messageStart := lengthEnd + 1 + messageEnd := messageStart + messageLength + + if messageEnd > len(data) { + break // Incomplete message + } + + messageData := data[messageStart:messageEnd] + + // Apply decompression if needed + if fpl.compression { + decompressed, err := fpl.decompress(messageData) + if err != nil { + offset = messageEnd + continue + } + messageData = decompressed + } + + // Apply decryption if needed + if fpl.encryption.Enabled { + decrypted, err := fpl.decrypt(messageData) + if err != nil { + offset = messageEnd + continue + } + messageData = decrypted + } + + // Parse message + var persistedMsg PersistedMessage + if err := json.Unmarshal(messageData, &persistedMsg); err != nil { + offset = messageEnd + continue + } + + messages = append(messages, persistedMsg.Message) + offset = messageEnd + } + + return messages, nil +} + +func (fpl *FilePersistenceLayer) isDirectoryEmpty(dir string) (bool, error) { + entries, err := ioutil.ReadDir(dir) + if err != nil { + return false, err + } + return len(entries) == 0, nil +} + +func (fpl *FilePersistenceLayer) encrypt(data []byte) ([]byte, error) { + // Placeholder for encryption implementation + // In a real implementation, you would use proper encryption libraries + return data, nil +} + +func (fpl *FilePersistenceLayer) decrypt(data []byte) ([]byte, error) { + // Placeholder for decryption implementation + return data, nil +} + +func (fpl *FilePersistenceLayer) compress(data []byte) ([]byte, error) { + // Placeholder for compression implementation + // In a real implementation, you would use libraries like gzip + return data, nil +} + +func (fpl *FilePersistenceLayer) decompress(data []byte) ([]byte, error) { + // Placeholder for decompression implementation + return data, nil +} + +// InMemoryPersistenceLayer implements in-memory persistence for testing/development +type InMemoryPersistenceLayer struct { + messages map[string]*Message + topics map[string][]string + mu sync.RWMutex +} + +// NewInMemoryPersistenceLayer creates a new in-memory persistence layer +func NewInMemoryPersistenceLayer() *InMemoryPersistenceLayer { + return &InMemoryPersistenceLayer{ + messages: make(map[string]*Message), + topics: make(map[string][]string), + } +} + +// Store stores a message in memory +func (impl *InMemoryPersistenceLayer) Store(msg *Message) error { + impl.mu.Lock() + defer impl.mu.Unlock() + + impl.messages[msg.ID] = msg + + if _, exists := impl.topics[msg.Topic]; !exists { + impl.topics[msg.Topic] = make([]string, 0) + } + impl.topics[msg.Topic] = append(impl.topics[msg.Topic], msg.ID) + + return nil +} + +// Retrieve retrieves a message from memory +func (impl *InMemoryPersistenceLayer) Retrieve(id string) (*Message, error) { + impl.mu.RLock() + defer impl.mu.RUnlock() + + msg, exists := impl.messages[id] + if !exists { + return nil, fmt.Errorf("message not found: %s", id) + } + + return msg, nil +} + +// Delete removes a message from memory +func (impl *InMemoryPersistenceLayer) Delete(id string) error { + impl.mu.Lock() + defer impl.mu.Unlock() + + msg, exists := impl.messages[id] + if !exists { + return fmt.Errorf("message not found: %s", id) + } + + delete(impl.messages, id) + + // Remove from topic index + if messageIDs, exists := impl.topics[msg.Topic]; exists { + for i, msgID := range messageIDs { + if msgID == id { + impl.topics[msg.Topic] = append(messageIDs[:i], messageIDs[i+1:]...) + break + } + } + } + + return nil +} + +// List returns messages for a topic +func (impl *InMemoryPersistenceLayer) List(topic string, limit int) ([]*Message, error) { + impl.mu.RLock() + defer impl.mu.RUnlock() + + messageIDs, exists := impl.topics[topic] + if !exists { + return []*Message{}, nil + } + + var messages []*Message + count := 0 + + for _, msgID := range messageIDs { + if limit > 0 && count >= limit { + break + } + + if msg, exists := impl.messages[msgID]; exists { + messages = append(messages, msg) + count++ + } + } + + return messages, nil +} + +// Cleanup removes messages older than maxAge +func (impl *InMemoryPersistenceLayer) Cleanup(maxAge time.Duration) error { + impl.mu.Lock() + defer impl.mu.Unlock() + + cutoff := time.Now().Add(-maxAge) + var toDelete []string + + for id, msg := range impl.messages { + if msg.Timestamp.Before(cutoff) { + toDelete = append(toDelete, id) + } + } + + for _, id := range toDelete { + impl.Delete(id) + } + + return nil +} diff --git a/pkg/transport/router.go b/pkg/transport/router.go new file mode 100644 index 0000000..b243462 --- /dev/null +++ b/pkg/transport/router.go @@ -0,0 +1,478 @@ +package transport + +import ( + "fmt" + "math/rand" + "sort" + "sync" + "time" +) + +// MessageRouter handles intelligent message routing and transport selection +type MessageRouter struct { + rules []RoutingRule + fallback TransportType + loadBalancer LoadBalancer + mu sync.RWMutex +} + +// RoutingRule defines message routing logic +type RoutingRule struct { + ID string + Name string + Condition MessageFilter + Transport TransportType + Priority int + Enabled bool + Created time.Time + LastUsed time.Time + UsageCount int64 +} + +// RouteMessage selects the appropriate transport for a message +func (mr *MessageRouter) RouteMessage(msg *Message, transports map[TransportType]Transport) (Transport, error) { + mr.mu.RLock() + defer mr.mu.RUnlock() + + // Find matching rules (sorted by priority) + matchingRules := mr.findMatchingRules(msg) + + // Try each matching rule in priority order + for _, rule := range matchingRules { + if transport, exists := transports[rule.Transport]; exists { + // Check transport health + if health := transport.Health(); health.Status == "healthy" { + mr.updateRuleUsage(rule.ID) + return transport, nil + } + } + } + + // Use load balancer for available transports + if mr.loadBalancer != nil { + availableTransports := mr.getHealthyTransports(transports) + if len(availableTransports) > 0 { + selectedType := mr.loadBalancer.SelectTransport(availableTransports, msg) + if transport, exists := transports[selectedType]; exists { + return transport, nil + } + } + } + + // Fall back to default transport + if fallbackTransport, exists := transports[mr.fallback]; exists { + if health := fallbackTransport.Health(); health.Status != "unhealthy" { + return fallbackTransport, nil + } + } + + return nil, fmt.Errorf("no available transport for message") +} + +// AddRule adds a new routing rule +func (mr *MessageRouter) AddRule(rule RoutingRule) { + mr.mu.Lock() + defer mr.mu.Unlock() + + if rule.ID == "" { + rule.ID = fmt.Sprintf("rule_%d", time.Now().UnixNano()) + } + rule.Created = time.Now() + rule.Enabled = true + + mr.rules = append(mr.rules, rule) + mr.sortRulesByPriority() +} + +// RemoveRule removes a routing rule by ID +func (mr *MessageRouter) RemoveRule(ruleID string) bool { + mr.mu.Lock() + defer mr.mu.Unlock() + + for i, rule := range mr.rules { + if rule.ID == ruleID { + mr.rules = append(mr.rules[:i], mr.rules[i+1:]...) + return true + } + } + return false +} + +// UpdateRule updates an existing routing rule +func (mr *MessageRouter) UpdateRule(ruleID string, updates func(*RoutingRule)) bool { + mr.mu.Lock() + defer mr.mu.Unlock() + + for i := range mr.rules { + if mr.rules[i].ID == ruleID { + updates(&mr.rules[i]) + mr.sortRulesByPriority() + return true + } + } + return false +} + +// GetRules returns all routing rules +func (mr *MessageRouter) GetRules() []RoutingRule { + mr.mu.RLock() + defer mr.mu.RUnlock() + + rules := make([]RoutingRule, len(mr.rules)) + copy(rules, mr.rules) + return rules +} + +// EnableRule enables a routing rule +func (mr *MessageRouter) EnableRule(ruleID string) bool { + return mr.UpdateRule(ruleID, func(rule *RoutingRule) { + rule.Enabled = true + }) +} + +// DisableRule disables a routing rule +func (mr *MessageRouter) DisableRule(ruleID string) bool { + return mr.UpdateRule(ruleID, func(rule *RoutingRule) { + rule.Enabled = false + }) +} + +// SetFallbackTransport sets the fallback transport type +func (mr *MessageRouter) SetFallbackTransport(transportType TransportType) { + mr.mu.Lock() + defer mr.mu.Unlock() + mr.fallback = transportType +} + +// SetLoadBalancer sets the load balancer +func (mr *MessageRouter) SetLoadBalancer(lb LoadBalancer) { + mr.mu.Lock() + defer mr.mu.Unlock() + mr.loadBalancer = lb +} + +// Private helper methods + +func (mr *MessageRouter) findMatchingRules(msg *Message) []RoutingRule { + var matching []RoutingRule + + for _, rule := range mr.rules { + if rule.Enabled && (rule.Condition == nil || rule.Condition(msg)) { + matching = append(matching, rule) + } + } + + return matching +} + +func (mr *MessageRouter) sortRulesByPriority() { + sort.Slice(mr.rules, func(i, j int) bool { + return mr.rules[i].Priority > mr.rules[j].Priority + }) +} + +func (mr *MessageRouter) updateRuleUsage(ruleID string) { + for i := range mr.rules { + if mr.rules[i].ID == ruleID { + mr.rules[i].LastUsed = time.Now() + mr.rules[i].UsageCount++ + break + } + } +} + +func (mr *MessageRouter) getHealthyTransports(transports map[TransportType]Transport) []TransportType { + var healthy []TransportType + + for transportType, transport := range transports { + if health := transport.Health(); health.Status == "healthy" { + healthy = append(healthy, transportType) + } + } + + return healthy +} + +// LoadBalancer implementations + +// RoundRobinLoadBalancer implements round-robin load balancing +type RoundRobinLoadBalancer struct { + counter int64 + mu sync.Mutex +} + +func NewRoundRobinLoadBalancer() *RoundRobinLoadBalancer { + return &RoundRobinLoadBalancer{} +} + +func (lb *RoundRobinLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType { + if len(transports) == 0 { + return "" + } + + lb.mu.Lock() + defer lb.mu.Unlock() + + selected := transports[lb.counter%int64(len(transports))] + lb.counter++ + return selected +} + +func (lb *RoundRobinLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) { + // Round-robin doesn't use stats +} + +// WeightedLoadBalancer implements weighted load balancing based on performance +type WeightedLoadBalancer struct { + stats map[TransportType]*TransportStats + mu sync.RWMutex +} + +type TransportStats struct { + TotalRequests int64 + SuccessRequests int64 + TotalLatency time.Duration + LastUpdate time.Time + Weight float64 +} + +func NewWeightedLoadBalancer() *WeightedLoadBalancer { + return &WeightedLoadBalancer{ + stats: make(map[TransportType]*TransportStats), + } +} + +func (lb *WeightedLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType { + if len(transports) == 0 { + return "" + } + + lb.mu.RLock() + defer lb.mu.RUnlock() + + // Calculate weights and select based on weighted random selection + totalWeight := 0.0 + weights := make(map[TransportType]float64) + + for _, transport := range transports { + weight := lb.calculateWeight(transport) + weights[transport] = weight + totalWeight += weight + } + + if totalWeight == 0 { + // Fall back to random selection + return transports[rand.Intn(len(transports))] + } + + // Weighted random selection + target := rand.Float64() * totalWeight + current := 0.0 + + for _, transport := range transports { + current += weights[transport] + if current >= target { + return transport + } + } + + // Fallback (shouldn't happen) + return transports[0] +} + +func (lb *WeightedLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) { + lb.mu.Lock() + defer lb.mu.Unlock() + + stats, exists := lb.stats[transport] + if !exists { + stats = &TransportStats{ + Weight: 1.0, // Default weight + } + lb.stats[transport] = stats + } + + stats.TotalRequests++ + stats.TotalLatency += latency + stats.LastUpdate = time.Now() + + if success { + stats.SuccessRequests++ + } + + // Recalculate weight based on performance + stats.Weight = lb.calculateWeight(transport) +} + +func (lb *WeightedLoadBalancer) calculateWeight(transport TransportType) float64 { + stats, exists := lb.stats[transport] + if !exists { + return 1.0 // Default weight for unknown transports + } + + if stats.TotalRequests == 0 { + return 1.0 + } + + // Calculate success rate + successRate := float64(stats.SuccessRequests) / float64(stats.TotalRequests) + + // Calculate average latency + avgLatency := stats.TotalLatency / time.Duration(stats.TotalRequests) + + // Weight formula: success rate / (latency factor) + // Lower latency and higher success rate = higher weight + latencyFactor := float64(avgLatency) / float64(time.Millisecond) + if latencyFactor < 1 { + latencyFactor = 1 + } + + weight := successRate / latencyFactor + + // Ensure minimum weight + if weight < 0.1 { + weight = 0.1 + } + + return weight +} + +// LeastLatencyLoadBalancer selects the transport with the lowest latency +type LeastLatencyLoadBalancer struct { + stats map[TransportType]*LatencyStats + mu sync.RWMutex +} + +type LatencyStats struct { + RecentLatencies []time.Duration + MaxSamples int + LastUpdate time.Time +} + +func NewLeastLatencyLoadBalancer() *LeastLatencyLoadBalancer { + return &LeastLatencyLoadBalancer{ + stats: make(map[TransportType]*LatencyStats), + } +} + +func (lb *LeastLatencyLoadBalancer) SelectTransport(transports []TransportType, msg *Message) TransportType { + if len(transports) == 0 { + return "" + } + + lb.mu.RLock() + defer lb.mu.RUnlock() + + bestTransport := transports[0] + bestLatency := time.Hour // Large initial value + + for _, transport := range transports { + avgLatency := lb.getAverageLatency(transport) + if avgLatency < bestLatency { + bestLatency = avgLatency + bestTransport = transport + } + } + + return bestTransport +} + +func (lb *LeastLatencyLoadBalancer) UpdateStats(transport TransportType, latency time.Duration, success bool) { + if !success { + return // Only track successful requests + } + + lb.mu.Lock() + defer lb.mu.Unlock() + + stats, exists := lb.stats[transport] + if !exists { + stats = &LatencyStats{ + RecentLatencies: make([]time.Duration, 0), + MaxSamples: 10, // Keep last 10 samples + } + lb.stats[transport] = stats + } + + // Add new latency sample + stats.RecentLatencies = append(stats.RecentLatencies, latency) + + // Keep only recent samples + if len(stats.RecentLatencies) > stats.MaxSamples { + stats.RecentLatencies = stats.RecentLatencies[1:] + } + + stats.LastUpdate = time.Now() +} + +func (lb *LeastLatencyLoadBalancer) getAverageLatency(transport TransportType) time.Duration { + stats, exists := lb.stats[transport] + if !exists || len(stats.RecentLatencies) == 0 { + return time.Millisecond * 100 // Default estimate + } + + total := time.Duration(0) + for _, latency := range stats.RecentLatencies { + total += latency + } + + return total / time.Duration(len(stats.RecentLatencies)) +} + +// Common routing rule factory functions + +// CreateTopicRule creates a rule based on message topic +func CreateTopicRule(name string, topic string, transport TransportType, priority int) RoutingRule { + return RoutingRule{ + Name: name, + Condition: func(msg *Message) bool { return msg.Topic == topic }, + Transport: transport, + Priority: priority, + } +} + +// CreateTopicPatternRule creates a rule based on topic pattern matching +func CreateTopicPatternRule(name string, pattern string, transport TransportType, priority int) RoutingRule { + return RoutingRule{ + Name: name, + Condition: func(msg *Message) bool { + // Simple pattern matching (can be enhanced with regex) + return msg.Topic == pattern || + (len(pattern) > 0 && pattern[len(pattern)-1] == '*' && + len(msg.Topic) >= len(pattern)-1 && + msg.Topic[:len(pattern)-1] == pattern[:len(pattern)-1]) + }, + Transport: transport, + Priority: priority, + } +} + +// CreatePriorityRule creates a rule based on message priority +func CreatePriorityRule(name string, msgPriority MessagePriority, transport TransportType, priority int) RoutingRule { + return RoutingRule{ + Name: name, + Condition: func(msg *Message) bool { return msg.Priority == msgPriority }, + Transport: transport, + Priority: priority, + } +} + +// CreateTypeRule creates a rule based on message type +func CreateTypeRule(name string, msgType MessageType, transport TransportType, priority int) RoutingRule { + return RoutingRule{ + Name: name, + Condition: func(msg *Message) bool { return msg.Type == msgType }, + Transport: transport, + Priority: priority, + } +} + +// CreateSourceRule creates a rule based on message source +func CreateSourceRule(name string, source string, transport TransportType, priority int) RoutingRule { + return RoutingRule{ + Name: name, + Condition: func(msg *Message) bool { return msg.Source == source }, + Transport: transport, + Priority: priority, + } +} diff --git a/pkg/transport/serialization.go b/pkg/transport/serialization.go new file mode 100644 index 0000000..b4bde23 --- /dev/null +++ b/pkg/transport/serialization.go @@ -0,0 +1,566 @@ +package transport + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "fmt" + "io" + "sync" +) + +// SerializationFormat defines supported serialization formats +type SerializationFormat string + +const ( + SerializationJSON SerializationFormat = "json" + SerializationMsgPack SerializationFormat = "msgpack" + SerializationProtobuf SerializationFormat = "protobuf" + SerializationAvro SerializationFormat = "avro" +) + +// CompressionType defines supported compression algorithms +type CompressionType string + +const ( + CompressionNone CompressionType = "none" + CompressionGZip CompressionType = "gzip" + CompressionLZ4 CompressionType = "lz4" + CompressionSnappy CompressionType = "snappy" +) + +// SerializationConfig configures serialization behavior +type SerializationConfig struct { + Format SerializationFormat + Compression CompressionType + Encryption bool + Validation bool +} + +// SerializedMessage represents a serialized message with metadata +type SerializedMessage struct { + Format SerializationFormat `json:"format"` + Compression CompressionType `json:"compression"` + Encrypted bool `json:"encrypted"` + Checksum string `json:"checksum"` + Data []byte `json:"data"` + Size int `json:"size"` + Timestamp int64 `json:"timestamp"` +} + +// Serializer interface defines serialization operations +type Serializer interface { + Serialize(msg *Message) (*SerializedMessage, error) + Deserialize(serialized *SerializedMessage) (*Message, error) + GetFormat() SerializationFormat + GetConfig() SerializationConfig +} + +// SerializationLayer manages multiple serializers and format selection +type SerializationLayer struct { + serializers map[SerializationFormat]Serializer + defaultFormat SerializationFormat + compressor Compressor + encryptor Encryptor + validator Validator + mu sync.RWMutex +} + +// Compressor interface for data compression +type Compressor interface { + Compress(data []byte, algorithm CompressionType) ([]byte, error) + Decompress(data []byte, algorithm CompressionType) ([]byte, error) + GetSupportedAlgorithms() []CompressionType +} + +// Encryptor interface for data encryption +type Encryptor interface { + Encrypt(data []byte) ([]byte, error) + Decrypt(data []byte) ([]byte, error) + IsEnabled() bool +} + +// Validator interface for data validation +type Validator interface { + Validate(msg *Message) error + GenerateChecksum(data []byte) string + VerifyChecksum(data []byte, checksum string) bool +} + +// NewSerializationLayer creates a new serialization layer +func NewSerializationLayer() *SerializationLayer { + sl := &SerializationLayer{ + serializers: make(map[SerializationFormat]Serializer), + defaultFormat: SerializationJSON, + compressor: NewDefaultCompressor(), + validator: NewDefaultValidator(), + } + + // Register default serializers + sl.RegisterSerializer(NewJSONSerializer()) + + return sl +} + +// RegisterSerializer registers a new serializer +func (sl *SerializationLayer) RegisterSerializer(serializer Serializer) { + sl.mu.Lock() + defer sl.mu.Unlock() + sl.serializers[serializer.GetFormat()] = serializer +} + +// SetDefaultFormat sets the default serialization format +func (sl *SerializationLayer) SetDefaultFormat(format SerializationFormat) { + sl.mu.Lock() + defer sl.mu.Unlock() + sl.defaultFormat = format +} + +// SetCompressor sets the compression handler +func (sl *SerializationLayer) SetCompressor(compressor Compressor) { + sl.mu.Lock() + defer sl.mu.Unlock() + sl.compressor = compressor +} + +// SetEncryptor sets the encryption handler +func (sl *SerializationLayer) SetEncryptor(encryptor Encryptor) { + sl.mu.Lock() + defer sl.mu.Unlock() + sl.encryptor = encryptor +} + +// SetValidator sets the validation handler +func (sl *SerializationLayer) SetValidator(validator Validator) { + sl.mu.Lock() + defer sl.mu.Unlock() + sl.validator = validator +} + +// Serialize serializes a message using the specified or default format +func (sl *SerializationLayer) Serialize(msg *Message, format ...SerializationFormat) (*SerializedMessage, error) { + sl.mu.RLock() + defer sl.mu.RUnlock() + + // Determine format to use + selectedFormat := sl.defaultFormat + if len(format) > 0 { + selectedFormat = format[0] + } + + // Get serializer + serializer, exists := sl.serializers[selectedFormat] + if !exists { + return nil, fmt.Errorf("unsupported serialization format: %s", selectedFormat) + } + + // Validate message if validator is configured + if sl.validator != nil { + if err := sl.validator.Validate(msg); err != nil { + return nil, fmt.Errorf("message validation failed: %w", err) + } + } + + // Serialize message + serialized, err := serializer.Serialize(msg) + if err != nil { + return nil, fmt.Errorf("serialization failed: %w", err) + } + + // Apply compression if configured + config := serializer.GetConfig() + if config.Compression != CompressionNone && sl.compressor != nil { + compressed, err := sl.compressor.Compress(serialized.Data, config.Compression) + if err != nil { + return nil, fmt.Errorf("compression failed: %w", err) + } + serialized.Data = compressed + serialized.Compression = config.Compression + } + + // Apply encryption if configured + if config.Encryption && sl.encryptor != nil && sl.encryptor.IsEnabled() { + encrypted, err := sl.encryptor.Encrypt(serialized.Data) + if err != nil { + return nil, fmt.Errorf("encryption failed: %w", err) + } + serialized.Data = encrypted + serialized.Encrypted = true + } + + // Generate checksum + if sl.validator != nil { + serialized.Checksum = sl.validator.GenerateChecksum(serialized.Data) + } + + // Update metadata + serialized.Size = len(serialized.Data) + serialized.Timestamp = msg.Timestamp.UnixNano() + + return serialized, nil +} + +// Deserialize deserializes a message +func (sl *SerializationLayer) Deserialize(serialized *SerializedMessage) (*Message, error) { + sl.mu.RLock() + defer sl.mu.RUnlock() + + // Verify checksum if available + if sl.validator != nil && serialized.Checksum != "" { + if !sl.validator.VerifyChecksum(serialized.Data, serialized.Checksum) { + return nil, fmt.Errorf("checksum verification failed") + } + } + + data := serialized.Data + + // Apply decryption if needed + if serialized.Encrypted && sl.encryptor != nil && sl.encryptor.IsEnabled() { + decrypted, err := sl.encryptor.Decrypt(data) + if err != nil { + return nil, fmt.Errorf("decryption failed: %w", err) + } + data = decrypted + } + + // Apply decompression if needed + if serialized.Compression != CompressionNone && sl.compressor != nil { + decompressed, err := sl.compressor.Decompress(data, serialized.Compression) + if err != nil { + return nil, fmt.Errorf("decompression failed: %w", err) + } + data = decompressed + } + + // Get serializer + serializer, exists := sl.serializers[serialized.Format] + if !exists { + return nil, fmt.Errorf("unsupported serialization format: %s", serialized.Format) + } + + // Create temporary serialized message for deserializer + tempSerialized := &SerializedMessage{ + Format: serialized.Format, + Data: data, + } + + // Deserialize message + msg, err := serializer.Deserialize(tempSerialized) + if err != nil { + return nil, fmt.Errorf("deserialization failed: %w", err) + } + + // Validate deserialized message if validator is configured + if sl.validator != nil { + if err := sl.validator.Validate(msg); err != nil { + return nil, fmt.Errorf("deserialized message validation failed: %w", err) + } + } + + return msg, nil +} + +// GetSupportedFormats returns all supported serialization formats +func (sl *SerializationLayer) GetSupportedFormats() []SerializationFormat { + sl.mu.RLock() + defer sl.mu.RUnlock() + + formats := make([]SerializationFormat, 0, len(sl.serializers)) + for format := range sl.serializers { + formats = append(formats, format) + } + return formats +} + +// JSONSerializer implements JSON serialization +type JSONSerializer struct { + config SerializationConfig +} + +// NewJSONSerializer creates a new JSON serializer +func NewJSONSerializer() *JSONSerializer { + return &JSONSerializer{ + config: SerializationConfig{ + Format: SerializationJSON, + Compression: CompressionNone, + Encryption: false, + Validation: true, + }, + } +} + +// SetConfig updates the serializer configuration +func (js *JSONSerializer) SetConfig(config SerializationConfig) { + js.config = config + js.config.Format = SerializationJSON // Ensure format is correct +} + +// Serialize serializes a message to JSON +func (js *JSONSerializer) Serialize(msg *Message) (*SerializedMessage, error) { + data, err := json.Marshal(msg) + if err != nil { + return nil, fmt.Errorf("JSON marshal failed: %w", err) + } + + return &SerializedMessage{ + Format: SerializationJSON, + Compression: CompressionNone, + Encrypted: false, + Data: data, + Size: len(data), + }, nil +} + +// Deserialize deserializes a message from JSON +func (js *JSONSerializer) Deserialize(serialized *SerializedMessage) (*Message, error) { + var msg Message + if err := json.Unmarshal(serialized.Data, &msg); err != nil { + return nil, fmt.Errorf("JSON unmarshal failed: %w", err) + } + return &msg, nil +} + +// GetFormat returns the serialization format +func (js *JSONSerializer) GetFormat() SerializationFormat { + return SerializationJSON +} + +// GetConfig returns the serializer configuration +func (js *JSONSerializer) GetConfig() SerializationConfig { + return js.config +} + +// DefaultCompressor implements basic compression operations +type DefaultCompressor struct { + supportedAlgorithms []CompressionType +} + +// NewDefaultCompressor creates a new default compressor +func NewDefaultCompressor() *DefaultCompressor { + return &DefaultCompressor{ + supportedAlgorithms: []CompressionType{ + CompressionNone, + CompressionGZip, + }, + } +} + +// Compress compresses data using the specified algorithm +func (dc *DefaultCompressor) Compress(data []byte, algorithm CompressionType) ([]byte, error) { + switch algorithm { + case CompressionNone: + return data, nil + + case CompressionGZip: + var buf bytes.Buffer + writer := gzip.NewWriter(&buf) + + if _, err := writer.Write(data); err != nil { + return nil, fmt.Errorf("gzip write failed: %w", err) + } + + if err := writer.Close(); err != nil { + return nil, fmt.Errorf("gzip close failed: %w", err) + } + + return buf.Bytes(), nil + + default: + return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm) + } +} + +// Decompress decompresses data using the specified algorithm +func (dc *DefaultCompressor) Decompress(data []byte, algorithm CompressionType) ([]byte, error) { + switch algorithm { + case CompressionNone: + return data, nil + + case CompressionGZip: + reader, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("gzip reader creation failed: %w", err) + } + defer reader.Close() + + decompressed, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("gzip read failed: %w", err) + } + + return decompressed, nil + + default: + return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm) + } +} + +// GetSupportedAlgorithms returns supported compression algorithms +func (dc *DefaultCompressor) GetSupportedAlgorithms() []CompressionType { + return dc.supportedAlgorithms +} + +// DefaultValidator implements basic message validation +type DefaultValidator struct { + strictMode bool +} + +// NewDefaultValidator creates a new default validator +func NewDefaultValidator() *DefaultValidator { + return &DefaultValidator{ + strictMode: false, + } +} + +// SetStrictMode enables/disables strict validation +func (dv *DefaultValidator) SetStrictMode(enabled bool) { + dv.strictMode = enabled +} + +// Validate validates a message +func (dv *DefaultValidator) Validate(msg *Message) error { + if msg == nil { + return fmt.Errorf("message is nil") + } + + if msg.ID == "" { + return fmt.Errorf("message ID is empty") + } + + if msg.Topic == "" { + return fmt.Errorf("message topic is empty") + } + + if msg.Source == "" { + return fmt.Errorf("message source is empty") + } + + if msg.Type == "" { + return fmt.Errorf("message type is empty") + } + + if dv.strictMode { + if msg.Data == nil { + return fmt.Errorf("message data is nil") + } + + if msg.Timestamp.IsZero() { + return fmt.Errorf("message timestamp is zero") + } + } + + return nil +} + +// GenerateChecksum generates a simple checksum for data +func (dv *DefaultValidator) GenerateChecksum(data []byte) string { + // Simple checksum implementation + // In production, use a proper hash function like SHA-256 + var sum uint32 + for _, b := range data { + sum += uint32(b) + } + return fmt.Sprintf("%08x", sum) +} + +// VerifyChecksum verifies a checksum +func (dv *DefaultValidator) VerifyChecksum(data []byte, checksum string) bool { + return dv.GenerateChecksum(data) == checksum +} + +// NoOpEncryptor implements a no-operation encryptor for testing +type NoOpEncryptor struct { + enabled bool +} + +// NewNoOpEncryptor creates a new no-op encryptor +func NewNoOpEncryptor() *NoOpEncryptor { + return &NoOpEncryptor{enabled: false} +} + +// SetEnabled enables/disables the encryptor +func (noe *NoOpEncryptor) SetEnabled(enabled bool) { + noe.enabled = enabled +} + +// Encrypt returns data unchanged +func (noe *NoOpEncryptor) Encrypt(data []byte) ([]byte, error) { + return data, nil +} + +// Decrypt returns data unchanged +func (noe *NoOpEncryptor) Decrypt(data []byte) ([]byte, error) { + return data, nil +} + +// IsEnabled returns whether encryption is enabled +func (noe *NoOpEncryptor) IsEnabled() bool { + return noe.enabled +} + +// SerializationMetrics tracks serialization performance +type SerializationMetrics struct { + SerializedMessages int64 + DeserializedMessages int64 + SerializationErrors int64 + CompressionRatio float64 + AverageMessageSize int64 + TotalDataProcessed int64 +} + +// MetricsCollector collects serialization metrics +type MetricsCollector struct { + metrics SerializationMetrics + mu sync.RWMutex +} + +// NewMetricsCollector creates a new metrics collector +func NewMetricsCollector() *MetricsCollector { + return &MetricsCollector{} +} + +// RecordSerialization records a serialization operation +func (mc *MetricsCollector) RecordSerialization(originalSize, serializedSize int) { + mc.mu.Lock() + defer mc.mu.Unlock() + + mc.metrics.SerializedMessages++ + mc.metrics.TotalDataProcessed += int64(originalSize) + + // Update compression ratio + if originalSize > 0 { + ratio := float64(serializedSize) / float64(originalSize) + mc.metrics.CompressionRatio = (mc.metrics.CompressionRatio + ratio) / 2 + } + + // Update average message size + mc.metrics.AverageMessageSize = mc.metrics.TotalDataProcessed / mc.metrics.SerializedMessages +} + +// RecordDeserialization records a deserialization operation +func (mc *MetricsCollector) RecordDeserialization() { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.metrics.DeserializedMessages++ +} + +// RecordError records a serialization error +func (mc *MetricsCollector) RecordError() { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.metrics.SerializationErrors++ +} + +// GetMetrics returns current metrics +func (mc *MetricsCollector) GetMetrics() SerializationMetrics { + mc.mu.RLock() + defer mc.mu.RUnlock() + return mc.metrics +} + +// Reset resets all metrics +func (mc *MetricsCollector) Reset() { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.metrics = SerializationMetrics{} +} diff --git a/pkg/transport/tcp_transport.go b/pkg/transport/tcp_transport.go new file mode 100644 index 0000000..d2dc14a --- /dev/null +++ b/pkg/transport/tcp_transport.go @@ -0,0 +1,490 @@ +package transport + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net" + "sync" + "time" +) + +// TCPTransport implements TCP transport for remote communication +type TCPTransport struct { + address string + port int + listener net.Listener + connections map[string]net.Conn + metrics TransportMetrics + connected bool + isServer bool + receiveChan chan *Message + tlsConfig *tls.Config + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + retryConfig RetryConfig +} + +// NewTCPTransport creates a new TCP transport +func NewTCPTransport(address string, port int, isServer bool) *TCPTransport { + ctx, cancel := context.WithCancel(context.Background()) + + return &TCPTransport{ + address: address, + port: port, + connections: make(map[string]net.Conn), + metrics: TransportMetrics{}, + isServer: isServer, + receiveChan: make(chan *Message, 1000), + ctx: ctx, + cancel: cancel, + retryConfig: RetryConfig{ + MaxRetries: 3, + InitialDelay: time.Second, + MaxDelay: 30 * time.Second, + BackoffFactor: 2.0, + Jitter: true, + }, + } +} + +// SetTLSConfig configures TLS for secure communication +func (tt *TCPTransport) SetTLSConfig(config *tls.Config) { + tt.tlsConfig = config +} + +// SetRetryConfig configures retry behavior +func (tt *TCPTransport) SetRetryConfig(config RetryConfig) { + tt.retryConfig = config +} + +// Connect establishes the TCP connection +func (tt *TCPTransport) Connect(ctx context.Context) error { + tt.mu.Lock() + defer tt.mu.Unlock() + + if tt.connected { + return nil + } + + if tt.isServer { + return tt.startServer() + } else { + return tt.connectToServer(ctx) + } +} + +// Disconnect closes the TCP connection +func (tt *TCPTransport) Disconnect(ctx context.Context) error { + tt.mu.Lock() + defer tt.mu.Unlock() + + if !tt.connected { + return nil + } + + tt.cancel() + + if tt.isServer && tt.listener != nil { + tt.listener.Close() + } + + // Close all connections + for id, conn := range tt.connections { + conn.Close() + delete(tt.connections, id) + } + + close(tt.receiveChan) + tt.connected = false + tt.metrics.Connections = 0 + + return nil +} + +// Send transmits a message through TCP +func (tt *TCPTransport) Send(ctx context.Context, msg *Message) error { + start := time.Now() + + tt.mu.RLock() + if !tt.connected { + tt.mu.RUnlock() + tt.metrics.Errors++ + return fmt.Errorf("transport not connected") + } + + // Serialize message + data, err := json.Marshal(msg) + if err != nil { + tt.mu.RUnlock() + tt.metrics.Errors++ + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Add length prefix for framing + frame := fmt.Sprintf("%d\n%s", len(data), data) + frameBytes := []byte(frame) + + // Send to all connections with retry + var sendErr error + connectionCount := len(tt.connections) + tt.mu.RUnlock() + + if connectionCount == 0 { + tt.metrics.Errors++ + return fmt.Errorf("no active connections") + } + + tt.mu.RLock() + for connID, conn := range tt.connections { + if err := tt.sendWithRetry(ctx, conn, frameBytes); err != nil { + sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err) + // Remove failed connection + go tt.removeConnection(connID) + } + } + tt.mu.RUnlock() + + if sendErr == nil { + tt.updateSendMetrics(msg, time.Since(start)) + } else { + tt.metrics.Errors++ + } + + return sendErr +} + +// Receive returns a channel for receiving messages +func (tt *TCPTransport) Receive(ctx context.Context) (<-chan *Message, error) { + tt.mu.RLock() + defer tt.mu.RUnlock() + + if !tt.connected { + return nil, fmt.Errorf("transport not connected") + } + + return tt.receiveChan, nil +} + +// Health returns the health status of the transport +func (tt *TCPTransport) Health() ComponentHealth { + tt.mu.RLock() + defer tt.mu.RUnlock() + + status := "unhealthy" + var responseTime time.Duration + + if tt.connected { + if len(tt.connections) > 0 { + status = "healthy" + responseTime = time.Millisecond * 10 // Estimate for TCP + } else { + status = "degraded" // Connected but no active connections + } + } + + return ComponentHealth{ + Status: status, + LastCheck: time.Now(), + ResponseTime: responseTime, + ErrorCount: tt.metrics.Errors, + } +} + +// GetMetrics returns transport-specific metrics +func (tt *TCPTransport) GetMetrics() TransportMetrics { + tt.mu.RLock() + defer tt.mu.RUnlock() + + return TransportMetrics{ + BytesSent: tt.metrics.BytesSent, + BytesReceived: tt.metrics.BytesReceived, + MessagesSent: tt.metrics.MessagesSent, + MessagesReceived: tt.metrics.MessagesReceived, + Connections: len(tt.connections), + Errors: tt.metrics.Errors, + Latency: tt.metrics.Latency, + } +} + +// Private helper methods + +func (tt *TCPTransport) startServer() error { + addr := fmt.Sprintf("%s:%d", tt.address, tt.port) + + var listener net.Listener + var err error + + if tt.tlsConfig != nil { + listener, err = tls.Listen("tcp", addr, tt.tlsConfig) + } else { + listener, err = net.Listen("tcp", addr) + } + + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", addr, err) + } + + tt.listener = listener + tt.connected = true + + // Start accepting connections + go tt.acceptConnections() + + return nil +} + +func (tt *TCPTransport) connectToServer(ctx context.Context) error { + addr := fmt.Sprintf("%s:%d", tt.address, tt.port) + + var conn net.Conn + var err error + + // Retry connection with exponential backoff + delay := tt.retryConfig.InitialDelay + for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ { + if tt.tlsConfig != nil { + conn, err = tls.Dial("tcp", addr, tt.tlsConfig) + } else { + conn, err = net.Dial("tcp", addr) + } + + if err == nil { + break + } + + if attempt == tt.retryConfig.MaxRetries { + return fmt.Errorf("failed to connect to %s after %d attempts: %w", addr, attempt+1, err) + } + + // Wait with exponential backoff + select { + case <-time.After(delay): + delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor) + if delay > tt.retryConfig.MaxDelay { + delay = tt.retryConfig.MaxDelay + } + // Add jitter if enabled + if tt.retryConfig.Jitter { + jitter := time.Duration(float64(delay) * 0.1) + delay += time.Duration(float64(jitter) * (2*time.Now().UnixNano()%1000/1000.0 - 1)) + } + case <-ctx.Done(): + return ctx.Err() + } + } + + connID := fmt.Sprintf("client_%d", time.Now().UnixNano()) + tt.connections[connID] = conn + tt.connected = true + tt.metrics.Connections = 1 + + // Start receiving from server + go tt.handleConnection(connID, conn) + + return nil +} + +func (tt *TCPTransport) acceptConnections() { + for { + select { + case <-tt.ctx.Done(): + return + default: + conn, err := tt.listener.Accept() + if err != nil { + if tt.ctx.Err() != nil { + return // Context cancelled + } + tt.metrics.Errors++ + continue + } + + connID := fmt.Sprintf("server_%d", time.Now().UnixNano()) + + tt.mu.Lock() + tt.connections[connID] = conn + tt.metrics.Connections = len(tt.connections) + tt.mu.Unlock() + + go tt.handleConnection(connID, conn) + } + } +} + +func (tt *TCPTransport) handleConnection(connID string, conn net.Conn) { + defer tt.removeConnection(connID) + + buffer := make([]byte, 4096) + var messageBuffer []byte + + for { + select { + case <-tt.ctx.Done(): + return + default: + conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + n, err := conn.Read(buffer) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue // Continue on timeout + } + return // Connection closed or error + } + + messageBuffer = append(messageBuffer, buffer[:n]...) + + // Process complete messages + for { + msg, remaining, err := tt.extractMessage(messageBuffer) + if err != nil { + return // Invalid message format + } + if msg == nil { + break // No complete message yet + } + + // Deliver message + select { + case tt.receiveChan <- msg: + tt.updateReceiveMetrics(msg) + case <-tt.ctx.Done(): + return + default: + // Channel full, drop message + tt.metrics.Errors++ + } + + messageBuffer = remaining + } + } + } +} + +func (tt *TCPTransport) sendWithRetry(ctx context.Context, conn net.Conn, data []byte) error { + delay := tt.retryConfig.InitialDelay + + for attempt := 0; attempt <= tt.retryConfig.MaxRetries; attempt++ { + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + _, err := conn.Write(data) + if err == nil { + return nil + } + + if attempt == tt.retryConfig.MaxRetries { + return err + } + + // Wait with exponential backoff + select { + case <-time.After(delay): + delay = time.Duration(float64(delay) * tt.retryConfig.BackoffFactor) + if delay > tt.retryConfig.MaxDelay { + delay = tt.retryConfig.MaxDelay + } + case <-ctx.Done(): + return ctx.Err() + } + } + + return fmt.Errorf("max retries exceeded") +} + +func (tt *TCPTransport) removeConnection(connID string) { + tt.mu.Lock() + defer tt.mu.Unlock() + + if conn, exists := tt.connections[connID]; exists { + conn.Close() + delete(tt.connections, connID) + tt.metrics.Connections = len(tt.connections) + } +} + +func (tt *TCPTransport) extractMessage(buffer []byte) (*Message, []byte, error) { + // Look for length prefix (format: "length\nmessage_data") + newlineIndex := -1 + for i, b := range buffer { + if b == '\n' { + newlineIndex = i + break + } + } + + if newlineIndex == -1 { + return nil, buffer, nil // No complete length prefix yet + } + + // Parse length + lengthStr := string(buffer[:newlineIndex]) + var messageLength int + if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil { + return nil, nil, fmt.Errorf("invalid length prefix: %s", lengthStr) + } + + // Check if we have the complete message + messageStart := newlineIndex + 1 + messageEnd := messageStart + messageLength + if len(buffer) < messageEnd { + return nil, buffer, nil // Incomplete message + } + + // Extract and parse message + messageData := buffer[messageStart:messageEnd] + var msg Message + if err := json.Unmarshal(messageData, &msg); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal message: %w", err) + } + + // Return message and remaining buffer + remaining := buffer[messageEnd:] + return &msg, remaining, nil +} + +func (tt *TCPTransport) updateSendMetrics(msg *Message, latency time.Duration) { + tt.mu.Lock() + defer tt.mu.Unlock() + + tt.metrics.MessagesSent++ + tt.metrics.Latency = latency + + // Estimate message size + messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) + if msg.Data != nil { + messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) + } + tt.metrics.BytesSent += messageSize +} + +func (tt *TCPTransport) updateReceiveMetrics(msg *Message) { + tt.mu.Lock() + defer tt.mu.Unlock() + + tt.metrics.MessagesReceived++ + + // Estimate message size + messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) + if msg.Data != nil { + messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) + } + tt.metrics.BytesReceived += messageSize +} + +// GetAddress returns the transport address (for testing/debugging) +func (tt *TCPTransport) GetAddress() string { + return fmt.Sprintf("%s:%d", tt.address, tt.port) +} + +// GetConnectionCount returns the number of active connections +func (tt *TCPTransport) GetConnectionCount() int { + tt.mu.RLock() + defer tt.mu.RUnlock() + return len(tt.connections) +} + +// IsSecure returns whether TLS is enabled +func (tt *TCPTransport) IsSecure() bool { + return tt.tlsConfig != nil +} diff --git a/pkg/transport/unix_transport.go b/pkg/transport/unix_transport.go new file mode 100644 index 0000000..afb40b5 --- /dev/null +++ b/pkg/transport/unix_transport.go @@ -0,0 +1,399 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "net" + "os" + "sync" + "time" +) + +// UnixSocketTransport implements Unix socket transport for local IPC +type UnixSocketTransport struct { + socketPath string + listener net.Listener + connections map[string]net.Conn + metrics TransportMetrics + connected bool + isServer bool + receiveChan chan *Message + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc +} + +// NewUnixSocketTransport creates a new Unix socket transport +func NewUnixSocketTransport(socketPath string, isServer bool) *UnixSocketTransport { + ctx, cancel := context.WithCancel(context.Background()) + + return &UnixSocketTransport{ + socketPath: socketPath, + connections: make(map[string]net.Conn), + metrics: TransportMetrics{}, + isServer: isServer, + receiveChan: make(chan *Message, 1000), + ctx: ctx, + cancel: cancel, + } +} + +// Connect establishes the Unix socket connection +func (ut *UnixSocketTransport) Connect(ctx context.Context) error { + ut.mu.Lock() + defer ut.mu.Unlock() + + if ut.connected { + return nil + } + + if ut.isServer { + return ut.startServer() + } else { + return ut.connectToServer() + } +} + +// Disconnect closes the Unix socket connection +func (ut *UnixSocketTransport) Disconnect(ctx context.Context) error { + ut.mu.Lock() + defer ut.mu.Unlock() + + if !ut.connected { + return nil + } + + ut.cancel() + + if ut.isServer && ut.listener != nil { + ut.listener.Close() + // Remove socket file + os.Remove(ut.socketPath) + } + + // Close all connections + for id, conn := range ut.connections { + conn.Close() + delete(ut.connections, id) + } + + close(ut.receiveChan) + ut.connected = false + ut.metrics.Connections = 0 + + return nil +} + +// Send transmits a message through the Unix socket +func (ut *UnixSocketTransport) Send(ctx context.Context, msg *Message) error { + start := time.Now() + + ut.mu.RLock() + if !ut.connected { + ut.mu.RUnlock() + ut.metrics.Errors++ + return fmt.Errorf("transport not connected") + } + + // Serialize message + data, err := json.Marshal(msg) + if err != nil { + ut.mu.RUnlock() + ut.metrics.Errors++ + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Add length prefix for framing + frame := fmt.Sprintf("%d\n%s", len(data), data) + frameBytes := []byte(frame) + + // Send to all connections + var sendErr error + connectionCount := len(ut.connections) + ut.mu.RUnlock() + + if connectionCount == 0 { + ut.metrics.Errors++ + return fmt.Errorf("no active connections") + } + + ut.mu.RLock() + for connID, conn := range ut.connections { + if err := ut.sendToConnection(conn, frameBytes); err != nil { + sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err) + // Remove failed connection + go ut.removeConnection(connID) + } + } + ut.mu.RUnlock() + + if sendErr == nil { + ut.updateSendMetrics(msg, time.Since(start)) + } else { + ut.metrics.Errors++ + } + + return sendErr +} + +// Receive returns a channel for receiving messages +func (ut *UnixSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) { + ut.mu.RLock() + defer ut.mu.RUnlock() + + if !ut.connected { + return nil, fmt.Errorf("transport not connected") + } + + return ut.receiveChan, nil +} + +// Health returns the health status of the transport +func (ut *UnixSocketTransport) Health() ComponentHealth { + ut.mu.RLock() + defer ut.mu.RUnlock() + + status := "unhealthy" + if ut.connected { + if len(ut.connections) > 0 { + status = "healthy" + } else { + status = "degraded" // Connected but no active connections + } + } + + return ComponentHealth{ + Status: status, + LastCheck: time.Now(), + ResponseTime: time.Millisecond, // Fast for local sockets + ErrorCount: ut.metrics.Errors, + } +} + +// GetMetrics returns transport-specific metrics +func (ut *UnixSocketTransport) GetMetrics() TransportMetrics { + ut.mu.RLock() + defer ut.mu.RUnlock() + + return TransportMetrics{ + BytesSent: ut.metrics.BytesSent, + BytesReceived: ut.metrics.BytesReceived, + MessagesSent: ut.metrics.MessagesSent, + MessagesReceived: ut.metrics.MessagesReceived, + Connections: len(ut.connections), + Errors: ut.metrics.Errors, + Latency: ut.metrics.Latency, + } +} + +// Private helper methods + +func (ut *UnixSocketTransport) startServer() error { + // Remove existing socket file + os.Remove(ut.socketPath) + + listener, err := net.Listen("unix", ut.socketPath) + if err != nil { + return fmt.Errorf("failed to listen on socket %s: %w", ut.socketPath, err) + } + + ut.listener = listener + ut.connected = true + + // Start accepting connections + go ut.acceptConnections() + + return nil +} + +func (ut *UnixSocketTransport) connectToServer() error { + conn, err := net.Dial("unix", ut.socketPath) + if err != nil { + return fmt.Errorf("failed to connect to socket %s: %w", ut.socketPath, err) + } + + connID := fmt.Sprintf("client_%d", time.Now().UnixNano()) + ut.connections[connID] = conn + ut.connected = true + ut.metrics.Connections = 1 + + // Start receiving from server + go ut.handleConnection(connID, conn) + + return nil +} + +func (ut *UnixSocketTransport) acceptConnections() { + for { + select { + case <-ut.ctx.Done(): + return + default: + conn, err := ut.listener.Accept() + if err != nil { + if ut.ctx.Err() != nil { + return // Context cancelled + } + ut.metrics.Errors++ + continue + } + + connID := fmt.Sprintf("server_%d", time.Now().UnixNano()) + + ut.mu.Lock() + ut.connections[connID] = conn + ut.metrics.Connections = len(ut.connections) + ut.mu.Unlock() + + go ut.handleConnection(connID, conn) + } + } +} + +func (ut *UnixSocketTransport) handleConnection(connID string, conn net.Conn) { + defer ut.removeConnection(connID) + + buffer := make([]byte, 4096) + var messageBuffer []byte + + for { + select { + case <-ut.ctx.Done(): + return + default: + conn.SetReadDeadline(time.Now().Add(time.Second)) // Non-blocking read + n, err := conn.Read(buffer) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue // Continue on timeout + } + return // Connection closed or error + } + + messageBuffer = append(messageBuffer, buffer[:n]...) + + // Process complete messages + for { + msg, remaining, err := ut.extractMessage(messageBuffer) + if err != nil { + return // Invalid message format + } + if msg == nil { + break // No complete message yet + } + + // Deliver message + select { + case ut.receiveChan <- msg: + ut.updateReceiveMetrics(msg) + case <-ut.ctx.Done(): + return + default: + // Channel full, drop message + ut.metrics.Errors++ + } + + messageBuffer = remaining + } + } + } +} + +func (ut *UnixSocketTransport) sendToConnection(conn net.Conn, data []byte) error { + conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _, err := conn.Write(data) + return err +} + +func (ut *UnixSocketTransport) removeConnection(connID string) { + ut.mu.Lock() + defer ut.mu.Unlock() + + if conn, exists := ut.connections[connID]; exists { + conn.Close() + delete(ut.connections, connID) + ut.metrics.Connections = len(ut.connections) + } +} + +func (ut *UnixSocketTransport) extractMessage(buffer []byte) (*Message, []byte, error) { + // Look for length prefix (format: "length\nmessage_data") + newlineIndex := -1 + for i, b := range buffer { + if b == '\n' { + newlineIndex = i + break + } + } + + if newlineIndex == -1 { + return nil, buffer, nil // No complete length prefix yet + } + + // Parse length + lengthStr := string(buffer[:newlineIndex]) + var messageLength int + if _, err := fmt.Sscanf(lengthStr, "%d", &messageLength); err != nil { + return nil, nil, fmt.Errorf("invalid length prefix: %s", lengthStr) + } + + // Check if we have the complete message + messageStart := newlineIndex + 1 + messageEnd := messageStart + messageLength + if len(buffer) < messageEnd { + return nil, buffer, nil // Incomplete message + } + + // Extract and parse message + messageData := buffer[messageStart:messageEnd] + var msg Message + if err := json.Unmarshal(messageData, &msg); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal message: %w", err) + } + + // Return message and remaining buffer + remaining := buffer[messageEnd:] + return &msg, remaining, nil +} + +func (ut *UnixSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) { + ut.mu.Lock() + defer ut.mu.Unlock() + + ut.metrics.MessagesSent++ + ut.metrics.Latency = latency + + // Estimate message size + messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) + if msg.Data != nil { + messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) + } + ut.metrics.BytesSent += messageSize +} + +func (ut *UnixSocketTransport) updateReceiveMetrics(msg *Message) { + ut.mu.Lock() + defer ut.mu.Unlock() + + ut.metrics.MessagesReceived++ + + // Estimate message size + messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) + if msg.Data != nil { + messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) + } + ut.metrics.BytesReceived += messageSize +} + +// GetSocketPath returns the socket path (for testing/debugging) +func (ut *UnixSocketTransport) GetSocketPath() string { + return ut.socketPath +} + +// GetConnectionCount returns the number of active connections +func (ut *UnixSocketTransport) GetConnectionCount() int { + ut.mu.RLock() + defer ut.mu.RUnlock() + return len(ut.connections) +} diff --git a/pkg/transport/websocket_transport.go b/pkg/transport/websocket_transport.go new file mode 100644 index 0000000..19ab62b --- /dev/null +++ b/pkg/transport/websocket_transport.go @@ -0,0 +1,427 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// WebSocketTransport implements WebSocket transport for real-time monitoring +type WebSocketTransport struct { + address string + port int + path string + upgrader websocket.Upgrader + connections map[string]*websocket.Conn + metrics TransportMetrics + connected bool + isServer bool + receiveChan chan *Message + server *http.Server + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + pingInterval time.Duration + pongTimeout time.Duration +} + +// NewWebSocketTransport creates a new WebSocket transport +func NewWebSocketTransport(address string, port int, path string, isServer bool) *WebSocketTransport { + ctx, cancel := context.WithCancel(context.Background()) + + return &WebSocketTransport{ + address: address, + port: port, + path: path, + connections: make(map[string]*websocket.Conn), + metrics: TransportMetrics{}, + isServer: isServer, + receiveChan: make(chan *Message, 1000), + ctx: ctx, + cancel: cancel, + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Allow all origins for now + }, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + pingInterval: 30 * time.Second, + pongTimeout: 10 * time.Second, + } +} + +// SetPingPongSettings configures WebSocket ping/pong settings +func (wt *WebSocketTransport) SetPingPongSettings(pingInterval, pongTimeout time.Duration) { + wt.pingInterval = pingInterval + wt.pongTimeout = pongTimeout +} + +// Connect establishes the WebSocket connection +func (wt *WebSocketTransport) Connect(ctx context.Context) error { + wt.mu.Lock() + defer wt.mu.Unlock() + + if wt.connected { + return nil + } + + if wt.isServer { + return wt.startServer() + } else { + return wt.connectToServer(ctx) + } +} + +// Disconnect closes the WebSocket connection +func (wt *WebSocketTransport) Disconnect(ctx context.Context) error { + wt.mu.Lock() + defer wt.mu.Unlock() + + if !wt.connected { + return nil + } + + wt.cancel() + + if wt.isServer && wt.server != nil { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wt.server.Shutdown(shutdownCtx) + } + + // Close all connections + for id, conn := range wt.connections { + conn.Close() + delete(wt.connections, id) + } + + close(wt.receiveChan) + wt.connected = false + wt.metrics.Connections = 0 + + return nil +} + +// Send transmits a message through WebSocket +func (wt *WebSocketTransport) Send(ctx context.Context, msg *Message) error { + start := time.Now() + + wt.mu.RLock() + if !wt.connected { + wt.mu.RUnlock() + wt.metrics.Errors++ + return fmt.Errorf("transport not connected") + } + + // Serialize message + data, err := json.Marshal(msg) + if err != nil { + wt.mu.RUnlock() + wt.metrics.Errors++ + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Send to all connections + var sendErr error + connectionCount := len(wt.connections) + wt.mu.RUnlock() + + if connectionCount == 0 { + wt.metrics.Errors++ + return fmt.Errorf("no active connections") + } + + wt.mu.RLock() + for connID, conn := range wt.connections { + if err := wt.sendToConnection(conn, data); err != nil { + sendErr = fmt.Errorf("failed to send to connection %s: %w", connID, err) + // Remove failed connection + go wt.removeConnection(connID) + } + } + wt.mu.RUnlock() + + if sendErr == nil { + wt.updateSendMetrics(msg, time.Since(start)) + } else { + wt.metrics.Errors++ + } + + return sendErr +} + +// Receive returns a channel for receiving messages +func (wt *WebSocketTransport) Receive(ctx context.Context) (<-chan *Message, error) { + wt.mu.RLock() + defer wt.mu.RUnlock() + + if !wt.connected { + return nil, fmt.Errorf("transport not connected") + } + + return wt.receiveChan, nil +} + +// Health returns the health status of the transport +func (wt *WebSocketTransport) Health() ComponentHealth { + wt.mu.RLock() + defer wt.mu.RUnlock() + + status := "unhealthy" + if wt.connected { + if len(wt.connections) > 0 { + status = "healthy" + } else { + status = "degraded" // Connected but no active connections + } + } + + return ComponentHealth{ + Status: status, + LastCheck: time.Now(), + ResponseTime: time.Millisecond * 5, // Very fast for WebSocket + ErrorCount: wt.metrics.Errors, + } +} + +// GetMetrics returns transport-specific metrics +func (wt *WebSocketTransport) GetMetrics() TransportMetrics { + wt.mu.RLock() + defer wt.mu.RUnlock() + + return TransportMetrics{ + BytesSent: wt.metrics.BytesSent, + BytesReceived: wt.metrics.BytesReceived, + MessagesSent: wt.metrics.MessagesSent, + MessagesReceived: wt.metrics.MessagesReceived, + Connections: len(wt.connections), + Errors: wt.metrics.Errors, + Latency: wt.metrics.Latency, + } +} + +// Private helper methods + +func (wt *WebSocketTransport) startServer() error { + mux := http.NewServeMux() + mux.HandleFunc(wt.path, wt.handleWebSocket) + + // Add health check endpoint + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + health := wt.Health() + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(health) + }) + + // Add metrics endpoint + mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { + metrics := wt.GetMetrics() + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(metrics) + }) + + addr := fmt.Sprintf("%s:%d", wt.address, wt.port) + wt.server = &http.Server{ + Addr: addr, + Handler: mux, + ReadTimeout: 60 * time.Second, + WriteTimeout: 60 * time.Second, + IdleTimeout: 120 * time.Second, + } + + wt.connected = true + + // Start server in goroutine + go func() { + if err := wt.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + wt.metrics.Errors++ + } + }() + + return nil +} + +func (wt *WebSocketTransport) connectToServer(ctx context.Context) error { + url := fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path) + + conn, _, err := websocket.DefaultDialer.DialContext(ctx, url, nil) + if err != nil { + return fmt.Errorf("failed to connect to WebSocket server: %w", err) + } + + connID := fmt.Sprintf("client_%d", time.Now().UnixNano()) + wt.connections[connID] = conn + wt.connected = true + wt.metrics.Connections = 1 + + // Start handling connection + go wt.handleConnection(connID, conn) + + return nil +} + +func (wt *WebSocketTransport) handleWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := wt.upgrader.Upgrade(w, r, nil) + if err != nil { + wt.metrics.Errors++ + return + } + + connID := fmt.Sprintf("server_%d", time.Now().UnixNano()) + + wt.mu.Lock() + wt.connections[connID] = conn + wt.metrics.Connections = len(wt.connections) + wt.mu.Unlock() + + go wt.handleConnection(connID, conn) +} + +func (wt *WebSocketTransport) handleConnection(connID string, conn *websocket.Conn) { + defer wt.removeConnection(connID) + + // Set up ping/pong handling + conn.SetReadDeadline(time.Now().Add(wt.pongTimeout)) + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(wt.pongTimeout)) + return nil + }) + + // Start ping routine + go wt.pingRoutine(connID, conn) + + for { + select { + case <-wt.ctx.Done(): + return + default: + var msg Message + err := conn.ReadJSON(&msg) + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + wt.metrics.Errors++ + } + return + } + + // Deliver message + select { + case wt.receiveChan <- &msg: + wt.updateReceiveMetrics(&msg) + case <-wt.ctx.Done(): + return + default: + // Channel full, drop message + wt.metrics.Errors++ + } + } + } +} + +func (wt *WebSocketTransport) pingRoutine(connID string, conn *websocket.Conn) { + ticker := time.NewTicker(wt.pingInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + return // Connection is likely closed + } + case <-wt.ctx.Done(): + return + } + } +} + +func (wt *WebSocketTransport) sendToConnection(conn *websocket.Conn, data []byte) error { + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return conn.WriteMessage(websocket.TextMessage, data) +} + +func (wt *WebSocketTransport) removeConnection(connID string) { + wt.mu.Lock() + defer wt.mu.Unlock() + + if conn, exists := wt.connections[connID]; exists { + conn.Close() + delete(wt.connections, connID) + wt.metrics.Connections = len(wt.connections) + } +} + +func (wt *WebSocketTransport) updateSendMetrics(msg *Message, latency time.Duration) { + wt.mu.Lock() + defer wt.mu.Unlock() + + wt.metrics.MessagesSent++ + wt.metrics.Latency = latency + + // Estimate message size + messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) + if msg.Data != nil { + messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) + } + wt.metrics.BytesSent += messageSize +} + +func (wt *WebSocketTransport) updateReceiveMetrics(msg *Message) { + wt.mu.Lock() + defer wt.mu.Unlock() + + wt.metrics.MessagesReceived++ + + // Estimate message size + messageSize := int64(len(msg.ID) + len(msg.Topic) + len(msg.Source)) + if msg.Data != nil { + messageSize += int64(len(fmt.Sprintf("%v", msg.Data))) + } + wt.metrics.BytesReceived += messageSize +} + +// Broadcast sends a message to all connected clients (server mode only) +func (wt *WebSocketTransport) Broadcast(ctx context.Context, msg *Message) error { + if !wt.isServer { + return fmt.Errorf("broadcast only available in server mode") + } + + return wt.Send(ctx, msg) +} + +// GetURL returns the WebSocket URL (for testing/debugging) +func (wt *WebSocketTransport) GetURL() string { + return fmt.Sprintf("ws://%s:%d%s", wt.address, wt.port, wt.path) +} + +// GetConnectionCount returns the number of active connections +func (wt *WebSocketTransport) GetConnectionCount() int { + wt.mu.RLock() + defer wt.mu.RUnlock() + return len(wt.connections) +} + +// SetAllowedOrigins configures CORS for WebSocket connections +func (wt *WebSocketTransport) SetAllowedOrigins(origins []string) { + if len(origins) == 0 { + wt.upgrader.CheckOrigin = func(r *http.Request) bool { + return true // Allow all origins + } + return + } + + originMap := make(map[string]bool) + for _, origin := range origins { + originMap[origin] = true + } + + wt.upgrader.CheckOrigin = func(r *http.Request) bool { + origin := r.Header.Get("Origin") + return originMap[origin] + } +}