Files
mev-beta/pkg/circuit/breaker.go

409 lines
9.0 KiB
Go

package circuit
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// State represents the circuit breaker state
type State int32
const (
StateClosed State = iota
StateHalfOpen
StateOpen
)
// String returns the string representation of the state
func (s State) String() string {
switch s {
case StateClosed:
return "CLOSED"
case StateHalfOpen:
return "HALF_OPEN"
case StateOpen:
return "OPEN"
default:
return "UNKNOWN"
}
}
// Config holds circuit breaker configuration
type Config struct {
Name string
MaxFailures uint64
ResetTimeout time.Duration
MaxRequests uint64
SuccessThreshold uint64
OnStateChange func(name string, from State, to State)
IsFailure func(error) bool
Logger *logger.Logger
}
// Counts holds the circuit breaker statistics
type Counts struct {
Requests uint64
TotalSuccesses uint64
TotalFailures uint64
ConsecutiveSuccesses uint64
ConsecutiveFailures uint64
}
// CircuitBreaker implements the circuit breaker pattern
type CircuitBreaker struct {
config *Config
mutex sync.RWMutex
state int32
generation uint64
counts Counts
expiry time.Time
}
// NewCircuitBreaker creates a new circuit breaker
func NewCircuitBreaker(config *Config) *CircuitBreaker {
if config.MaxFailures == 0 {
config.MaxFailures = 5
}
if config.ResetTimeout == 0 {
config.ResetTimeout = 60 * time.Second
}
if config.MaxRequests == 0 {
config.MaxRequests = 1
}
if config.SuccessThreshold == 0 {
config.SuccessThreshold = 1
}
if config.IsFailure == nil {
config.IsFailure = func(err error) bool { return err != nil }
}
return &CircuitBreaker{
config: config,
state: int32(StateClosed),
generation: 0,
counts: Counts{},
expiry: time.Now(),
}
}
// Execute executes the given function with circuit breaker protection
func (cb *CircuitBreaker) Execute(fn func() (interface{}, error)) (interface{}, error) {
generation, err := cb.beforeRequest()
if err != nil {
return nil, err
}
defer func() {
if e := recover(); e != nil {
cb.afterRequest(generation, fmt.Errorf("panic: %v", e))
panic(e)
}
}()
result, err := fn()
cb.afterRequest(generation, err)
return result, err
}
// ExecuteContext executes the given function with circuit breaker protection and context
func (cb *CircuitBreaker) ExecuteContext(ctx context.Context, fn func(context.Context) (interface{}, error)) (interface{}, error) {
generation, err := cb.beforeRequest()
if err != nil {
return nil, err
}
defer func() {
if e := recover(); e != nil {
cb.afterRequest(generation, fmt.Errorf("panic: %v", e))
panic(e)
}
}()
// Check context cancellation
select {
case <-ctx.Done():
cb.afterRequest(generation, ctx.Err())
return nil, ctx.Err()
default:
}
result, err := fn(ctx)
cb.afterRequest(generation, err)
return result, err
}
// beforeRequest checks if the request can proceed
func (cb *CircuitBreaker) beforeRequest() (uint64, error) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
state := cb.currentState(now)
if state == StateOpen {
return cb.generation, ErrOpenState
} else if state == StateHalfOpen && cb.counts.Requests >= cb.config.MaxRequests {
return cb.generation, ErrTooManyRequests
}
cb.counts.Requests++
return cb.generation, nil
}
// afterRequest processes the request result
func (cb *CircuitBreaker) afterRequest(before uint64, err error) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
state := cb.currentState(now)
if before != cb.generation {
return // generation mismatch, ignore
}
if cb.config.IsFailure(err) {
cb.onFailure(state, now)
} else {
cb.onSuccess(state, now)
}
}
// onFailure handles failure cases
func (cb *CircuitBreaker) onFailure(state State, now time.Time) {
cb.counts.TotalFailures++
cb.counts.ConsecutiveFailures++
cb.counts.ConsecutiveSuccesses = 0
switch state {
case StateClosed:
if cb.counts.ConsecutiveFailures >= cb.config.MaxFailures {
cb.setState(StateOpen, now)
}
case StateHalfOpen:
cb.setState(StateOpen, now)
}
}
// onSuccess handles success cases
func (cb *CircuitBreaker) onSuccess(state State, now time.Time) {
cb.counts.TotalSuccesses++
cb.counts.ConsecutiveSuccesses++
cb.counts.ConsecutiveFailures = 0
switch state {
case StateHalfOpen:
if cb.counts.ConsecutiveSuccesses >= cb.config.SuccessThreshold {
cb.setState(StateClosed, now)
}
}
}
// currentState returns the current state, potentially updating it
func (cb *CircuitBreaker) currentState(now time.Time) State {
switch State(atomic.LoadInt32(&cb.state)) {
case StateClosed:
if !cb.expiry.IsZero() && cb.expiry.Before(now) {
cb.setState(StateClosed, now)
}
case StateOpen:
if cb.expiry.Before(now) {
cb.setState(StateHalfOpen, now)
}
}
return State(atomic.LoadInt32(&cb.state))
}
// setState changes the state of the circuit breaker
func (cb *CircuitBreaker) setState(state State, now time.Time) {
if cb.state == int32(state) {
return
}
prev := State(cb.state)
atomic.StoreInt32(&cb.state, int32(state))
cb.generation++
cb.counts = Counts{}
var zero time.Time
switch state {
case StateClosed:
cb.expiry = zero
case StateOpen:
cb.expiry = now.Add(cb.config.ResetTimeout)
case StateHalfOpen:
cb.expiry = zero
}
if cb.config.OnStateChange != nil {
cb.config.OnStateChange(cb.config.Name, prev, state)
}
if cb.config.Logger != nil {
cb.config.Logger.Info(fmt.Sprintf("Circuit breaker '%s' state changed from %s to %s",
cb.config.Name, prev.String(), state.String()))
}
}
// State returns the current state
func (cb *CircuitBreaker) State() State {
return State(atomic.LoadInt32(&cb.state))
}
// Counts returns a copy of the current counts
func (cb *CircuitBreaker) Counts() Counts {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.counts
}
// Name returns the name of the circuit breaker
func (cb *CircuitBreaker) Name() string {
return cb.config.Name
}
// Reset resets the circuit breaker to closed state
func (cb *CircuitBreaker) Reset() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.setState(StateClosed, time.Now())
}
// Errors
var (
ErrOpenState = fmt.Errorf("circuit breaker is open")
ErrTooManyRequests = fmt.Errorf("too many requests")
)
// TwoStepCircuitBreaker extends CircuitBreaker with two-step recovery
type TwoStepCircuitBreaker struct {
*CircuitBreaker
failFast bool
}
// NewTwoStepCircuitBreaker creates a two-step circuit breaker
func NewTwoStepCircuitBreaker(config *Config) *TwoStepCircuitBreaker {
return &TwoStepCircuitBreaker{
CircuitBreaker: NewCircuitBreaker(config),
failFast: true,
}
}
// Allow checks if a request is allowed (non-blocking)
func (cb *TwoStepCircuitBreaker) Allow() bool {
_, err := cb.beforeRequest()
return err == nil
}
// ReportResult reports the result of a request
func (cb *TwoStepCircuitBreaker) ReportResult(success bool) {
var err error
if !success {
err = fmt.Errorf("request failed")
}
cb.afterRequest(cb.generation, err)
}
// Manager manages multiple circuit breakers
type Manager struct {
breakers map[string]*CircuitBreaker
mutex sync.RWMutex
logger *logger.Logger
}
// NewManager creates a new circuit breaker manager
func NewManager(logger *logger.Logger) *Manager {
return &Manager{
breakers: make(map[string]*CircuitBreaker),
logger: logger,
}
}
// GetOrCreate gets an existing circuit breaker or creates a new one
func (m *Manager) GetOrCreate(name string, config *Config) *CircuitBreaker {
m.mutex.RLock()
if breaker, exists := m.breakers[name]; exists {
m.mutex.RUnlock()
return breaker
}
m.mutex.RUnlock()
m.mutex.Lock()
defer m.mutex.Unlock()
// Double-check after acquiring write lock
if breaker, exists := m.breakers[name]; exists {
return breaker
}
config.Name = name
config.Logger = m.logger
breaker := NewCircuitBreaker(config)
m.breakers[name] = breaker
return breaker
}
// Get gets a circuit breaker by name
func (m *Manager) Get(name string) (*CircuitBreaker, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
breaker, exists := m.breakers[name]
return breaker, exists
}
// Remove removes a circuit breaker
func (m *Manager) Remove(name string) {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.breakers, name)
}
// List returns all circuit breaker names
func (m *Manager) List() []string {
m.mutex.RLock()
defer m.mutex.RUnlock()
names := make([]string, 0, len(m.breakers))
for name := range m.breakers {
names = append(names, name)
}
return names
}
// Stats returns statistics for all circuit breakers
func (m *Manager) Stats() map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
stats := make(map[string]interface{})
for name, breaker := range m.breakers {
stats[name] = map[string]interface{}{
"state": breaker.State().String(),
"counts": breaker.Counts(),
}
}
return stats
}
// Reset resets all circuit breakers
func (m *Manager) Reset() {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, breaker := range m.breakers {
breaker.Reset()
}
}