357 lines
9.2 KiB
Go
357 lines
9.2 KiB
Go
package arbitrum
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/fraktal/mev-beta/internal/logger"
|
|
pkgerrors "github.com/fraktal/mev-beta/pkg/errors"
|
|
)
|
|
|
|
// RPCEndpointHealth tracks health metrics for an RPC endpoint
|
|
type RPCEndpointHealth struct {
|
|
URL string
|
|
SuccessCount int64
|
|
FailureCount int64
|
|
ConsecutiveFails int64
|
|
LastChecked time.Time
|
|
IsHealthy bool
|
|
ResponseTime time.Duration
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// RecordSuccess records a successful RPC call
|
|
func (h *RPCEndpointHealth) RecordSuccess(responseTime time.Duration) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
atomic.AddInt64(&h.SuccessCount, 1)
|
|
atomic.StoreInt64(&h.ConsecutiveFails, 0)
|
|
h.LastChecked = time.Now()
|
|
h.ResponseTime = responseTime
|
|
h.IsHealthy = true
|
|
}
|
|
|
|
// RecordFailure records a failed RPC call
|
|
func (h *RPCEndpointHealth) RecordFailure() {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
atomic.AddInt64(&h.FailureCount, 1)
|
|
atomic.AddInt64(&h.ConsecutiveFails, 1)
|
|
h.LastChecked = time.Now()
|
|
// Mark unhealthy if 3+ consecutive failures
|
|
if atomic.LoadInt64(&h.ConsecutiveFails) >= 3 {
|
|
h.IsHealthy = false
|
|
}
|
|
}
|
|
|
|
// GetStats returns health statistics
|
|
func (h *RPCEndpointHealth) GetStats() (success, failure, consecutive int64, healthy bool) {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
return atomic.LoadInt64(&h.SuccessCount), atomic.LoadInt64(&h.FailureCount),
|
|
atomic.LoadInt64(&h.ConsecutiveFails), h.IsHealthy
|
|
}
|
|
|
|
// RPCManager manages multiple RPC endpoints with round-robin load balancing
|
|
type RPCManager struct {
|
|
endpoints []*RateLimitedClient
|
|
health []*RPCEndpointHealth
|
|
currentIndex int64
|
|
logger *logger.Logger
|
|
mu sync.RWMutex
|
|
rotationPolicy RotationPolicy
|
|
}
|
|
|
|
// RotationPolicy defines how RPCs are rotated
|
|
type RotationPolicy string
|
|
|
|
const (
|
|
RoundRobin RotationPolicy = "round-robin"
|
|
HealthAware RotationPolicy = "health-aware"
|
|
LeastFailures RotationPolicy = "least-failures"
|
|
)
|
|
|
|
// NewRPCManager creates a new RPC manager with multiple endpoints
|
|
func NewRPCManager(logger *logger.Logger) *RPCManager {
|
|
return &RPCManager{
|
|
endpoints: make([]*RateLimitedClient, 0),
|
|
health: make([]*RPCEndpointHealth, 0),
|
|
currentIndex: 0,
|
|
logger: logger,
|
|
rotationPolicy: RoundRobin, // Default to simple round-robin
|
|
}
|
|
}
|
|
|
|
// AddEndpoint adds an RPC endpoint to the manager
|
|
func (rm *RPCManager) AddEndpoint(client *RateLimitedClient, url string) error {
|
|
if client == nil {
|
|
return fmt.Errorf("client cannot be nil")
|
|
}
|
|
|
|
rm.mu.Lock()
|
|
defer rm.mu.Unlock()
|
|
|
|
rm.endpoints = append(rm.endpoints, client)
|
|
rm.health = append(rm.health, &RPCEndpointHealth{
|
|
URL: url,
|
|
IsHealthy: true,
|
|
})
|
|
|
|
rm.logger.Info(fmt.Sprintf("✅ Added RPC endpoint %d: %s", len(rm.endpoints), url))
|
|
return nil
|
|
}
|
|
|
|
// GetNextClient returns the next RPC client using the configured rotation policy
|
|
func (rm *RPCManager) GetNextClient(ctx context.Context) (*RateLimitedClient, int, error) {
|
|
rm.mu.RLock()
|
|
defer rm.mu.RUnlock()
|
|
|
|
if len(rm.endpoints) == 0 {
|
|
return nil, -1, fmt.Errorf("no RPC endpoints available")
|
|
}
|
|
|
|
var clientIndex int
|
|
|
|
switch rm.rotationPolicy {
|
|
case HealthAware:
|
|
clientIndex = rm.selectHealthAware()
|
|
case LeastFailures:
|
|
clientIndex = rm.selectLeastFailures()
|
|
default: // RoundRobin
|
|
clientIndex = rm.selectRoundRobin()
|
|
}
|
|
|
|
if clientIndex < 0 || clientIndex >= len(rm.endpoints) {
|
|
return nil, -1, fmt.Errorf("invalid endpoint index: %d", clientIndex)
|
|
}
|
|
|
|
return rm.endpoints[clientIndex], clientIndex, nil
|
|
}
|
|
|
|
// selectRoundRobin selects the next endpoint using simple round-robin
|
|
func (rm *RPCManager) selectRoundRobin() int {
|
|
current := atomic.AddInt64(&rm.currentIndex, 1)
|
|
return int((current - 1) % int64(len(rm.endpoints)))
|
|
}
|
|
|
|
// selectHealthAware selects an endpoint preferring healthy ones
|
|
func (rm *RPCManager) selectHealthAware() int {
|
|
// First, try to find a healthy endpoint
|
|
for i := 0; i < len(rm.health); i++ {
|
|
idx := (int(atomic.LoadInt64(&rm.currentIndex)) + i) % len(rm.endpoints)
|
|
if rm.health[idx].IsHealthy {
|
|
atomic.AddInt64(&rm.currentIndex, 1)
|
|
return idx
|
|
}
|
|
}
|
|
|
|
// If all are unhealthy, fall back to round-robin
|
|
return rm.selectRoundRobin()
|
|
}
|
|
|
|
// selectLeastFailures selects the endpoint with least failures
|
|
func (rm *RPCManager) selectLeastFailures() int {
|
|
if len(rm.health) == 0 {
|
|
return 0
|
|
}
|
|
|
|
minIndex := 0
|
|
minFailures := atomic.LoadInt64(&rm.health[0].FailureCount)
|
|
|
|
for i := 1; i < len(rm.health); i++ {
|
|
failures := atomic.LoadInt64(&rm.health[i].FailureCount)
|
|
if failures < minFailures {
|
|
minFailures = failures
|
|
minIndex = i
|
|
}
|
|
}
|
|
|
|
atomic.AddInt64(&rm.currentIndex, 1)
|
|
return minIndex
|
|
}
|
|
|
|
// RecordSuccess records a successful call to an endpoint
|
|
func (rm *RPCManager) RecordSuccess(endpointIndex int, responseTime time.Duration) {
|
|
rm.mu.RLock()
|
|
defer rm.mu.RUnlock()
|
|
|
|
if endpointIndex < 0 || endpointIndex >= len(rm.health) {
|
|
return
|
|
}
|
|
|
|
rm.health[endpointIndex].RecordSuccess(responseTime)
|
|
}
|
|
|
|
// RecordFailure records a failed call to an endpoint
|
|
func (rm *RPCManager) RecordFailure(endpointIndex int) {
|
|
rm.mu.RLock()
|
|
defer rm.mu.RUnlock()
|
|
|
|
if endpointIndex < 0 || endpointIndex >= len(rm.health) {
|
|
return
|
|
}
|
|
|
|
rm.health[endpointIndex].RecordFailure()
|
|
}
|
|
|
|
// GetEndpointHealth returns health information for a specific endpoint
|
|
func (rm *RPCManager) GetEndpointHealth(endpointIndex int) (*RPCEndpointHealth, error) {
|
|
rm.mu.RLock()
|
|
defer rm.mu.RUnlock()
|
|
|
|
if endpointIndex < 0 || endpointIndex >= len(rm.health) {
|
|
return nil, fmt.Errorf("invalid endpoint index: %d", endpointIndex)
|
|
}
|
|
|
|
return rm.health[endpointIndex], nil
|
|
}
|
|
|
|
// GetAllHealthStats returns health statistics for all endpoints
|
|
func (rm *RPCManager) GetAllHealthStats() []map[string]interface{} {
|
|
rm.mu.RLock()
|
|
defer rm.mu.RUnlock()
|
|
|
|
stats := make([]map[string]interface{}, 0, len(rm.health))
|
|
for i, h := range rm.health {
|
|
success, failure, consecutive, healthy := h.GetStats()
|
|
stats = append(stats, map[string]interface{}{
|
|
"index": i,
|
|
"url": h.URL,
|
|
"success_count": success,
|
|
"failure_count": failure,
|
|
"consecutive_fails": consecutive,
|
|
"is_healthy": healthy,
|
|
"last_checked": h.LastChecked,
|
|
"response_time_ms": h.ResponseTime.Milliseconds(),
|
|
})
|
|
}
|
|
return stats
|
|
}
|
|
|
|
// SetRotationPolicy sets the rotation policy for endpoint selection
|
|
func (rm *RPCManager) SetRotationPolicy(policy RotationPolicy) {
|
|
rm.mu.Lock()
|
|
defer rm.mu.Unlock()
|
|
rm.rotationPolicy = policy
|
|
rm.logger.Info(fmt.Sprintf("📊 RPC rotation policy set to: %s", policy))
|
|
}
|
|
|
|
// HealthCheckAll performs a health check on all endpoints
|
|
func (rm *RPCManager) HealthCheckAll(ctx context.Context) error {
|
|
rm.mu.RLock()
|
|
endpoints := rm.endpoints
|
|
rm.mu.RUnlock()
|
|
|
|
var wg sync.WaitGroup
|
|
errors := make([]error, 0)
|
|
errorMu := sync.Mutex{}
|
|
|
|
for i, client := range endpoints {
|
|
wg.Add(1)
|
|
go func(idx int, cli *RateLimitedClient) {
|
|
defer wg.Done()
|
|
|
|
if err := rm.healthCheckEndpoint(ctx, idx, cli); err != nil {
|
|
errorMu.Lock()
|
|
errors = append(errors, err)
|
|
errorMu.Unlock()
|
|
}
|
|
}(i, client)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
if len(errors) > 0 {
|
|
return fmt.Errorf("health check failures: %v", errors)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// healthCheckEndpoint performs a health check on a single endpoint
|
|
func (rm *RPCManager) healthCheckEndpoint(ctx context.Context, index int, client *RateLimitedClient) error {
|
|
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel()
|
|
|
|
start := time.Now()
|
|
|
|
// Try to get chain ID as a simple health check
|
|
if client == nil || client.Client == nil {
|
|
rm.RecordFailure(index)
|
|
rm.logger.Warn(fmt.Sprintf("⚠️ RPC endpoint %d is nil", index))
|
|
return fmt.Errorf("endpoint %d is nil", index)
|
|
}
|
|
|
|
_, err := client.Client.ChainID(checkCtx)
|
|
responseTime := time.Since(start)
|
|
|
|
if err != nil {
|
|
rm.RecordFailure(index)
|
|
return pkgerrors.WrapContextError(err, "RPCManager.healthCheckEndpoint",
|
|
map[string]interface{}{
|
|
"endpoint_index": index,
|
|
"response_time": responseTime.String(),
|
|
})
|
|
}
|
|
|
|
rm.RecordSuccess(index, responseTime)
|
|
return nil
|
|
}
|
|
|
|
// Close closes all RPC client connections
|
|
func (rm *RPCManager) Close() error {
|
|
rm.mu.Lock()
|
|
defer rm.mu.Unlock()
|
|
|
|
for i, client := range rm.endpoints {
|
|
if client != nil && client.Client != nil {
|
|
rm.logger.Debug(fmt.Sprintf("Closing RPC endpoint %d", i))
|
|
client.Client.Close()
|
|
}
|
|
}
|
|
|
|
rm.endpoints = nil
|
|
rm.health = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetStats returns a summary of all endpoint statistics
|
|
func (rm *RPCManager) GetStats() map[string]interface{} {
|
|
rm.mu.RLock()
|
|
defer rm.mu.RUnlock()
|
|
|
|
totalSuccess := int64(0)
|
|
totalFailure := int64(0)
|
|
healthyCount := 0
|
|
|
|
for _, h := range rm.health {
|
|
success, failure, _, healthy := h.GetStats()
|
|
totalSuccess += success
|
|
totalFailure += failure
|
|
if healthy {
|
|
healthyCount++
|
|
}
|
|
}
|
|
|
|
totalRequests := totalSuccess + totalFailure
|
|
successRate := 0.0
|
|
if totalRequests > 0 {
|
|
successRate = float64(totalSuccess) / float64(totalRequests) * 100
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"total_endpoints": len(rm.endpoints),
|
|
"healthy_count": healthyCount,
|
|
"total_requests": totalRequests,
|
|
"total_success": totalSuccess,
|
|
"total_failure": totalFailure,
|
|
"success_rate": fmt.Sprintf("%.2f%%", successRate),
|
|
"current_policy": rm.rotationPolicy,
|
|
"endpoint_details": rm.GetAllHealthStats(),
|
|
}
|
|
}
|