Files
mev-beta/pkg/arbitrum/rpc_manager.go

357 lines
9.2 KiB
Go

package arbitrum
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/fraktal/mev-beta/internal/logger"
pkgerrors "github.com/fraktal/mev-beta/pkg/errors"
)
// RPCEndpointHealth tracks health metrics for an RPC endpoint
type RPCEndpointHealth struct {
URL string
SuccessCount int64
FailureCount int64
ConsecutiveFails int64
LastChecked time.Time
IsHealthy bool
ResponseTime time.Duration
mu sync.RWMutex
}
// RecordSuccess records a successful RPC call
func (h *RPCEndpointHealth) RecordSuccess(responseTime time.Duration) {
h.mu.Lock()
defer h.mu.Unlock()
atomic.AddInt64(&h.SuccessCount, 1)
atomic.StoreInt64(&h.ConsecutiveFails, 0)
h.LastChecked = time.Now()
h.ResponseTime = responseTime
h.IsHealthy = true
}
// RecordFailure records a failed RPC call
func (h *RPCEndpointHealth) RecordFailure() {
h.mu.Lock()
defer h.mu.Unlock()
atomic.AddInt64(&h.FailureCount, 1)
atomic.AddInt64(&h.ConsecutiveFails, 1)
h.LastChecked = time.Now()
// Mark unhealthy if 3+ consecutive failures
if atomic.LoadInt64(&h.ConsecutiveFails) >= 3 {
h.IsHealthy = false
}
}
// GetStats returns health statistics
func (h *RPCEndpointHealth) GetStats() (success, failure, consecutive int64, healthy bool) {
h.mu.RLock()
defer h.mu.RUnlock()
return atomic.LoadInt64(&h.SuccessCount), atomic.LoadInt64(&h.FailureCount),
atomic.LoadInt64(&h.ConsecutiveFails), h.IsHealthy
}
// RPCManager manages multiple RPC endpoints with round-robin load balancing
type RPCManager struct {
endpoints []*RateLimitedClient
health []*RPCEndpointHealth
currentIndex int64
logger *logger.Logger
mu sync.RWMutex
rotationPolicy RotationPolicy
}
// RotationPolicy defines how RPCs are rotated
type RotationPolicy string
const (
RoundRobin RotationPolicy = "round-robin"
HealthAware RotationPolicy = "health-aware"
LeastFailures RotationPolicy = "least-failures"
)
// NewRPCManager creates a new RPC manager with multiple endpoints
func NewRPCManager(logger *logger.Logger) *RPCManager {
return &RPCManager{
endpoints: make([]*RateLimitedClient, 0),
health: make([]*RPCEndpointHealth, 0),
currentIndex: 0,
logger: logger,
rotationPolicy: RoundRobin, // Default to simple round-robin
}
}
// AddEndpoint adds an RPC endpoint to the manager
func (rm *RPCManager) AddEndpoint(client *RateLimitedClient, url string) error {
if client == nil {
return fmt.Errorf("client cannot be nil")
}
rm.mu.Lock()
defer rm.mu.Unlock()
rm.endpoints = append(rm.endpoints, client)
rm.health = append(rm.health, &RPCEndpointHealth{
URL: url,
IsHealthy: true,
})
rm.logger.Info(fmt.Sprintf("✅ Added RPC endpoint %d: %s", len(rm.endpoints), url))
return nil
}
// GetNextClient returns the next RPC client using the configured rotation policy
func (rm *RPCManager) GetNextClient(ctx context.Context) (*RateLimitedClient, int, error) {
rm.mu.RLock()
defer rm.mu.RUnlock()
if len(rm.endpoints) == 0 {
return nil, -1, fmt.Errorf("no RPC endpoints available")
}
var clientIndex int
switch rm.rotationPolicy {
case HealthAware:
clientIndex = rm.selectHealthAware()
case LeastFailures:
clientIndex = rm.selectLeastFailures()
default: // RoundRobin
clientIndex = rm.selectRoundRobin()
}
if clientIndex < 0 || clientIndex >= len(rm.endpoints) {
return nil, -1, fmt.Errorf("invalid endpoint index: %d", clientIndex)
}
return rm.endpoints[clientIndex], clientIndex, nil
}
// selectRoundRobin selects the next endpoint using simple round-robin
func (rm *RPCManager) selectRoundRobin() int {
current := atomic.AddInt64(&rm.currentIndex, 1)
return int((current - 1) % int64(len(rm.endpoints)))
}
// selectHealthAware selects an endpoint preferring healthy ones
func (rm *RPCManager) selectHealthAware() int {
// First, try to find a healthy endpoint
for i := 0; i < len(rm.health); i++ {
idx := (int(atomic.LoadInt64(&rm.currentIndex)) + i) % len(rm.endpoints)
if rm.health[idx].IsHealthy {
atomic.AddInt64(&rm.currentIndex, 1)
return idx
}
}
// If all are unhealthy, fall back to round-robin
return rm.selectRoundRobin()
}
// selectLeastFailures selects the endpoint with least failures
func (rm *RPCManager) selectLeastFailures() int {
if len(rm.health) == 0 {
return 0
}
minIndex := 0
minFailures := atomic.LoadInt64(&rm.health[0].FailureCount)
for i := 1; i < len(rm.health); i++ {
failures := atomic.LoadInt64(&rm.health[i].FailureCount)
if failures < minFailures {
minFailures = failures
minIndex = i
}
}
atomic.AddInt64(&rm.currentIndex, 1)
return minIndex
}
// RecordSuccess records a successful call to an endpoint
func (rm *RPCManager) RecordSuccess(endpointIndex int, responseTime time.Duration) {
rm.mu.RLock()
defer rm.mu.RUnlock()
if endpointIndex < 0 || endpointIndex >= len(rm.health) {
return
}
rm.health[endpointIndex].RecordSuccess(responseTime)
}
// RecordFailure records a failed call to an endpoint
func (rm *RPCManager) RecordFailure(endpointIndex int) {
rm.mu.RLock()
defer rm.mu.RUnlock()
if endpointIndex < 0 || endpointIndex >= len(rm.health) {
return
}
rm.health[endpointIndex].RecordFailure()
}
// GetEndpointHealth returns health information for a specific endpoint
func (rm *RPCManager) GetEndpointHealth(endpointIndex int) (*RPCEndpointHealth, error) {
rm.mu.RLock()
defer rm.mu.RUnlock()
if endpointIndex < 0 || endpointIndex >= len(rm.health) {
return nil, fmt.Errorf("invalid endpoint index: %d", endpointIndex)
}
return rm.health[endpointIndex], nil
}
// GetAllHealthStats returns health statistics for all endpoints
func (rm *RPCManager) GetAllHealthStats() []map[string]interface{} {
rm.mu.RLock()
defer rm.mu.RUnlock()
stats := make([]map[string]interface{}, 0, len(rm.health))
for i, h := range rm.health {
success, failure, consecutive, healthy := h.GetStats()
stats = append(stats, map[string]interface{}{
"index": i,
"url": h.URL,
"success_count": success,
"failure_count": failure,
"consecutive_fails": consecutive,
"is_healthy": healthy,
"last_checked": h.LastChecked,
"response_time_ms": h.ResponseTime.Milliseconds(),
})
}
return stats
}
// SetRotationPolicy sets the rotation policy for endpoint selection
func (rm *RPCManager) SetRotationPolicy(policy RotationPolicy) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.rotationPolicy = policy
rm.logger.Info(fmt.Sprintf("📊 RPC rotation policy set to: %s", policy))
}
// HealthCheckAll performs a health check on all endpoints
func (rm *RPCManager) HealthCheckAll(ctx context.Context) error {
rm.mu.RLock()
endpoints := rm.endpoints
rm.mu.RUnlock()
var wg sync.WaitGroup
errors := make([]error, 0)
errorMu := sync.Mutex{}
for i, client := range endpoints {
wg.Add(1)
go func(idx int, cli *RateLimitedClient) {
defer wg.Done()
if err := rm.healthCheckEndpoint(ctx, idx, cli); err != nil {
errorMu.Lock()
errors = append(errors, err)
errorMu.Unlock()
}
}(i, client)
}
wg.Wait()
if len(errors) > 0 {
return fmt.Errorf("health check failures: %v", errors)
}
return nil
}
// healthCheckEndpoint performs a health check on a single endpoint
func (rm *RPCManager) healthCheckEndpoint(ctx context.Context, index int, client *RateLimitedClient) error {
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
start := time.Now()
// Try to get chain ID as a simple health check
if client == nil || client.Client == nil {
rm.RecordFailure(index)
rm.logger.Warn(fmt.Sprintf("⚠️ RPC endpoint %d is nil", index))
return fmt.Errorf("endpoint %d is nil", index)
}
_, err := client.Client.ChainID(checkCtx)
responseTime := time.Since(start)
if err != nil {
rm.RecordFailure(index)
return pkgerrors.WrapContextError(err, "RPCManager.healthCheckEndpoint",
map[string]interface{}{
"endpoint_index": index,
"response_time": responseTime.String(),
})
}
rm.RecordSuccess(index, responseTime)
return nil
}
// Close closes all RPC client connections
func (rm *RPCManager) Close() error {
rm.mu.Lock()
defer rm.mu.Unlock()
for i, client := range rm.endpoints {
if client != nil && client.Client != nil {
rm.logger.Debug(fmt.Sprintf("Closing RPC endpoint %d", i))
client.Client.Close()
}
}
rm.endpoints = nil
rm.health = nil
return nil
}
// GetStats returns a summary of all endpoint statistics
func (rm *RPCManager) GetStats() map[string]interface{} {
rm.mu.RLock()
defer rm.mu.RUnlock()
totalSuccess := int64(0)
totalFailure := int64(0)
healthyCount := 0
for _, h := range rm.health {
success, failure, _, healthy := h.GetStats()
totalSuccess += success
totalFailure += failure
if healthy {
healthyCount++
}
}
totalRequests := totalSuccess + totalFailure
successRate := 0.0
if totalRequests > 0 {
successRate = float64(totalSuccess) / float64(totalRequests) * 100
}
return map[string]interface{}{
"total_endpoints": len(rm.endpoints),
"healthy_count": healthyCount,
"total_requests": totalRequests,
"total_success": totalSuccess,
"total_failure": totalFailure,
"success_rate": fmt.Sprintf("%.2f%%", successRate),
"current_policy": rm.rotationPolicy,
"endpoint_details": rm.GetAllHealthStats(),
}
}