feat: create v2-prep branch with comprehensive planning

Restructured project for V2 refactor:

**Structure Changes:**
- Moved all V1 code to orig/ folder (preserved with git mv)
- Created docs/planning/ directory
- Added orig/README_V1.md explaining V1 preservation

**Planning Documents:**
- 00_V2_MASTER_PLAN.md: Complete architecture overview
  - Executive summary of critical V1 issues
  - High-level component architecture diagrams
  - 5-phase implementation roadmap
  - Success metrics and risk mitigation

- 07_TASK_BREAKDOWN.md: Atomic task breakdown
  - 99+ hours of detailed tasks
  - Every task < 2 hours (atomic)
  - Clear dependencies and success criteria
  - Organized by implementation phase

**V2 Key Improvements:**
- Per-exchange parsers (factory pattern)
- Multi-layer strict validation
- Multi-index pool cache
- Background validation pipeline
- Comprehensive observability

**Critical Issues Addressed:**
- Zero address tokens (strict validation + cache enrichment)
- Parsing accuracy (protocol-specific parsers)
- No audit trail (background validation channel)
- Inefficient lookups (multi-index cache)
- Stats disconnection (event-driven metrics)

Next Steps:
1. Review planning documents
2. Begin Phase 1: Foundation (P1-001 through P1-010)
3. Implement parsers in Phase 2
4. Build cache system in Phase 3
5. Add validation pipeline in Phase 4
6. Migrate and test in Phase 5

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Administrator
2025-11-10 10:14:26 +01:00
parent 1773daffe7
commit 803de231ba
411 changed files with 20390 additions and 8680 deletions

View File

@@ -0,0 +1,661 @@
package lifecycle
import (
"context"
"fmt"
"reflect"
"sync"
)
// Container provides dependency injection functionality
type Container struct {
services map[reflect.Type]*ServiceDescriptor
instances map[reflect.Type]interface{}
namedServices map[string]*ServiceDescriptor
namedInstances map[string]interface{}
singletons map[reflect.Type]interface{}
factories map[reflect.Type]FactoryFunc
interceptors []Interceptor
config ContainerConfig
mu sync.RWMutex
parent *Container
scoped map[string]*Container
}
// ServiceDescriptor describes how a service should be instantiated
type ServiceDescriptor struct {
ServiceType reflect.Type
Implementation reflect.Type
Lifetime ServiceLifetime
Factory FactoryFunc
Instance interface{}
Name string
Dependencies []reflect.Type
Tags []string
Metadata map[string]interface{}
Interceptors []Interceptor
}
// ServiceLifetime defines the lifetime of a service
type ServiceLifetime string
const (
Transient ServiceLifetime = "transient" // New instance every time
Singleton ServiceLifetime = "singleton" // Single instance for container lifetime
Scoped ServiceLifetime = "scoped" // Single instance per scope
)
// FactoryFunc creates service instances
type FactoryFunc func(container *Container) (interface{}, error)
// Interceptor can intercept service creation and method calls
type Interceptor interface {
Intercept(ctx context.Context, target interface{}, method string, args []interface{}) (interface{}, error)
}
// ContainerConfig configures the container behavior
type ContainerConfig struct {
EnableReflection bool `json:"enable_reflection"`
EnableCircularDetection bool `json:"enable_circular_detection"`
EnableInterception bool `json:"enable_interception"`
EnableValidation bool `json:"enable_validation"`
MaxDepth int `json:"max_depth"`
CacheInstances bool `json:"cache_instances"`
}
// ServiceBuilder provides a fluent interface for service registration
type ServiceBuilder struct {
container *Container
serviceType reflect.Type
implType reflect.Type
lifetime ServiceLifetime
factory FactoryFunc
instance interface{}
name string
tags []string
metadata map[string]interface{}
interceptors []Interceptor
}
// NewContainer creates a new dependency injection container
func NewContainer(config ContainerConfig) *Container {
container := &Container{
services: make(map[reflect.Type]*ServiceDescriptor),
instances: make(map[reflect.Type]interface{}),
namedServices: make(map[string]*ServiceDescriptor),
namedInstances: make(map[string]interface{}),
singletons: make(map[reflect.Type]interface{}),
factories: make(map[reflect.Type]FactoryFunc),
interceptors: make([]Interceptor, 0),
config: config,
scoped: make(map[string]*Container),
}
// Set default configuration
if container.config.MaxDepth == 0 {
container.config.MaxDepth = 10
}
return container
}
// Register registers a service type with its implementation
func (c *Container) Register(serviceType, implementationType interface{}) *ServiceBuilder {
c.mu.Lock()
defer c.mu.Unlock()
sType := reflect.TypeOf(serviceType)
if sType.Kind() == reflect.Ptr {
sType = sType.Elem()
}
if sType.Kind() == reflect.Interface {
sType = reflect.TypeOf(serviceType).Elem()
}
implType := reflect.TypeOf(implementationType)
if implType.Kind() == reflect.Ptr {
implType = implType.Elem()
}
return &ServiceBuilder{
container: c,
serviceType: sType,
implType: implType,
lifetime: Transient,
tags: make([]string, 0),
metadata: make(map[string]interface{}),
interceptors: make([]Interceptor, 0),
}
}
// RegisterInstance registers a specific instance
func (c *Container) RegisterInstance(serviceType interface{}, instance interface{}) *ServiceBuilder {
c.mu.Lock()
defer c.mu.Unlock()
sType := reflect.TypeOf(serviceType)
if sType.Kind() == reflect.Ptr {
sType = sType.Elem()
}
if sType.Kind() == reflect.Interface {
sType = reflect.TypeOf(serviceType).Elem()
}
return &ServiceBuilder{
container: c,
serviceType: sType,
instance: instance,
lifetime: Singleton,
tags: make([]string, 0),
metadata: make(map[string]interface{}),
interceptors: make([]Interceptor, 0),
}
}
// RegisterFactory registers a factory function for creating instances
func (c *Container) RegisterFactory(serviceType interface{}, factory FactoryFunc) *ServiceBuilder {
c.mu.Lock()
defer c.mu.Unlock()
sType := reflect.TypeOf(serviceType)
if sType.Kind() == reflect.Ptr {
sType = sType.Elem()
}
if sType.Kind() == reflect.Interface {
sType = reflect.TypeOf(serviceType).Elem()
}
return &ServiceBuilder{
container: c,
serviceType: sType,
factory: factory,
lifetime: Transient,
tags: make([]string, 0),
metadata: make(map[string]interface{}),
interceptors: make([]Interceptor, 0),
}
}
// Resolve resolves a service instance by type
func (c *Container) Resolve(serviceType interface{}) (interface{}, error) {
sType := reflect.TypeOf(serviceType)
if sType.Kind() == reflect.Ptr {
sType = sType.Elem()
}
if sType.Kind() == reflect.Interface {
sType = reflect.TypeOf(serviceType).Elem()
}
return c.resolveType(sType, make(map[reflect.Type]bool), 0)
}
// ResolveNamed resolves a named service instance
func (c *Container) ResolveNamed(name string) (interface{}, error) {
c.mu.RLock()
defer c.mu.RUnlock()
// Check if instance already exists
if instance, exists := c.namedInstances[name]; exists {
return instance, nil
}
// Get service descriptor
descriptor, exists := c.namedServices[name]
if !exists {
return nil, fmt.Errorf("named service not found: %s", name)
}
return c.createInstance(descriptor, make(map[reflect.Type]bool), 0)
}
// ResolveAll resolves all services with a specific tag
func (c *Container) ResolveAll(tag string) ([]interface{}, error) {
c.mu.RLock()
defer c.mu.RUnlock()
var instances []interface{}
for _, descriptor := range c.services {
for _, serviceTag := range descriptor.Tags {
if serviceTag == tag {
instance, err := c.createInstance(descriptor, make(map[reflect.Type]bool), 0)
if err != nil {
return nil, fmt.Errorf("failed to resolve service with tag %s: %w", tag, err)
}
instances = append(instances, instance)
break
}
}
}
return instances, nil
}
// TryResolve attempts to resolve a service, returning nil if not found
func (c *Container) TryResolve(serviceType interface{}) interface{} {
instance, err := c.Resolve(serviceType)
if err != nil {
return nil
}
return instance
}
// IsRegistered checks if a service type is registered
func (c *Container) IsRegistered(serviceType interface{}) bool {
sType := reflect.TypeOf(serviceType)
if sType.Kind() == reflect.Ptr {
sType = sType.Elem()
}
if sType.Kind() == reflect.Interface {
sType = reflect.TypeOf(serviceType).Elem()
}
c.mu.RLock()
defer c.mu.RUnlock()
_, exists := c.services[sType]
return exists
}
// CreateScope creates a new scoped container
func (c *Container) CreateScope(name string) *Container {
c.mu.Lock()
defer c.mu.Unlock()
scope := &Container{
services: make(map[reflect.Type]*ServiceDescriptor),
instances: make(map[reflect.Type]interface{}),
namedServices: make(map[string]*ServiceDescriptor),
namedInstances: make(map[string]interface{}),
singletons: make(map[reflect.Type]interface{}),
factories: make(map[reflect.Type]FactoryFunc),
interceptors: make([]Interceptor, 0),
config: c.config,
parent: c,
scoped: make(map[string]*Container),
}
c.scoped[name] = scope
return scope
}
// GetScope retrieves a named scope
func (c *Container) GetScope(name string) (*Container, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
scope, exists := c.scoped[name]
return scope, exists
}
// AddInterceptor adds a global interceptor
func (c *Container) AddInterceptor(interceptor Interceptor) {
c.mu.Lock()
defer c.mu.Unlock()
c.interceptors = append(c.interceptors, interceptor)
}
// GetRegistrations returns all service registrations
func (c *Container) GetRegistrations() map[reflect.Type]*ServiceDescriptor {
c.mu.RLock()
defer c.mu.RUnlock()
registrations := make(map[reflect.Type]*ServiceDescriptor)
for t, desc := range c.services {
registrations[t] = desc
}
return registrations
}
// Validate validates all service registrations
func (c *Container) Validate() error {
c.mu.RLock()
defer c.mu.RUnlock()
if !c.config.EnableValidation {
return nil
}
for serviceType, descriptor := range c.services {
if err := c.validateDescriptor(serviceType, descriptor); err != nil {
return fmt.Errorf("validation failed for service %s: %w", serviceType.String(), err)
}
}
return nil
}
// Dispose cleans up the container and all instances
func (c *Container) Dispose() error {
c.mu.Lock()
defer c.mu.Unlock()
// Dispose all scoped containers
for _, scope := range c.scoped {
scope.Dispose()
}
// Clear all maps
c.services = make(map[reflect.Type]*ServiceDescriptor)
c.instances = make(map[reflect.Type]interface{})
c.namedServices = make(map[string]*ServiceDescriptor)
c.namedInstances = make(map[string]interface{})
c.singletons = make(map[reflect.Type]interface{})
c.factories = make(map[reflect.Type]FactoryFunc)
c.scoped = make(map[string]*Container)
return nil
}
// Private methods
func (c *Container) resolveType(serviceType reflect.Type, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
if depth > c.config.MaxDepth {
return nil, fmt.Errorf("maximum resolution depth exceeded for type %s", serviceType.String())
}
// Check for circular dependencies
if c.config.EnableCircularDetection && resolving[serviceType] {
return nil, fmt.Errorf("circular dependency detected for type %s", serviceType.String())
}
c.mu.RLock()
defer c.mu.RUnlock()
// Check if singleton instance exists
if instance, exists := c.singletons[serviceType]; exists {
return instance, nil
}
// Check if cached instance exists
if c.config.CacheInstances {
if instance, exists := c.instances[serviceType]; exists {
return instance, nil
}
}
// Get service descriptor
descriptor, exists := c.services[serviceType]
if !exists {
// Try parent container
if c.parent != nil {
return c.parent.resolveType(serviceType, resolving, depth+1)
}
return nil, fmt.Errorf("service not registered: %s", serviceType.String())
}
resolving[serviceType] = true
defer delete(resolving, serviceType)
return c.createInstance(descriptor, resolving, depth+1)
}
func (c *Container) createInstance(descriptor *ServiceDescriptor, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
// Use existing instance if available
if descriptor.Instance != nil {
return descriptor.Instance, nil
}
// Use factory if available
if descriptor.Factory != nil {
instance, err := descriptor.Factory(c)
if err != nil {
return nil, fmt.Errorf("factory failed for %s: %w", descriptor.ServiceType.String(), err)
}
if descriptor.Lifetime == Singleton {
c.singletons[descriptor.ServiceType] = instance
}
return c.applyInterceptors(instance, descriptor)
}
// Create instance using reflection
if descriptor.Implementation == nil {
return nil, fmt.Errorf("no implementation or factory provided for %s", descriptor.ServiceType.String())
}
instance, err := c.createInstanceByReflection(descriptor, resolving, depth)
if err != nil {
return nil, err
}
// Store singleton
if descriptor.Lifetime == Singleton {
c.singletons[descriptor.ServiceType] = instance
}
return c.applyInterceptors(instance, descriptor)
}
func (c *Container) createInstanceByReflection(descriptor *ServiceDescriptor, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
if !c.config.EnableReflection {
return nil, fmt.Errorf("reflection is disabled")
}
implType := descriptor.Implementation
if implType.Kind() == reflect.Ptr {
implType = implType.Elem()
}
// Find constructor (assumes first constructor or struct creation)
var constructorFunc reflect.Value
// Look for constructor function
_ = "New" + implType.Name() // constructorName not used yet
if implType.PkgPath() != "" {
// Try to find package-level constructor
// This is simplified - in a real implementation you'd use build tags or reflection
// to find the actual constructor functions
}
// Create instance
if constructorFunc.IsValid() {
// Use constructor function
return c.callConstructor(constructorFunc, resolving, depth)
} else {
// Create struct directly
return c.createStruct(implType, resolving, depth)
}
}
func (c *Container) createStruct(structType reflect.Type, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
// Create new instance
instance := reflect.New(structType)
elem := instance.Elem()
// Inject dependencies into fields
for i := 0; i < elem.NumField(); i++ {
field := elem.Field(i)
fieldType := elem.Type().Field(i)
// Check for dependency injection tags
if tag := fieldType.Tag.Get("inject"); tag != "" {
if !field.CanSet() {
continue
}
var dependency interface{}
var err error
if tag == "true" || tag == "" {
// Inject by type
dependency, err = c.resolveType(field.Type(), resolving, depth)
} else {
// Inject by name
dependency, err = c.ResolveNamed(tag)
}
if err != nil {
// Check if injection is optional
if optionalTag := fieldType.Tag.Get("optional"); optionalTag == "true" {
continue
}
return nil, fmt.Errorf("failed to inject dependency for field %s: %w", fieldType.Name, err)
}
field.Set(reflect.ValueOf(dependency))
}
}
return instance.Interface(), nil
}
func (c *Container) callConstructor(constructor reflect.Value, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
constructorType := constructor.Type()
args := make([]reflect.Value, constructorType.NumIn())
// Resolve constructor arguments
for i := 0; i < constructorType.NumIn(); i++ {
argType := constructorType.In(i)
arg, err := c.resolveType(argType, resolving, depth)
if err != nil {
return nil, fmt.Errorf("failed to resolve constructor argument %d (%s): %w", i, argType.String(), err)
}
args[i] = reflect.ValueOf(arg)
}
// Call constructor
results := constructor.Call(args)
if len(results) == 0 {
return nil, fmt.Errorf("constructor returned no values")
}
instance := results[0].Interface()
// Check for error result
if len(results) > 1 && !results[1].IsNil() {
if err, ok := results[1].Interface().(error); ok {
return nil, fmt.Errorf("constructor error: %w", err)
}
}
return instance, nil
}
func (c *Container) applyInterceptors(instance interface{}, descriptor *ServiceDescriptor) (interface{}, error) {
if !c.config.EnableInterception {
return instance, nil
}
// Apply service-specific interceptors
for _, interceptor := range descriptor.Interceptors {
// Apply interceptor (simplified - real implementation would create proxies)
_ = interceptor
}
// Apply global interceptors
for _, interceptor := range c.interceptors {
// Apply interceptor (simplified - real implementation would create proxies)
_ = interceptor
}
return instance, nil
}
func (c *Container) validateDescriptor(serviceType reflect.Type, descriptor *ServiceDescriptor) error {
// Validate that implementation implements the service interface
if descriptor.Implementation != nil {
if serviceType.Kind() == reflect.Interface {
if !descriptor.Implementation.Implements(serviceType) {
return fmt.Errorf("implementation %s does not implement interface %s",
descriptor.Implementation.String(), serviceType.String())
}
}
}
// Validate dependencies
for _, depType := range descriptor.Dependencies {
if !c.IsRegistered(depType) && (c.parent == nil || !c.parent.IsRegistered(depType)) {
return fmt.Errorf("dependency %s is not registered", depType.String())
}
}
return nil
}
// ServiceBuilder methods
// AsSingleton sets the service lifetime to singleton
func (sb *ServiceBuilder) AsSingleton() *ServiceBuilder {
sb.lifetime = Singleton
return sb
}
// AsTransient sets the service lifetime to transient
func (sb *ServiceBuilder) AsTransient() *ServiceBuilder {
sb.lifetime = Transient
return sb
}
// AsScoped sets the service lifetime to scoped
func (sb *ServiceBuilder) AsScoped() *ServiceBuilder {
sb.lifetime = Scoped
return sb
}
// WithName sets a name for the service
func (sb *ServiceBuilder) WithName(name string) *ServiceBuilder {
sb.name = name
return sb
}
// WithTag adds a tag to the service
func (sb *ServiceBuilder) WithTag(tag string) *ServiceBuilder {
sb.tags = append(sb.tags, tag)
return sb
}
// WithMetadata adds metadata to the service
func (sb *ServiceBuilder) WithMetadata(key string, value interface{}) *ServiceBuilder {
sb.metadata[key] = value
return sb
}
// WithInterceptor adds an interceptor to the service
func (sb *ServiceBuilder) WithInterceptor(interceptor Interceptor) *ServiceBuilder {
sb.interceptors = append(sb.interceptors, interceptor)
return sb
}
// Build finalizes the service registration
func (sb *ServiceBuilder) Build() error {
sb.container.mu.Lock()
defer sb.container.mu.Unlock()
descriptor := &ServiceDescriptor{
ServiceType: sb.serviceType,
Implementation: sb.implType,
Lifetime: sb.lifetime,
Factory: sb.factory,
Instance: sb.instance,
Name: sb.name,
Tags: sb.tags,
Metadata: sb.metadata,
Interceptors: sb.interceptors,
}
// Store by type
sb.container.services[sb.serviceType] = descriptor
// Store by name if provided
if sb.name != "" {
sb.container.namedServices[sb.name] = descriptor
}
return nil
}
// DefaultInterceptor provides basic interception functionality
type DefaultInterceptor struct {
name string
}
func NewDefaultInterceptor(name string) *DefaultInterceptor {
return &DefaultInterceptor{name: name}
}
func (di *DefaultInterceptor) Intercept(ctx context.Context, target interface{}, method string, args []interface{}) (interface{}, error) {
// Basic interception - could add logging, metrics, etc.
return nil, nil
}

