package circuit import ( "context" "fmt" "sync" "sync/atomic" "time" "github.com/fraktal/mev-beta/internal/logger" ) // State represents the circuit breaker state type State int32 const ( StateClosed State = iota StateHalfOpen StateOpen ) // String returns the string representation of the state func (s State) String() string { switch s { case StateClosed: return "CLOSED" case StateHalfOpen: return "HALF_OPEN" case StateOpen: return "OPEN" default: return "UNKNOWN" } } // Config holds circuit breaker configuration type Config struct { Name string MaxFailures uint64 ResetTimeout time.Duration MaxRequests uint64 SuccessThreshold uint64 OnStateChange func(name string, from State, to State) IsFailure func(error) bool Logger *logger.Logger } // Counts holds the circuit breaker statistics type Counts struct { Requests uint64 TotalSuccesses uint64 TotalFailures uint64 ConsecutiveSuccesses uint64 ConsecutiveFailures uint64 } // CircuitBreaker implements the circuit breaker pattern type CircuitBreaker struct { config *Config mutex sync.RWMutex state int32 generation uint64 counts Counts expiry time.Time } // NewCircuitBreaker creates a new circuit breaker func NewCircuitBreaker(config *Config) *CircuitBreaker { if config.MaxFailures == 0 { config.MaxFailures = 5 } if config.ResetTimeout == 0 { config.ResetTimeout = 60 * time.Second } if config.MaxRequests == 0 { config.MaxRequests = 1 } if config.SuccessThreshold == 0 { config.SuccessThreshold = 1 } if config.IsFailure == nil { config.IsFailure = func(err error) bool { return err != nil } } return &CircuitBreaker{ config: config, state: int32(StateClosed), generation: 0, counts: Counts{}, expiry: time.Now(), } } // Execute executes the given function with circuit breaker protection func (cb *CircuitBreaker) Execute(fn func() (interface{}, error)) (interface{}, error) { generation, err := cb.beforeRequest() if err != nil { return nil, err } defer func() { if e := recover(); e != nil { cb.afterRequest(generation, fmt.Errorf("panic: %v", e)) panic(e) } }() result, err := fn() cb.afterRequest(generation, err) return result, err } // ExecuteContext executes the given function with circuit breaker protection and context func (cb *CircuitBreaker) ExecuteContext(ctx context.Context, fn func(context.Context) (interface{}, error)) (interface{}, error) { generation, err := cb.beforeRequest() if err != nil { return nil, err } defer func() { if e := recover(); e != nil { cb.afterRequest(generation, fmt.Errorf("panic: %v", e)) panic(e) } }() // Check context cancellation select { case <-ctx.Done(): cb.afterRequest(generation, ctx.Err()) return nil, ctx.Err() default: } result, err := fn(ctx) cb.afterRequest(generation, err) return result, err } // beforeRequest checks if the request can proceed func (cb *CircuitBreaker) beforeRequest() (uint64, error) { cb.mutex.Lock() defer cb.mutex.Unlock() now := time.Now() state := cb.currentState(now) if state == StateOpen { return cb.generation, ErrOpenState } else if state == StateHalfOpen && cb.counts.Requests >= cb.config.MaxRequests { return cb.generation, ErrTooManyRequests } cb.counts.Requests++ return cb.generation, nil } // afterRequest processes the request result func (cb *CircuitBreaker) afterRequest(before uint64, err error) { cb.mutex.Lock() defer cb.mutex.Unlock() now := time.Now() state := cb.currentState(now) if before != cb.generation { return // generation mismatch, ignore } if cb.config.IsFailure(err) { cb.onFailure(state, now) } else { cb.onSuccess(state, now) } } // onFailure handles failure cases func (cb *CircuitBreaker) onFailure(state State, now time.Time) { cb.counts.TotalFailures++ cb.counts.ConsecutiveFailures++ cb.counts.ConsecutiveSuccesses = 0 switch state { case StateClosed: if cb.counts.ConsecutiveFailures >= cb.config.MaxFailures { cb.setState(StateOpen, now) } case StateHalfOpen: cb.setState(StateOpen, now) } } // onSuccess handles success cases func (cb *CircuitBreaker) onSuccess(state State, now time.Time) { cb.counts.TotalSuccesses++ cb.counts.ConsecutiveSuccesses++ cb.counts.ConsecutiveFailures = 0 switch state { case StateHalfOpen: if cb.counts.ConsecutiveSuccesses >= cb.config.SuccessThreshold { cb.setState(StateClosed, now) } } } // currentState returns the current state, potentially updating it func (cb *CircuitBreaker) currentState(now time.Time) State { switch State(atomic.LoadInt32(&cb.state)) { case StateClosed: if !cb.expiry.IsZero() && cb.expiry.Before(now) { cb.setState(StateClosed, now) } case StateOpen: if cb.expiry.Before(now) { cb.setState(StateHalfOpen, now) } } return State(atomic.LoadInt32(&cb.state)) } // setState changes the state of the circuit breaker func (cb *CircuitBreaker) setState(state State, now time.Time) { if cb.state == int32(state) { return } prev := State(cb.state) atomic.StoreInt32(&cb.state, int32(state)) cb.generation++ cb.counts = Counts{} var zero time.Time switch state { case StateClosed: cb.expiry = zero case StateOpen: cb.expiry = now.Add(cb.config.ResetTimeout) case StateHalfOpen: cb.expiry = zero } if cb.config.OnStateChange != nil { cb.config.OnStateChange(cb.config.Name, prev, state) } if cb.config.Logger != nil { cb.config.Logger.Info(fmt.Sprintf("Circuit breaker '%s' state changed from %s to %s", cb.config.Name, prev.String(), state.String())) } } // State returns the current state func (cb *CircuitBreaker) State() State { return State(atomic.LoadInt32(&cb.state)) } // Counts returns a copy of the current counts func (cb *CircuitBreaker) Counts() Counts { cb.mutex.RLock() defer cb.mutex.RUnlock() return cb.counts } // Name returns the name of the circuit breaker func (cb *CircuitBreaker) Name() string { return cb.config.Name } // Reset resets the circuit breaker to closed state func (cb *CircuitBreaker) Reset() { cb.mutex.Lock() defer cb.mutex.Unlock() cb.setState(StateClosed, time.Now()) } // Errors var ( ErrOpenState = fmt.Errorf("circuit breaker is open") ErrTooManyRequests = fmt.Errorf("too many requests") ) // TwoStepCircuitBreaker extends CircuitBreaker with two-step recovery type TwoStepCircuitBreaker struct { *CircuitBreaker failFast bool } // NewTwoStepCircuitBreaker creates a two-step circuit breaker func NewTwoStepCircuitBreaker(config *Config) *TwoStepCircuitBreaker { return &TwoStepCircuitBreaker{ CircuitBreaker: NewCircuitBreaker(config), failFast: true, } } // Allow checks if a request is allowed (non-blocking) func (cb *TwoStepCircuitBreaker) Allow() bool { _, err := cb.beforeRequest() return err == nil } // ReportResult reports the result of a request func (cb *TwoStepCircuitBreaker) ReportResult(success bool) { var err error if !success { err = fmt.Errorf("request failed") } cb.afterRequest(cb.generation, err) } // Manager manages multiple circuit breakers type Manager struct { breakers map[string]*CircuitBreaker mutex sync.RWMutex logger *logger.Logger } // NewManager creates a new circuit breaker manager func NewManager(logger *logger.Logger) *Manager { return &Manager{ breakers: make(map[string]*CircuitBreaker), logger: logger, } } // GetOrCreate gets an existing circuit breaker or creates a new one func (m *Manager) GetOrCreate(name string, config *Config) *CircuitBreaker { m.mutex.RLock() if breaker, exists := m.breakers[name]; exists { m.mutex.RUnlock() return breaker } m.mutex.RUnlock() m.mutex.Lock() defer m.mutex.Unlock() // Double-check after acquiring write lock if breaker, exists := m.breakers[name]; exists { return breaker } config.Name = name config.Logger = m.logger breaker := NewCircuitBreaker(config) m.breakers[name] = breaker return breaker } // Get gets a circuit breaker by name func (m *Manager) Get(name string) (*CircuitBreaker, bool) { m.mutex.RLock() defer m.mutex.RUnlock() breaker, exists := m.breakers[name] return breaker, exists } // Remove removes a circuit breaker func (m *Manager) Remove(name string) { m.mutex.Lock() defer m.mutex.Unlock() delete(m.breakers, name) } // List returns all circuit breaker names func (m *Manager) List() []string { m.mutex.RLock() defer m.mutex.RUnlock() names := make([]string, 0, len(m.breakers)) for name := range m.breakers { names = append(names, name) } return names } // Stats returns statistics for all circuit breakers func (m *Manager) Stats() map[string]interface{} { m.mutex.RLock() defer m.mutex.RUnlock() stats := make(map[string]interface{}) for name, breaker := range m.breakers { stats[name] = map[string]interface{}{ "state": breaker.State().String(), "counts": breaker.Counts(), } } return stats } // Reset resets all circuit breakers func (m *Manager) Reset() { m.mutex.RLock() defer m.mutex.RUnlock() for _, breaker := range m.breakers { breaker.Reset() } }