View File

@@ -0,0 +1,109 @@
package lifecycle
import (
"errors"
"fmt"
"regexp"
"strings"
)
var txHashPattern = regexp.MustCompile(`0x[a-fA-F0-9]{64}`)
type RecordedError struct {
Err error
TxHash string
}
func (re RecordedError) Error() string {
if re.Err == nil {
return ""
}
return re.Err.Error()
}
func enrichErrorWithTxHash(message string, err error, attrs []interface{}) (error, string, []interface{}) {
txHash, attrsWithTx := ensureTxHash(attrs, err)
wrapped := fmt.Errorf("%s: %w", message, err)
if txHash != "" {
wrapped = fmt.Errorf("%s [tx_hash=%s]: %w", message, txHash, err)
}
return wrapped, txHash, attrsWithTx
}
func ensureTxHash(attrs []interface{}, err error) (string, []interface{}) {
txHash := extractTxHashFromAttrs(attrs)
if txHash == "" {
txHash = extractTxHashFromError(err)
}
if txHash == "" {
return "", attrs
}
hasTxAttr := false
for i := 0; i+1 < len(attrs); i += 2 {
key, ok := attrs[i].(string)
if !ok {
continue
}
if key == "tx_hash" || key == "transaction_hash" || key == "tx" {
hasTxAttr = true
break
}
}
if !hasTxAttr {
attrs = append(attrs, "tx_hash", txHash)
}
return txHash, attrs
}
func extractTxHashFromAttrs(attrs []interface{}) string {
for i := 0; i+1 < len(attrs); i += 2 {
key, ok := attrs[i].(string)
if !ok {
continue
}
if key == "tx_hash" || key == "transaction_hash" || key == "tx" {
if value, ok := attrs[i+1].(string); ok && isValidTxHash(value) {
return strings.ToLower(value)
}
}
}
return ""
}
func extractTxHashFromError(err error) string {
for err != nil {
if match := txHashPattern.FindString(err.Error()); match != "" {
return strings.ToLower(match)
}
err = errors.Unwrap(err)
}
return ""
}
func isValidTxHash(value string) bool {
if value == "" {
return false
}
if len(value) != 66 {
return false
}
if !strings.HasPrefix(value, "0x") {
return false
}
for _, r := range value[2:] {
if !isHexChar(r) {
return false
}
}
return true
}
func isHexChar(r rune) bool {
return (r >= '0' && r <= '9') ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
package lifecycle
import (
"fmt"
"strings"
"testing"
"time"
)
type stubHealthNotifier struct {
failUntil int
attempts int
txHash string
}
func (s *stubHealthNotifier) NotifyHealthChange(moduleID string, oldHealth, newHealth ModuleHealth) error {
s.attempts++
if s.attempts <= s.failUntil {
return fmt.Errorf("notify failure %d for tx %s", s.attempts, s.txHash)
}
return nil
}
func (s *stubHealthNotifier) NotifySystemHealth(health OverallHealth) error {
return nil
}
func (s *stubHealthNotifier) NotifyAlert(alert HealthAlert) error {
return nil
}
func TestHealthMonitorNotifyWithRetrySuccess(t *testing.T) {
config := HealthMonitorConfig{
NotificationRetries: 3,
NotificationRetryDelay: time.Nanosecond,
}
hm := NewHealthMonitor(config)
hm.logger = nil
notifier := &stubHealthNotifier{failUntil: 2, txHash: "0xdddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd"}
hm.notifier = notifier
err := hm.notifyWithRetry(func() error {
return notifier.NotifyHealthChange("module", ModuleHealth{}, ModuleHealth{})
}, "notify failure", "module_id", "module")
if err != nil {
t.Fatalf("expected notification to eventually succeed, got %v", err)
}
if notifier.attempts != 3 {
t.Fatalf("expected 3 attempts, got %d", notifier.attempts)
}
if errs := hm.NotificationErrors(); len(errs) != 0 {
t.Fatalf("expected no recorded notification errors, got %d", len(errs))
}
if hm.aggregatedNotificationError() != nil {
t.Fatal("expected aggregated notification error to be nil")
}
}
func TestHealthMonitorNotifyWithRetryFailure(t *testing.T) {
config := HealthMonitorConfig{
NotificationRetries: 2,
NotificationRetryDelay: time.Nanosecond,
}
hm := NewHealthMonitor(config)
hm.logger = nil
txHash := "0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
notifier := &stubHealthNotifier{failUntil: 3, txHash: txHash}
hm.notifier = notifier
err := hm.notifyWithRetry(func() error {
return notifier.NotifyHealthChange("module", ModuleHealth{}, ModuleHealth{})
}, "notify failure", "module_id", "module")
if err == nil {
t.Fatal("expected notification to fail after retries")
}
if notifier.attempts != 2 {
t.Fatalf("expected 2 attempts, got %d", notifier.attempts)
}
exported := hm.NotificationErrors()
if len(exported) != 1 {
t.Fatalf("expected 1 recorded notification error, got %d", len(exported))
}
copyErrs := hm.NotificationErrors()
if copyErrs[0] == nil {
t.Fatal("expected copy of notification errors to retain value")
}
if got := exported[0].Error(); !strings.Contains(got, txHash) {
t.Fatalf("recorded notification error should include tx hash, got %q", got)
}
details := hm.NotificationErrorDetails()
if len(details) != 1 {
t.Fatalf("expected notification error details entry, got %d", len(details))
}
if details[0].TxHash != txHash {
t.Fatalf("expected notification error detail to track tx hash %s, got %s", txHash, details[0].TxHash)
}
agg := hm.aggregatedNotificationError()
if agg == nil {
t.Fatal("expected aggregated notification error to be returned")
}
if got := agg.Error(); !strings.Contains(got, "notify failure") || !strings.Contains(got, "notify failure 1") || !strings.Contains(got, txHash) {
t.Fatalf("aggregated notification error should include failure details and tx hash, got %q", got)
}
}

View File

@@ -0,0 +1,417 @@
package lifecycle
import (
"context"
"fmt"
"sync"
"time"
)
// BaseModule provides a default implementation of the Module interface
type BaseModule struct {
id string
name string
version string
dependencies []string
state ModuleState
health ModuleHealth
metrics ModuleMetrics
config ModuleConfig
}
// NewBaseModule creates a new base module
func NewBaseModule(id, name, version string, dependencies []string) *BaseModule {
return &BaseModule{
id: id,
name: name,
version: version,
dependencies: dependencies,
state: StateUninitialized,
health: ModuleHealth{
Status: HealthUnknown,
},
metrics: ModuleMetrics{
CustomMetrics: make(map[string]interface{}),
},
}
}
// Core lifecycle methods
func (bm *BaseModule) Initialize(ctx context.Context, config ModuleConfig) error {
bm.config = config
bm.state = StateInitialized
bm.health.Status = HealthHealthy
return nil
}
func (bm *BaseModule) Start(ctx context.Context) error {
if bm.state != StateInitialized && bm.state != StateStopped {
return fmt.Errorf("invalid state for start: %s", bm.state)
}
startTime := time.Now()
bm.state = StateRunning
bm.metrics.StartupTime = time.Since(startTime)
bm.metrics.LastActivity = time.Now()
bm.health.Status = HealthHealthy
return nil
}
func (bm *BaseModule) Stop(ctx context.Context) error {
if bm.state != StateRunning && bm.state != StatePaused {
return fmt.Errorf("invalid state for stop: %s", bm.state)
}
stopTime := time.Now()
bm.state = StateStopped
bm.metrics.ShutdownTime = time.Since(stopTime)
bm.health.Status = HealthUnknown
return nil
}
func (bm *BaseModule) Pause(ctx context.Context) error {
if bm.state != StateRunning {
return fmt.Errorf("invalid state for pause: %s", bm.state)
}
bm.state = StatePaused
return nil
}
func (bm *BaseModule) Resume(ctx context.Context) error {
if bm.state != StatePaused {
return fmt.Errorf("invalid state for resume: %s", bm.state)
}
bm.state = StateRunning
bm.metrics.LastActivity = time.Now()
return nil
}
// Module information
func (bm *BaseModule) GetID() string {
return bm.id
}
func (bm *BaseModule) GetName() string {
return bm.name
}
func (bm *BaseModule) GetVersion() string {
return bm.version
}
func (bm *BaseModule) GetDependencies() []string {
return bm.dependencies
}
// Health and status
func (bm *BaseModule) GetHealth() ModuleHealth {
bm.health.LastCheck = time.Now()
return bm.health
}
func (bm *BaseModule) GetState() ModuleState {
return bm.state
}
func (bm *BaseModule) GetMetrics() ModuleMetrics {
return bm.metrics
}
// Protected methods for subclasses
func (bm *BaseModule) SetHealth(status HealthStatus, message string) {
bm.health.Status = status
bm.health.Message = message
bm.health.LastCheck = time.Now()
}
func (bm *BaseModule) SetState(state ModuleState) {
bm.state = state
}
func (bm *BaseModule) UpdateMetrics(updates map[string]interface{}) {
for key, value := range updates {
bm.metrics.CustomMetrics[key] = value
}
bm.metrics.LastActivity = time.Now()
}
func (bm *BaseModule) IncrementMetric(name string, value int64) {
if current, exists := bm.metrics.CustomMetrics[name]; exists {
if currentVal, ok := current.(int64); ok {
bm.metrics.CustomMetrics[name] = currentVal + value
} else {
bm.metrics.CustomMetrics[name] = value
}
} else {
bm.metrics.CustomMetrics[name] = value
}
}
// SimpleEventBus provides a basic implementation of EventBus
type SimpleEventBus struct {
handlers map[EventType][]EventHandler
mu sync.RWMutex
}
func NewSimpleEventBus() *SimpleEventBus {
return &SimpleEventBus{
handlers: make(map[EventType][]EventHandler),
}
}
func (seb *SimpleEventBus) Publish(event ModuleEvent) error {
seb.mu.RLock()
defer seb.mu.RUnlock()
handlers, exists := seb.handlers[event.Type]
if !exists {
return nil
}
for _, handler := range handlers {
go func(h EventHandler) {
if err := h(event); err != nil {
// Log error but don't fail the publish
}
}(handler)
}
return nil
}
func (seb *SimpleEventBus) Subscribe(eventType EventType, handler EventHandler) error {
seb.mu.Lock()
defer seb.mu.Unlock()
if _, exists := seb.handlers[eventType]; !exists {
seb.handlers[eventType] = make([]EventHandler, 0)
}
seb.handlers[eventType] = append(seb.handlers[eventType], handler)
return nil
}
// LifecycleManager coordinates all lifecycle components
type LifecycleManager struct {
registry *ModuleRegistry
healthMonitor *HealthMonitorImpl
shutdownManager *ShutdownManager
container *Container
eventBus *SimpleEventBus
config LifecycleConfig
mu sync.RWMutex
}
// LifecycleConfig configures the lifecycle manager
type LifecycleConfig struct {
RegistryConfig RegistryConfig `json:"registry_config"`
HealthMonitorConfig HealthMonitorConfig `json:"health_monitor_config"`
ShutdownConfig ShutdownConfig `json:"shutdown_config"`
ContainerConfig ContainerConfig `json:"container_config"`
EnableEventBus bool `json:"enable_event_bus"`
EnableHealthMonitor bool `json:"enable_health_monitor"`
EnableShutdownManager bool `json:"enable_shutdown_manager"`
EnableDependencyInjection bool `json:"enable_dependency_injection"`
}
// NewLifecycleManager creates a new lifecycle manager
func NewLifecycleManager(config LifecycleConfig) *LifecycleManager {
lm := &LifecycleManager{
config: config,
}
// Create event bus
if config.EnableEventBus {
lm.eventBus = NewSimpleEventBus()
}
// Create dependency injection container
if config.EnableDependencyInjection {
lm.container = NewContainer(config.ContainerConfig)
}
// Create module registry
lm.registry = NewModuleRegistry(config.RegistryConfig)
if lm.eventBus != nil {
lm.registry.SetEventBus(lm.eventBus)
}
// Create health monitor
if config.EnableHealthMonitor {
lm.healthMonitor = NewHealthMonitor(config.HealthMonitorConfig)
lm.registry.SetHealthMonitor(lm.healthMonitor)
}
// Create shutdown manager
if config.EnableShutdownManager {
lm.shutdownManager = NewShutdownManager(lm.registry, config.ShutdownConfig)
}
return lm
}
// Initialize initializes the lifecycle manager
func (lm *LifecycleManager) Initialize(ctx context.Context) error {
lm.mu.Lock()
defer lm.mu.Unlock()
// Validate container if enabled
if lm.container != nil {
if err := lm.container.Validate(); err != nil {
return fmt.Errorf("container validation failed: %w", err)
}
}
// Initialize registry
if err := lm.registry.Initialize(ctx); err != nil {
return fmt.Errorf("registry initialization failed: %w", err)
}
// Start health monitor
if lm.healthMonitor != nil {
if err := lm.healthMonitor.Start(); err != nil {
return fmt.Errorf("health monitor start failed: %w", err)
}
}
// Start shutdown manager
if lm.shutdownManager != nil {
if err := lm.shutdownManager.Start(); err != nil {
return fmt.Errorf("shutdown manager start failed: %w", err)
}
}
return nil
}
// Start starts all modules
func (lm *LifecycleManager) Start(ctx context.Context) error {
return lm.registry.StartAll(ctx)
}
// Stop stops all modules
func (lm *LifecycleManager) Stop(ctx context.Context) error {
return lm.registry.StopAll(ctx)
}
// Shutdown gracefully shuts down the entire system
func (lm *LifecycleManager) Shutdown(ctx context.Context) error {
if lm.shutdownManager != nil {
return lm.shutdownManager.Shutdown(ctx)
}
return lm.registry.Shutdown(ctx)
}
// RegisterModule registers a new module
func (lm *LifecycleManager) RegisterModule(module Module, config ModuleConfig) error {
return lm.registry.Register(module, config)
}
// GetModule retrieves a module by ID
func (lm *LifecycleManager) GetModule(moduleID string) (Module, error) {
return lm.registry.Get(moduleID)
}
// GetRegistry returns the module registry
func (lm *LifecycleManager) GetRegistry() *ModuleRegistry {
lm.mu.RLock()
defer lm.mu.RUnlock()
return lm.registry
}
// GetHealthMonitor returns the health monitor
func (lm *LifecycleManager) GetHealthMonitor() *HealthMonitorImpl {
lm.mu.RLock()
defer lm.mu.RUnlock()
return lm.healthMonitor
}
// GetShutdownManager returns the shutdown manager
func (lm *LifecycleManager) GetShutdownManager() *ShutdownManager {
lm.mu.RLock()
defer lm.mu.RUnlock()
return lm.shutdownManager
}
// GetContainer returns the dependency injection container
func (lm *LifecycleManager) GetContainer() *Container {
lm.mu.RLock()
defer lm.mu.RUnlock()
return lm.container
}
// GetEventBus returns the event bus
func (lm *LifecycleManager) GetEventBus() *SimpleEventBus {
lm.mu.RLock()
defer lm.mu.RUnlock()
return lm.eventBus
}
// GetOverallHealth returns the overall system health
func (lm *LifecycleManager) GetOverallHealth() (OverallHealth, error) {
if lm.healthMonitor == nil {
return OverallHealth{}, fmt.Errorf("health monitor not enabled")
}
return lm.healthMonitor.GetOverallHealth(), nil
}
// CreateDefaultConfig creates a default lifecycle configuration
func CreateDefaultConfig() LifecycleConfig {
return LifecycleConfig{
RegistryConfig: RegistryConfig{
StartTimeout: 30 * time.Second,
StopTimeout: 15 * time.Second,
HealthCheckInterval: 30 * time.Second,
EnableMetrics: true,
EnableHealthMonitor: true,
ParallelStartup: false,
ParallelShutdown: true,
FailureRecovery: true,
AutoRestart: true,
MaxRestartAttempts: 3,
},
HealthMonitorConfig: HealthMonitorConfig{
CheckInterval: 30 * time.Second,
CheckTimeout: 10 * time.Second,
HistorySize: 100,
FailureThreshold: 3,
RecoveryThreshold: 3,
EnableNotifications: true,
EnableMetrics: true,
EnableTrends: true,
ParallelChecks: true,
MaxConcurrentChecks: 10,
},
ShutdownConfig: ShutdownConfig{
GracefulTimeout: 30 * time.Second,
ForceTimeout: 60 * time.Second,
SignalBufferSize: 10,
MaxRetries: 3,
RetryDelay: time.Second,
ParallelShutdown: true,
SaveState: true,
CleanupTempFiles: true,
NotifyExternal: false,
WaitForConnections: true,
EnableMetrics: true,
},
ContainerConfig: ContainerConfig{
EnableReflection: true,
EnableCircularDetection: true,
EnableInterception: false,
EnableValidation: true,
MaxDepth: 10,
CacheInstances: true,
},
EnableEventBus: true,
EnableHealthMonitor: true,
EnableShutdownManager: true,
EnableDependencyInjection: true,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
package lifecycle
import (
"fmt"
"strings"
"testing"
"time"
)
type stubEventBus struct {
failUntil int
attempts int
txHash string
}
func (s *stubEventBus) Publish(event ModuleEvent) error {
s.attempts++
if s.attempts <= s.failUntil {
return fmt.Errorf("publish failure %d for tx %s", s.attempts, s.txHash)
}
return nil
}
func (s *stubEventBus) Subscribe(eventType EventType, handler EventHandler) error {
return nil
}
func TestModuleRegistryPublishEventWithRetrySuccess(t *testing.T) {
registry := NewModuleRegistry(RegistryConfig{
EventPublishRetries: 3,
EventPublishDelay: time.Nanosecond,
})
registry.logger = nil
bus := &stubEventBus{failUntil: 2, txHash: "0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}
registry.eventBus = bus
err := registry.publishEventWithRetry(ModuleEvent{ModuleID: "module-A", Type: EventModuleStarted}, "publish failed")
if err != nil {
t.Fatalf("expected publish to eventually succeed, got %v", err)
}
if bus.attempts != 3 {
t.Fatalf("expected 3 attempts, got %d", bus.attempts)
}
if err := registry.aggregatedErrors(); err != nil {
t.Fatalf("expected no aggregated errors, got %v", err)
}
}
func TestModuleRegistryPublishEventWithRetryFailure(t *testing.T) {
registry := NewModuleRegistry(RegistryConfig{
EventPublishRetries: 2,
EventPublishDelay: time.Nanosecond,
})
registry.logger = nil
txHash := "0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc"
bus := &stubEventBus{failUntil: 3, txHash: txHash}
registry.eventBus = bus
err := registry.publishEventWithRetry(ModuleEvent{ModuleID: "module-B", Type: EventModuleStopped}, "publish failed")
if err == nil {
t.Fatal("expected publish to fail after retries")
}
if bus.attempts != 2 {
t.Fatalf("expected 2 attempts, got %d", bus.attempts)
}
exported := registry.RegistryErrors()
if len(exported) != 1 {
t.Fatalf("expected 1 recorded registry error, got %d", len(exported))
}
if exported[0] == nil {
t.Fatal("expected recorded error to be non-nil")
}
if copyErrs := registry.RegistryErrors(); copyErrs[0] == nil {
t.Fatal("copy of registry errors should preserve values")
}
if got := exported[0].Error(); !strings.Contains(got, txHash) {
t.Fatalf("recorded registry error should include tx hash, got %q", got)
}
details := registry.RegistryErrorDetails()
if len(details) != 1 {
t.Fatalf("expected registry error details to include entry, got %d", len(details))
}
if details[0].TxHash != txHash {
t.Fatalf("expected registry error detail to track tx hash %s, got %s", txHash, details[0].TxHash)
}
agg := registry.aggregatedErrors()
if agg == nil {
t.Fatal("expected aggregated error to be returned")
}
if got := agg.Error(); !strings.Contains(got, "publish failed") || !strings.Contains(got, "publish failure 1") || !strings.Contains(got, txHash) {
t.Fatalf("aggregated error should include failure details and tx hash, got %q", got)
}
}

View File

@@ -0,0 +1,876 @@
package lifecycle
import (
"context"
"errors"
"fmt"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// ShutdownManager handles graceful shutdown of the application
type ShutdownManager struct {
registry *ModuleRegistry
shutdownTasks []ShutdownTask
shutdownHooks []ShutdownHook
config ShutdownConfig
signalChannel chan os.Signal
shutdownChannel chan struct{}
state ShutdownState
startTime time.Time
shutdownStarted time.Time
mu sync.RWMutex
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
logger *logger.Logger
shutdownErrors []error
shutdownErrorDetails []RecordedError
errMu sync.Mutex
exitFunc func(code int)
emergencyHandler func(ctx context.Context, reason string, err error) error
}
// ShutdownTask represents a task to be executed during shutdown
type ShutdownTask struct {
Name string
Priority int
Timeout time.Duration
Task func(ctx context.Context) error
OnError func(error)
Critical bool
Enabled bool
}
// ShutdownHook is called at different stages of shutdown
type ShutdownHook interface {
OnShutdownStarted(ctx context.Context) error
OnModulesStopped(ctx context.Context) error
OnCleanupStarted(ctx context.Context) error
OnShutdownCompleted(ctx context.Context) error
OnShutdownFailed(ctx context.Context, err error) error
}
// ShutdownConfig configures shutdown behavior
type ShutdownConfig struct {
GracefulTimeout time.Duration `json:"graceful_timeout"`
ForceTimeout time.Duration `json:"force_timeout"`
SignalBufferSize int `json:"signal_buffer_size"`
MaxRetries int `json:"max_retries"`
RetryDelay time.Duration `json:"retry_delay"`
ParallelShutdown bool `json:"parallel_shutdown"`
SaveState bool `json:"save_state"`
CleanupTempFiles bool `json:"cleanup_temp_files"`
NotifyExternal bool `json:"notify_external"`
WaitForConnections bool `json:"wait_for_connections"`
EnableMetrics bool `json:"enable_metrics"`
}
// ShutdownState represents the current shutdown state
type ShutdownState string
const (
ShutdownStateRunning ShutdownState = "running"
ShutdownStateInitiated ShutdownState = "initiated"
ShutdownStateModuleStop ShutdownState = "stopping_modules"
ShutdownStateCleanup ShutdownState = "cleanup"
ShutdownStateCompleted ShutdownState = "completed"
ShutdownStateFailed ShutdownState = "failed"
ShutdownStateForced ShutdownState = "forced"
)
// ShutdownMetrics tracks shutdown performance
type ShutdownMetrics struct {
ShutdownInitiated time.Time `json:"shutdown_initiated"`
ModuleStopTime time.Duration `json:"module_stop_time"`
CleanupTime time.Duration `json:"cleanup_time"`
TotalShutdownTime time.Duration `json:"total_shutdown_time"`
TasksExecuted int `json:"tasks_executed"`
TasksSuccessful int `json:"tasks_successful"`
TasksFailed int `json:"tasks_failed"`
RetryAttempts int `json:"retry_attempts"`
ForceShutdown bool `json:"force_shutdown"`
Signal string `json:"signal"`
}
// ShutdownProgress tracks shutdown progress
type ShutdownProgress struct {
State ShutdownState `json:"state"`
Progress float64 `json:"progress"`
CurrentTask string `json:"current_task"`
CompletedTasks int `json:"completed_tasks"`
TotalTasks int `json:"total_tasks"`
ElapsedTime time.Duration `json:"elapsed_time"`
EstimatedRemaining time.Duration `json:"estimated_remaining"`
Message string `json:"message"`
}
// NewShutdownManager creates a new shutdown manager
func NewShutdownManager(registry *ModuleRegistry, config ShutdownConfig) *ShutdownManager {
ctx, cancel := context.WithCancel(context.Background())
sm := &ShutdownManager{
registry: registry,
shutdownTasks: make([]ShutdownTask, 0),
shutdownHooks: make([]ShutdownHook, 0),
config: config,
signalChannel: make(chan os.Signal, config.SignalBufferSize),
shutdownChannel: make(chan struct{}),
state: ShutdownStateRunning,
startTime: time.Now(),
ctx: ctx,
cancel: cancel,
shutdownErrors: make([]error, 0),
shutdownErrorDetails: make([]RecordedError, 0),
exitFunc: os.Exit,
}
// Set default configuration
if sm.config.GracefulTimeout == 0 {
sm.config.GracefulTimeout = 30 * time.Second
}
if sm.config.ForceTimeout == 0 {
sm.config.ForceTimeout = 60 * time.Second
}
if sm.config.SignalBufferSize == 0 {
sm.config.SignalBufferSize = 10
}
if sm.config.MaxRetries == 0 {
sm.config.MaxRetries = 3
}
if sm.config.RetryDelay == 0 {
sm.config.RetryDelay = time.Second
}
if err := os.MkdirAll("logs", 0o755); err != nil {
fmt.Printf("failed to ensure logs directory: %v\n", err)
}
sm.logger = logger.New("info", "", "logs/lifecycle_shutdown.log")
// Setup default shutdown tasks
sm.setupDefaultTasks()
// Setup signal handling
sm.setupSignalHandling()
return sm
}
// Start starts the shutdown manager
func (sm *ShutdownManager) Start() error {
sm.mu.Lock()
defer sm.mu.Unlock()
if sm.state != ShutdownStateRunning {
return fmt.Errorf("shutdown manager not in running state: %s", sm.state)
}
// Start signal monitoring
go sm.signalHandler()
return nil
}
// Shutdown initiates graceful shutdown
func (sm *ShutdownManager) Shutdown(ctx context.Context) error {
sm.mu.Lock()
defer sm.mu.Unlock()
if sm.state != ShutdownStateRunning {
return fmt.Errorf("shutdown already initiated: %s", sm.state)
}
sm.state = ShutdownStateInitiated
sm.shutdownStarted = time.Now()
// Close shutdown channel to signal shutdown
close(sm.shutdownChannel)
err := sm.performShutdown(ctx)
combined := sm.combinedShutdownError()
if combined != nil {
return combined
}
return err
}
// ForceShutdown forces immediate shutdown
func (sm *ShutdownManager) ForceShutdown(ctx context.Context) error {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.state = ShutdownStateForced
sm.cancel() // Cancel all operations
// Force stop all modules immediately
var forceErr error
if sm.registry != nil {
forceCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sm.registry.StopAll(forceCtx); err != nil {
wrapped := fmt.Errorf("StopAll failed during force shutdown: %w", err)
sm.recordShutdownError("StopAll failed in force shutdown", wrapped)
forceErr = errors.Join(forceErr, wrapped)
}
}
if forceErr != nil {
sm.recordShutdownError("Force shutdown encountered errors", forceErr)
}
sm.exitFunc(1)
return forceErr
}
// AddShutdownTask adds a task to be executed during shutdown
func (sm *ShutdownManager) AddShutdownTask(task ShutdownTask) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.shutdownTasks = append(sm.shutdownTasks, task)
sm.sortTasksByPriority()
}
// AddShutdownHook adds a hook to be called during shutdown phases
func (sm *ShutdownManager) AddShutdownHook(hook ShutdownHook) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.shutdownHooks = append(sm.shutdownHooks, hook)
}
// GetState returns the current shutdown state
func (sm *ShutdownManager) GetState() ShutdownState {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state
}
// GetProgress returns the current shutdown progress
func (sm *ShutdownManager) GetProgress() ShutdownProgress {
sm.mu.RLock()
defer sm.mu.RUnlock()
totalTasks := len(sm.shutdownTasks)
if sm.registry != nil {
totalTasks += len(sm.registry.List())
}
var progress float64
var completedTasks int
var currentTask string
switch sm.state {
case ShutdownStateRunning:
progress = 0
currentTask = "Running"
case ShutdownStateInitiated:
progress = 0.1
currentTask = "Shutdown initiated"
case ShutdownStateModuleStop:
progress = 0.3
currentTask = "Stopping modules"
completedTasks = totalTasks / 3
case ShutdownStateCleanup:
progress = 0.7
currentTask = "Cleanup"
completedTasks = (totalTasks * 2) / 3
case ShutdownStateCompleted:
progress = 1.0
currentTask = "Completed"
completedTasks = totalTasks
case ShutdownStateFailed:
progress = 0.8
currentTask = "Failed"
case ShutdownStateForced:
progress = 1.0
currentTask = "Forced shutdown"
completedTasks = totalTasks
}
elapsedTime := time.Since(sm.shutdownStarted)
var estimatedRemaining time.Duration
if progress > 0 && progress < 1.0 {
totalEstimated := time.Duration(float64(elapsedTime) / progress)
estimatedRemaining = totalEstimated - elapsedTime
}
return ShutdownProgress{
State: sm.state,
Progress: progress,
CurrentTask: currentTask,
CompletedTasks: completedTasks,
TotalTasks: totalTasks,
ElapsedTime: elapsedTime,
EstimatedRemaining: estimatedRemaining,
Message: fmt.Sprintf("Shutdown %s", sm.state),
}
}
// Wait waits for shutdown to complete
func (sm *ShutdownManager) Wait() {
<-sm.shutdownChannel
sm.wg.Wait()
}
// WaitWithTimeout waits for shutdown with timeout
func (sm *ShutdownManager) WaitWithTimeout(timeout time.Duration) error {
done := make(chan struct{})
go func() {
sm.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-time.After(timeout):
return fmt.Errorf("shutdown timeout after %v", timeout)
}
}
// Private methods
func (sm *ShutdownManager) setupSignalHandling() {
signal.Notify(sm.signalChannel,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
syscall.SIGHUP,
)
}
func (sm *ShutdownManager) setupDefaultTasks() {
// Task: Save application state
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
Name: "save_state",
Priority: 100,
Timeout: 10 * time.Second,
Task: func(ctx context.Context) error {
if sm.config.SaveState {
return sm.saveApplicationState(ctx)
}
return nil
},
Critical: false,
Enabled: sm.config.SaveState,
})
// Task: Close external connections
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
Name: "close_connections",
Priority: 90,
Timeout: 5 * time.Second,
Task: func(ctx context.Context) error {
return sm.closeExternalConnections(ctx)
},
Critical: false,
Enabled: sm.config.WaitForConnections,
})
// Task: Cleanup temporary files
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
Name: "cleanup_temp_files",
Priority: 10,
Timeout: 5 * time.Second,
Task: func(ctx context.Context) error {
if sm.config.CleanupTempFiles {
return sm.cleanupTempFiles(ctx)
}
return nil
},
Critical: false,
Enabled: sm.config.CleanupTempFiles,
})
// Task: Notify external systems
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
Name: "notify_external",
Priority: 80,
Timeout: 3 * time.Second,
Task: func(ctx context.Context) error {
if sm.config.NotifyExternal {
return sm.notifyExternalSystems(ctx)
}
return nil
},
Critical: false,
Enabled: sm.config.NotifyExternal,
})
}
func (sm *ShutdownManager) signalHandler() {
for {
select {
case sig := <-sm.signalChannel:
switch sig {
case syscall.SIGINT, syscall.SIGTERM:
// Graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), sm.config.GracefulTimeout)
if err := sm.Shutdown(ctx); err != nil {
sm.recordShutdownError("Graceful shutdown failed from signal", err)
cancel()
// Force shutdown if graceful fails
forceCtx, forceCancel := context.WithTimeout(context.Background(), sm.config.ForceTimeout)
if err := sm.ForceShutdown(forceCtx); err != nil {
sm.recordShutdownError("Force shutdown error in timeout scenario", err)
// CRITICAL FIX: Escalate force shutdown failure to emergency protocols
sm.triggerEmergencyShutdown("Force shutdown failed after graceful timeout", err)
}
forceCancel()
}
cancel()
return
case syscall.SIGQUIT:
// Force shutdown
ctx, cancel := context.WithTimeout(context.Background(), sm.config.ForceTimeout)
if err := sm.ForceShutdown(ctx); err != nil {
sm.recordShutdownError("Force shutdown error in SIGQUIT handler", err)
// CRITICAL FIX: Escalate force shutdown failure to emergency protocols
sm.triggerEmergencyShutdown("Force shutdown failed on SIGQUIT", err)
}
cancel()
return
case syscall.SIGHUP:
// Reload signal - could be used for configuration reload
// For now, just log it
continue
}
case <-sm.ctx.Done():
return
}
}
}
func (sm *ShutdownManager) performShutdown(ctx context.Context) error {
sm.wg.Add(1)
defer sm.wg.Done()
// Create timeout context for entire shutdown
shutdownCtx, cancel := context.WithTimeout(ctx, sm.config.GracefulTimeout)
defer cancel()
var phaseErrors []error
// Phase 1: Call shutdown started hooks
if err := sm.callHooks(shutdownCtx, "OnShutdownStarted", nil); err != nil {
wrapped := fmt.Errorf("shutdown start hooks failed: %w", err)
sm.recordShutdownError("Shutdown started hook failure", wrapped)
phaseErrors = append(phaseErrors, wrapped)
}
// Phase 2: Stop modules
sm.state = ShutdownStateModuleStop
if sm.registry != nil {
if err := sm.registry.StopAll(shutdownCtx); err != nil {
wrapped := fmt.Errorf("failed to stop modules: %w", err)
sm.recordShutdownError("Module stop failure", wrapped)
phaseErrors = append(phaseErrors, wrapped)
}
}
// Call modules stopped hooks
if err := sm.callHooks(shutdownCtx, "OnModulesStopped", nil); err != nil {
wrapped := fmt.Errorf("modules stopped hooks failed: %w", err)
sm.recordShutdownError("Modules stopped hook failure", wrapped)
phaseErrors = append(phaseErrors, wrapped)
}
// Phase 3: Execute shutdown tasks
sm.state = ShutdownStateCleanup
if err := sm.callHooks(shutdownCtx, "OnCleanupStarted", nil); err != nil {
wrapped := fmt.Errorf("cleanup hooks failed: %w", err)
sm.recordShutdownError("Cleanup hook failure", wrapped)
phaseErrors = append(phaseErrors, wrapped)
}
if err := sm.executeShutdownTasks(shutdownCtx); err != nil {
wrapped := fmt.Errorf("shutdown tasks failed: %w", err)
sm.recordShutdownError("Shutdown task execution failure", wrapped)
phaseErrors = append(phaseErrors, wrapped)
}
// Phase 4: Final cleanup
if len(phaseErrors) > 0 {
finalErr := errors.Join(phaseErrors...)
sm.state = ShutdownStateFailed
if err := sm.callHooks(shutdownCtx, "OnShutdownFailed", finalErr); err != nil {
wrapped := fmt.Errorf("shutdown failed hook error: %w", err)
sm.recordShutdownError("Shutdown failed hook error", wrapped)
finalErr = errors.Join(finalErr, wrapped)
// CRITICAL FIX: Escalate hook failure during shutdown failed state
sm.triggerEmergencyShutdown("Shutdown failed hook error", wrapped)
}
return finalErr
}
sm.state = ShutdownStateCompleted
if err := sm.callHooks(shutdownCtx, "OnShutdownCompleted", nil); err != nil {
wrapped := fmt.Errorf("shutdown completed hook error: %w", err)
sm.recordShutdownError("Shutdown completed hook error", wrapped)
// CRITICAL FIX: Log but don't fail shutdown for completion hook errors
// These are non-critical notifications that shouldn't prevent successful shutdown
sm.logger.Warn("Shutdown completed hook failed", "error", wrapped)
// Don't return error for completion hook failures - shutdown was successful
}
return nil
}
func (sm *ShutdownManager) executeShutdownTasks(ctx context.Context) error {
if sm.config.ParallelShutdown {
return sm.executeTasksParallel(ctx)
} else {
return sm.executeTasksSequential(ctx)
}
}
func (sm *ShutdownManager) executeTasksSequential(ctx context.Context) error {
var lastErr error
for _, task := range sm.shutdownTasks {
if !task.Enabled {
continue
}
if err := sm.executeTask(ctx, task); err != nil {
lastErr = err
if task.Critical {
return fmt.Errorf("critical task %s failed: %w", task.Name, err)
}
}
}
return lastErr
}
func (sm *ShutdownManager) executeTasksParallel(ctx context.Context) error {
var wg sync.WaitGroup
errors := make(chan error, len(sm.shutdownTasks))
// Group tasks by priority
priorityGroups := sm.groupTasksByPriority()
// Execute each priority group sequentially, but tasks within group in parallel
for _, tasks := range priorityGroups {
for _, task := range tasks {
if !task.Enabled {
continue
}
wg.Add(1)
go func(t ShutdownTask) {
defer wg.Done()
if err := sm.executeTask(ctx, t); err != nil {
errors <- fmt.Errorf("task %s failed: %w", t.Name, err)
}
}(task)
}
wg.Wait()
}
close(errors)
// Collect errors
var criticalErr error
var lastErr error
for err := range errors {
lastErr = err
// Check if this was from a critical task
for _, task := range sm.shutdownTasks {
if task.Critical && fmt.Sprintf("task %s failed:", task.Name) == err.Error()[:len(fmt.Sprintf("task %s failed:", task.Name))] {
criticalErr = err
break
}
}
}
if criticalErr != nil {
return criticalErr
}
return lastErr
}
func (sm *ShutdownManager) executeTask(ctx context.Context, task ShutdownTask) error {
// Create timeout context for the task
taskCtx, cancel := context.WithTimeout(ctx, task.Timeout)
defer cancel()
// Execute task with retry
var lastErr error
for attempt := 0; attempt <= sm.config.MaxRetries; attempt++ {
if attempt > 0 {
select {
case <-time.After(sm.config.RetryDelay):
case <-taskCtx.Done():
return taskCtx.Err()
}
}
err := task.Task(taskCtx)
if err == nil {
return nil
}
lastErr = err
attemptNumber := attempt + 1
sm.recordShutdownError(
fmt.Sprintf("Shutdown task %s failed", task.Name),
fmt.Errorf("attempt %d: %w", attemptNumber, err),
"task", task.Name,
"attempt", attemptNumber,
)
// Call error handler if provided
if task.OnError != nil {
task.OnError(err)
}
}
return fmt.Errorf("task failed after %d attempts: %w", sm.config.MaxRetries, lastErr)
}
func (sm *ShutdownManager) callHooks(ctx context.Context, hookMethod string, cause error) error {
var hookErrors []error
for _, hook := range sm.shutdownHooks {
hookName := fmt.Sprintf("%T", hook)
var err error
switch hookMethod {
case "OnShutdownStarted":
err = hook.OnShutdownStarted(ctx)
case "OnModulesStopped":
err = hook.OnModulesStopped(ctx)
case "OnCleanupStarted":
err = hook.OnCleanupStarted(ctx)
case "OnShutdownCompleted":
err = hook.OnShutdownCompleted(ctx)
case "OnShutdownFailed":
err = hook.OnShutdownFailed(ctx, cause)
}
if err != nil {
recordContext := fmt.Sprintf("%s hook failure (%s)", hookMethod, hookName)
sm.recordShutdownError(recordContext, err, "hook", hookName, "phase", hookMethod)
hookErrors = append(hookErrors, fmt.Errorf("%s: %w", recordContext, err))
}
}
if len(hookErrors) > 0 {
return errors.Join(hookErrors...)
}
return nil
}
func (sm *ShutdownManager) sortTasksByPriority() {
// Simple bubble sort by priority (descending)
for i := 0; i < len(sm.shutdownTasks); i++ {
for j := i + 1; j < len(sm.shutdownTasks); j++ {
if sm.shutdownTasks[j].Priority > sm.shutdownTasks[i].Priority {
sm.shutdownTasks[i], sm.shutdownTasks[j] = sm.shutdownTasks[j], sm.shutdownTasks[i]
}
}
}
}
func (sm *ShutdownManager) groupTasksByPriority() [][]ShutdownTask {
groups := make(map[int][]ShutdownTask)
for _, task := range sm.shutdownTasks {
groups[task.Priority] = append(groups[task.Priority], task)
}
// Convert to sorted slice
var priorities []int
for priority := range groups {
priorities = append(priorities, priority)
}
// Sort priorities descending
for i := 0; i < len(priorities); i++ {
for j := i + 1; j < len(priorities); j++ {
if priorities[j] > priorities[i] {
priorities[i], priorities[j] = priorities[j], priorities[i]
}
}
}
var result [][]ShutdownTask
for _, priority := range priorities {
result = append(result, groups[priority])
}
return result
}
// Default task implementations
func (sm *ShutdownManager) saveApplicationState(ctx context.Context) error {
// Save application state to disk
// This would save things like current configuration, runtime state, etc.
return nil
}
func (sm *ShutdownManager) closeExternalConnections(ctx context.Context) error {
// Close database connections, external API connections, etc.
return nil
}
func (sm *ShutdownManager) cleanupTempFiles(ctx context.Context) error {
// Remove temporary files, logs, caches, etc.
return nil
}
func (sm *ShutdownManager) notifyExternalSystems(ctx context.Context) error {
// Notify external systems that this instance is shutting down
return nil
}
func (sm *ShutdownManager) recordShutdownError(message string, err error, attrs ...interface{}) {
if err == nil {
return
}
attrCopy := append([]interface{}{}, attrs...)
wrapped, txHash, attrsWithTx := enrichErrorWithTxHash(message, err, attrCopy)
sm.errMu.Lock()
sm.shutdownErrors = append(sm.shutdownErrors, wrapped)
sm.shutdownErrorDetails = append(sm.shutdownErrorDetails, RecordedError{
Err: wrapped,
TxHash: txHash,
})
sm.errMu.Unlock()
if sm.logger != nil {
kv := append([]interface{}{}, attrsWithTx...)
kv = append(kv, "error", err)
args := append([]interface{}{message}, kv...)
sm.logger.Error(args...)
}
}
func (sm *ShutdownManager) combinedShutdownError() error {
sm.errMu.Lock()
defer sm.errMu.Unlock()
if len(sm.shutdownErrors) == 0 {
return nil
}
errs := make([]error, len(sm.shutdownErrors))
copy(errs, sm.shutdownErrors)
return errors.Join(errs...)
}
// ShutdownErrors returns a copy of recorded shutdown errors for diagnostics.
func (sm *ShutdownManager) ShutdownErrors() []error {
sm.errMu.Lock()
defer sm.errMu.Unlock()
if len(sm.shutdownErrors) == 0 {
return nil
}
errs := make([]error, len(sm.shutdownErrors))
copy(errs, sm.shutdownErrors)
return errs
}
// ShutdownErrorDetails returns recorded errors with associated metadata such as tx hash.
func (sm *ShutdownManager) ShutdownErrorDetails() []RecordedError {
sm.errMu.Lock()
defer sm.errMu.Unlock()
if len(sm.shutdownErrorDetails) == 0 {
return nil
}
details := make([]RecordedError, len(sm.shutdownErrorDetails))
copy(details, sm.shutdownErrorDetails)
return details
}
// DefaultShutdownHook provides a basic implementation of ShutdownHook
type DefaultShutdownHook struct {
name string
}
func NewDefaultShutdownHook(name string) *DefaultShutdownHook {
return &DefaultShutdownHook{name: name}
}
// triggerEmergencyShutdown performs emergency shutdown procedures when critical failures occur
func (sm *ShutdownManager) triggerEmergencyShutdown(reason string, err error) {
sm.logger.Error("EMERGENCY SHUTDOWN TRIGGERED",
"reason", reason,
"error", err,
"state", sm.state,
"timestamp", time.Now())
// Set emergency state
sm.mu.Lock()
sm.state = ShutdownStateFailed
sm.mu.Unlock()
// Attempt to signal all processes to terminate immediately
// This is a last-resort mechanism
if sm.emergencyHandler != nil {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sm.emergencyHandler(ctx, reason, err); err != nil {
sm.logger.Error("Emergency handler failed", "error", err)
}
}()
}
// Log to all available outputs
sm.recordShutdownError("EMERGENCY_SHUTDOWN", fmt.Errorf("%s: %w", reason, err))
// Attempt to notify monitoring systems if available
if len(sm.shutdownHooks) > 0 {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// CRITICAL FIX: Log emergency shutdown notification failures
if err := sm.callHooks(ctx, "OnEmergencyShutdown", fmt.Errorf("%s: %w", reason, err)); err != nil {
sm.logger.Warn("Failed to call emergency shutdown hooks",
"error", err,
"reason", reason)
}
}()
}
}
func (dsh *DefaultShutdownHook) OnShutdownStarted(ctx context.Context) error {
return nil
}
func (dsh *DefaultShutdownHook) OnModulesStopped(ctx context.Context) error {
return nil
}
func (dsh *DefaultShutdownHook) OnCleanupStarted(ctx context.Context) error {
return nil
}
func (dsh *DefaultShutdownHook) OnShutdownCompleted(ctx context.Context) error {
return nil
}
func (dsh *DefaultShutdownHook) OnShutdownFailed(ctx context.Context, err error) error {
return nil
}

View File

@@ -0,0 +1,111 @@
package lifecycle
import (
"context"
"errors"
"fmt"
"strings"
"testing"
)
type testShutdownHook struct {
errs map[string]error
lastFailure error
}
func (h *testShutdownHook) OnShutdownStarted(ctx context.Context) error {
return h.errs["OnShutdownStarted"]
}
func (h *testShutdownHook) OnModulesStopped(ctx context.Context) error {
return h.errs["OnModulesStopped"]
}
func (h *testShutdownHook) OnCleanupStarted(ctx context.Context) error {
return h.errs["OnCleanupStarted"]
}
func (h *testShutdownHook) OnShutdownCompleted(ctx context.Context) error {
return h.errs["OnShutdownCompleted"]
}
func (h *testShutdownHook) OnShutdownFailed(ctx context.Context, err error) error {
h.lastFailure = err
return h.errs["OnShutdownFailed"]
}
func TestShutdownManagerErrorAggregation(t *testing.T) {
sm := NewShutdownManager(nil, ShutdownConfig{})
sm.logger = nil
txHash := "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
firstErr := fmt.Errorf("first failure on %s", txHash)
secondErr := errors.New("second")
sm.recordShutdownError("first error", firstErr)
sm.recordShutdownError("second error", secondErr)
if got := len(sm.shutdownErrors); got != 2 {
t.Fatalf("expected 2 recorded errors, got %d", got)
}
combined := sm.combinedShutdownError()
if combined == nil {
t.Fatal("expected combined shutdown error, got nil")
}
if !errors.Is(combined, firstErr) || !errors.Is(combined, secondErr) {
t.Fatalf("combined error does not contain original errors: %v", combined)
}
exportedErrors := sm.ShutdownErrors()
if len(exportedErrors) != 2 {
t.Fatalf("expected exported error slice of length 2, got %d", len(exportedErrors))
}
details := sm.ShutdownErrorDetails()
if len(details) != 2 {
t.Fatalf("expected error detail slice of length 2, got %d", len(details))
}
if details[0].TxHash != txHash {
t.Fatalf("expected recorded error to track tx hash %s, got %s", txHash, details[0].TxHash)
}
if details[0].Err == nil || !strings.Contains(details[0].Err.Error(), "tx_hash="+txHash) {
t.Fatalf("expected recorded error message to include tx hash, got %v", details[0].Err)
}
}
func TestShutdownManagerCallHooksAggregatesErrors(t *testing.T) {
sm := NewShutdownManager(nil, ShutdownConfig{})
sm.logger = nil
sm.shutdownErrors = nil
hookErrA := errors.New("hookA failure")
hookErrB := errors.New("hookB failure")
hookA := &testShutdownHook{
errs: map[string]error{
"OnShutdownFailed": hookErrA,
},
}
hookB := &testShutdownHook{
errs: map[string]error{
"OnShutdownFailed": hookErrB,
},
}
sm.shutdownHooks = []ShutdownHook{hookA, hookB}
cause := errors.New("original failure")
err := sm.callHooks(context.Background(), "OnShutdownFailed", cause)
if err == nil {
t.Fatal("expected aggregated error from hooks, got nil")
}
if !errors.Is(err, hookErrA) || !errors.Is(err, hookErrB) {
t.Fatalf("expected aggregated error to contain hook failures, got %v", err)
}
if hookA.lastFailure != cause || hookB.lastFailure != cause {
t.Fatal("expected hook to receive original failure cause")
}
if len(sm.ShutdownErrors()) != 2 {
t.Fatalf("expected shutdown errors to be recorded for each hook failure, got %d", len(sm.ShutdownErrors()))
}
}

View File

@@ -0,0 +1,660 @@
package lifecycle
import (
"context"
"fmt"
"sync"
"time"
)
// StateMachine manages module state transitions and enforces valid state changes
type StateMachine struct {
currentState ModuleState
transitions map[ModuleState][]ModuleState
stateHandlers map[ModuleState]StateHandler
transitionHooks map[string]TransitionHook
history []StateTransition
module Module
config StateMachineConfig
mu sync.RWMutex
metrics StateMachineMetrics
}
// StateHandler handles operations when entering a specific state
type StateHandler func(ctx context.Context, machine *StateMachine) error
// TransitionHook is called before or after state transitions
type TransitionHook func(ctx context.Context, from, to ModuleState, machine *StateMachine) error
// StateTransition represents a state change event
type StateTransition struct {
From ModuleState `json:"from"`
To ModuleState `json:"to"`
Timestamp time.Time `json:"timestamp"`
Duration time.Duration `json:"duration"`
Success bool `json:"success"`
Error error `json:"error,omitempty"`
Trigger string `json:"trigger"`
Context map[string]interface{} `json:"context"`
}
// StateMachineConfig configures state machine behavior
type StateMachineConfig struct {
InitialState ModuleState `json:"initial_state"`
TransitionTimeout time.Duration `json:"transition_timeout"`
MaxHistorySize int `json:"max_history_size"`
EnableMetrics bool `json:"enable_metrics"`
EnableValidation bool `json:"enable_validation"`
AllowConcurrent bool `json:"allow_concurrent"`
RetryFailedTransitions bool `json:"retry_failed_transitions"`
MaxRetries int `json:"max_retries"`
RetryDelay time.Duration `json:"retry_delay"`
}
// StateMachineMetrics tracks state machine performance
type StateMachineMetrics struct {
TotalTransitions int64 `json:"total_transitions"`
SuccessfulTransitions int64 `json:"successful_transitions"`
FailedTransitions int64 `json:"failed_transitions"`
StateDistribution map[ModuleState]int64 `json:"state_distribution"`
TransitionTimes map[string]time.Duration `json:"transition_times"`
AverageTransitionTime time.Duration `json:"average_transition_time"`
LongestTransition time.Duration `json:"longest_transition"`
LastTransition time.Time `json:"last_transition"`
CurrentStateDuration time.Duration `json:"current_state_duration"`
stateEnterTime time.Time
}
// NewStateMachine creates a new state machine for a module
func NewStateMachine(module Module, config StateMachineConfig) *StateMachine {
sm := &StateMachine{
currentState: config.InitialState,
transitions: createDefaultTransitions(),
stateHandlers: make(map[ModuleState]StateHandler),
transitionHooks: make(map[string]TransitionHook),
history: make([]StateTransition, 0),
module: module,
config: config,
metrics: StateMachineMetrics{
StateDistribution: make(map[ModuleState]int64),
TransitionTimes: make(map[string]time.Duration),
stateEnterTime: time.Now(),
},
}
// Set default config values
if sm.config.TransitionTimeout == 0 {
sm.config.TransitionTimeout = 30 * time.Second
}
if sm.config.MaxHistorySize == 0 {
sm.config.MaxHistorySize = 100
}
if sm.config.MaxRetries == 0 {
sm.config.MaxRetries = 3
}
if sm.config.RetryDelay == 0 {
sm.config.RetryDelay = time.Second
}
// Setup default state handlers
sm.setupDefaultHandlers()
return sm
}
// GetCurrentState returns the current state
func (sm *StateMachine) GetCurrentState() ModuleState {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.currentState
}
// CanTransition checks if a transition from current state to target state is valid
func (sm *StateMachine) CanTransition(to ModuleState) bool {
sm.mu.RLock()
defer sm.mu.RUnlock()
validTransitions, exists := sm.transitions[sm.currentState]
if !exists {
return false
}
for _, validState := range validTransitions {
if validState == to {
return true
}
}
return false
}
// Transition performs a state transition
func (sm *StateMachine) Transition(ctx context.Context, to ModuleState, trigger string) error {
if !sm.config.AllowConcurrent {
sm.mu.Lock()
defer sm.mu.Unlock()
} else {
sm.mu.RLock()
defer sm.mu.RUnlock()
}
return sm.performTransition(ctx, to, trigger)
}
// TransitionWithRetry performs a state transition with retry logic
func (sm *StateMachine) TransitionWithRetry(ctx context.Context, to ModuleState, trigger string) error {
var lastErr error
for attempt := 0; attempt <= sm.config.MaxRetries; attempt++ {
if attempt > 0 {
// Wait before retrying
select {
case <-time.After(sm.config.RetryDelay):
case <-ctx.Done():
return ctx.Err()
}
}
err := sm.Transition(ctx, to, trigger)
if err == nil {
return nil
}
lastErr = err
// Don't retry if it's a validation error
if !sm.config.RetryFailedTransitions {
break
}
}
return fmt.Errorf("transition failed after %d attempts: %w", sm.config.MaxRetries, lastErr)
}
// Initialize transitions to initialized state
func (sm *StateMachine) Initialize(ctx context.Context) error {
return sm.Transition(ctx, StateInitialized, "initialize")
}
// Start transitions to running state
func (sm *StateMachine) Start(ctx context.Context) error {
return sm.Transition(ctx, StateRunning, "start")
}
// Stop transitions to stopped state
func (sm *StateMachine) Stop(ctx context.Context) error {
return sm.Transition(ctx, StateStopped, "stop")
}
// Pause transitions to paused state
func (sm *StateMachine) Pause(ctx context.Context) error {
return sm.Transition(ctx, StatePaused, "pause")
}
// Resume transitions to running state from paused
func (sm *StateMachine) Resume(ctx context.Context) error {
return sm.Transition(ctx, StateRunning, "resume")
}
// Fail transitions to failed state
func (sm *StateMachine) Fail(ctx context.Context, reason string) error {
return sm.Transition(ctx, StateFailed, fmt.Sprintf("fail: %s", reason))
}
// SetStateHandler sets a custom handler for a specific state
func (sm *StateMachine) SetStateHandler(state ModuleState, handler StateHandler) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.stateHandlers[state] = handler
}
// SetTransitionHook sets a hook for state transitions
func (sm *StateMachine) SetTransitionHook(name string, hook TransitionHook) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.transitionHooks[name] = hook
}
// GetHistory returns the state transition history
func (sm *StateMachine) GetHistory() []StateTransition {
sm.mu.RLock()
defer sm.mu.RUnlock()
history := make([]StateTransition, len(sm.history))
copy(history, sm.history)
return history
}
// GetMetrics returns state machine metrics
func (sm *StateMachine) GetMetrics() StateMachineMetrics {
sm.mu.RLock()
defer sm.mu.RUnlock()
// Update current state duration
metrics := sm.metrics
metrics.CurrentStateDuration = time.Since(sm.metrics.stateEnterTime)
return metrics
}
// AddCustomTransition adds a custom state transition rule
func (sm *StateMachine) AddCustomTransition(from, to ModuleState) {
sm.mu.Lock()
defer sm.mu.Unlock()
if _, exists := sm.transitions[from]; !exists {
sm.transitions[from] = make([]ModuleState, 0)
}
// Check if transition already exists
for _, existing := range sm.transitions[from] {
if existing == to {
return
}
}
sm.transitions[from] = append(sm.transitions[from], to)
}
// RemoveTransition removes a state transition rule
func (sm *StateMachine) RemoveTransition(from, to ModuleState) {
sm.mu.Lock()
defer sm.mu.Unlock()
transitions, exists := sm.transitions[from]
if !exists {
return
}
for i, transition := range transitions {
if transition == to {
sm.transitions[from] = append(transitions[:i], transitions[i+1:]...)
break
}
}
}
// GetValidTransitions returns all valid transitions from current state
func (sm *StateMachine) GetValidTransitions() []ModuleState {
sm.mu.RLock()
defer sm.mu.RUnlock()
validTransitions, exists := sm.transitions[sm.currentState]
if !exists {
return []ModuleState{}
}
result := make([]ModuleState, len(validTransitions))
copy(result, validTransitions)
return result
}
// IsInState checks if the state machine is in a specific state
func (sm *StateMachine) IsInState(state ModuleState) bool {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.currentState == state
}
// IsInAnyState checks if the state machine is in any of the provided states
func (sm *StateMachine) IsInAnyState(states ...ModuleState) bool {
sm.mu.RLock()
defer sm.mu.RUnlock()
for _, state := range states {
if sm.currentState == state {
return true
}
}
return false
}
// WaitForState waits until the state machine reaches a specific state or times out
func (sm *StateMachine) WaitForState(ctx context.Context, state ModuleState, timeout time.Duration) error {
if sm.IsInState(state) {
return nil
}
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-timeoutCtx.Done():
return fmt.Errorf("timeout waiting for state %s", state)
case <-ticker.C:
if sm.IsInState(state) {
return nil
}
}
}
}
// Reset resets the state machine to its initial state
func (sm *StateMachine) Reset(ctx context.Context) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Clear history
sm.history = make([]StateTransition, 0)
// Reset metrics
sm.metrics = StateMachineMetrics{
StateDistribution: make(map[ModuleState]int64),
TransitionTimes: make(map[string]time.Duration),
stateEnterTime: time.Now(),
}
// Transition to initial state
return sm.performTransition(ctx, sm.config.InitialState, "reset")
}
// Private methods
func (sm *StateMachine) performTransition(ctx context.Context, to ModuleState, trigger string) error {
startTime := time.Now()
from := sm.currentState
// Validate transition
if sm.config.EnableValidation && !sm.canTransitionUnsafe(to) {
return fmt.Errorf("invalid transition from %s to %s", from, to)
}
// Create transition context
transitionCtx := map[string]interface{}{
"trigger": trigger,
"start_time": startTime,
"module_id": sm.module.GetID(),
}
// Execute pre-transition hooks
for name, hook := range sm.transitionHooks {
hookCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
err := func() error {
defer cancel()
if err := hook(hookCtx, from, to, sm); err != nil {
return fmt.Errorf("pre-transition hook %s failed: %w", name, err)
}
return nil
}()
if err != nil {
sm.recordFailedTransition(from, to, startTime, trigger, err, transitionCtx)
return err
}
}
// Execute state-specific logic
if err := sm.executeStateTransition(ctx, from, to); err != nil {
sm.recordFailedTransition(from, to, startTime, trigger, err, transitionCtx)
return fmt.Errorf("state transition failed: %w", err)
}
// Update current state
sm.currentState = to
duration := time.Since(startTime)
// Update metrics
if sm.config.EnableMetrics {
sm.updateMetrics(from, to, duration)
}
// Record successful transition
sm.recordSuccessfulTransition(from, to, startTime, duration, trigger, transitionCtx)
// Execute post-transition hooks
for _, hook := range sm.transitionHooks {
hookCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
func() {
defer cancel()
if err := hook(hookCtx, from, to, sm); err != nil {
// Log error but don't fail the transition
return
}
}()
}
// Execute state handler for new state
if handler, exists := sm.stateHandlers[to]; exists {
handlerCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
func() {
defer cancel()
if err := handler(handlerCtx, sm); err != nil {
// Log error but don't fail the transition
}
}()
}
return nil
}
func (sm *StateMachine) executeStateTransition(ctx context.Context, from, to ModuleState) error {
// Create timeout context for the operation
timeoutCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
defer cancel()
switch to {
case StateInitialized:
return sm.module.Initialize(timeoutCtx, ModuleConfig{})
case StateRunning:
if from == StatePaused {
return sm.module.Resume(timeoutCtx)
}
return sm.module.Start(timeoutCtx)
case StateStopped:
return sm.module.Stop(timeoutCtx)
case StatePaused:
return sm.module.Pause(timeoutCtx)
case StateFailed:
// Failed state doesn't require module action
return nil
default:
return fmt.Errorf("unknown target state: %s", to)
}
}
func (sm *StateMachine) canTransitionUnsafe(to ModuleState) bool {
validTransitions, exists := sm.transitions[sm.currentState]
if !exists {
return false
}
for _, validState := range validTransitions {
if validState == to {
return true
}
}
return false
}
func (sm *StateMachine) recordSuccessfulTransition(from, to ModuleState, startTime time.Time, duration time.Duration, trigger string, context map[string]interface{}) {
transition := StateTransition{
From: from,
To: to,
Timestamp: startTime,
Duration: duration,
Success: true,
Trigger: trigger,
Context: context,
}
sm.addToHistory(transition)
}
func (sm *StateMachine) recordFailedTransition(from, to ModuleState, startTime time.Time, trigger string, err error, context map[string]interface{}) {
transition := StateTransition{
From: from,
To: to,
Timestamp: startTime,
Duration: time.Since(startTime),
Success: false,
Error: err,
Trigger: trigger,
Context: context,
}
sm.addToHistory(transition)
if sm.config.EnableMetrics {
sm.metrics.FailedTransitions++
}
}
func (sm *StateMachine) addToHistory(transition StateTransition) {
sm.history = append(sm.history, transition)
// Trim history if it exceeds max size
if len(sm.history) > sm.config.MaxHistorySize {
sm.history = sm.history[1:]
}
}
func (sm *StateMachine) updateMetrics(from, to ModuleState, duration time.Duration) {
sm.metrics.TotalTransitions++
sm.metrics.SuccessfulTransitions++
sm.metrics.StateDistribution[to]++
sm.metrics.LastTransition = time.Now()
// Update transition times
transitionKey := fmt.Sprintf("%s->%s", from, to)
sm.metrics.TransitionTimes[transitionKey] = duration
// Update average transition time
if sm.metrics.TotalTransitions > 0 {
total := time.Duration(0)
for _, d := range sm.metrics.TransitionTimes {
total += d
}
sm.metrics.AverageTransitionTime = total / time.Duration(len(sm.metrics.TransitionTimes))
}
// Update longest transition
if duration > sm.metrics.LongestTransition {
sm.metrics.LongestTransition = duration
}
// Update state enter time for duration tracking
sm.metrics.stateEnterTime = time.Now()
}
func (sm *StateMachine) setupDefaultHandlers() {
// Default handlers for common states
sm.stateHandlers[StateInitialized] = func(ctx context.Context, machine *StateMachine) error {
// State entered successfully
return nil
}
sm.stateHandlers[StateRunning] = func(ctx context.Context, machine *StateMachine) error {
// Module is now running
return nil
}
sm.stateHandlers[StateStopped] = func(ctx context.Context, machine *StateMachine) error {
// Module has stopped
return nil
}
sm.stateHandlers[StatePaused] = func(ctx context.Context, machine *StateMachine) error {
// Module is paused
return nil
}
sm.stateHandlers[StateFailed] = func(ctx context.Context, machine *StateMachine) error {
// Handle failure state - could trigger recovery logic
return nil
}
}
// createDefaultTransitions creates the standard state transition rules
func createDefaultTransitions() map[ModuleState][]ModuleState {
return map[ModuleState][]ModuleState{
StateUninitialized: {StateInitialized, StateFailed},
StateInitialized: {StateStarting, StateStopped, StateFailed},
StateStarting: {StateRunning, StateFailed},
StateRunning: {StatePausing, StateStopping, StateFailed},
StatePausing: {StatePaused, StateFailed},
StatePaused: {StateResuming, StateStopping, StateFailed},
StateResuming: {StateRunning, StateFailed},
StateStopping: {StateStopped, StateFailed},
StateStopped: {StateInitialized, StateStarting, StateFailed},
StateFailed: {StateInitialized, StateStopped}, // Recovery paths
}
}
// StateMachineBuilder provides a fluent interface for building state machines
type StateMachineBuilder struct {
config StateMachineConfig
stateHandlers map[ModuleState]StateHandler
transitionHooks map[string]TransitionHook
customTransitions map[ModuleState][]ModuleState
}
// NewStateMachineBuilder creates a new state machine builder
func NewStateMachineBuilder() *StateMachineBuilder {
return &StateMachineBuilder{
config: StateMachineConfig{
InitialState: StateUninitialized,
TransitionTimeout: 30 * time.Second,
MaxHistorySize: 100,
EnableMetrics: true,
EnableValidation: true,
},
stateHandlers: make(map[ModuleState]StateHandler),
transitionHooks: make(map[string]TransitionHook),
customTransitions: make(map[ModuleState][]ModuleState),
}
}
// WithConfig sets the state machine configuration
func (smb *StateMachineBuilder) WithConfig(config StateMachineConfig) *StateMachineBuilder {
smb.config = config
return smb
}
// WithStateHandler adds a state handler
func (smb *StateMachineBuilder) WithStateHandler(state ModuleState, handler StateHandler) *StateMachineBuilder {
smb.stateHandlers[state] = handler
return smb
}
// WithTransitionHook adds a transition hook
func (smb *StateMachineBuilder) WithTransitionHook(name string, hook TransitionHook) *StateMachineBuilder {
smb.transitionHooks[name] = hook
return smb
}
// WithCustomTransition adds a custom transition rule
func (smb *StateMachineBuilder) WithCustomTransition(from, to ModuleState) *StateMachineBuilder {
if _, exists := smb.customTransitions[from]; !exists {
smb.customTransitions[from] = make([]ModuleState, 0)
}
smb.customTransitions[from] = append(smb.customTransitions[from], to)
return smb
}
// Build creates the state machine
func (smb *StateMachineBuilder) Build(module Module) *StateMachine {
sm := NewStateMachine(module, smb.config)
// Add state handlers
for state, handler := range smb.stateHandlers {
sm.SetStateHandler(state, handler)
}
// Add transition hooks
for name, hook := range smb.transitionHooks {
sm.SetTransitionHook(name, hook)
}
// Add custom transitions
for from, toStates := range smb.customTransitions {
for _, to := range toStates {
sm.AddCustomTransition(from, to)
}
}
return sm
}