feat: create v2-prep branch with comprehensive planning
Restructured project for V2 refactor: **Structure Changes:** - Moved all V1 code to orig/ folder (preserved with git mv) - Created docs/planning/ directory - Added orig/README_V1.md explaining V1 preservation **Planning Documents:** - 00_V2_MASTER_PLAN.md: Complete architecture overview - Executive summary of critical V1 issues - High-level component architecture diagrams - 5-phase implementation roadmap - Success metrics and risk mitigation - 07_TASK_BREAKDOWN.md: Atomic task breakdown - 99+ hours of detailed tasks - Every task < 2 hours (atomic) - Clear dependencies and success criteria - Organized by implementation phase **V2 Key Improvements:** - Per-exchange parsers (factory pattern) - Multi-layer strict validation - Multi-index pool cache - Background validation pipeline - Comprehensive observability **Critical Issues Addressed:** - Zero address tokens (strict validation + cache enrichment) - Parsing accuracy (protocol-specific parsers) - No audit trail (background validation channel) - Inefficient lookups (multi-index cache) - Stats disconnection (event-driven metrics) Next Steps: 1. Review planning documents 2. Begin Phase 1: Foundation (P1-001 through P1-010) 3. Implement parsers in Phase 2 4. Build cache system in Phase 3 5. Add validation pipeline in Phase 4 6. Migrate and test in Phase 5 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
661
orig/pkg/lifecycle/dependency_injection.go
Normal file
661
orig/pkg/lifecycle/dependency_injection.go
Normal file
@@ -0,0 +1,661 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Container provides dependency injection functionality
|
||||
type Container struct {
|
||||
services map[reflect.Type]*ServiceDescriptor
|
||||
instances map[reflect.Type]interface{}
|
||||
namedServices map[string]*ServiceDescriptor
|
||||
namedInstances map[string]interface{}
|
||||
singletons map[reflect.Type]interface{}
|
||||
factories map[reflect.Type]FactoryFunc
|
||||
interceptors []Interceptor
|
||||
config ContainerConfig
|
||||
mu sync.RWMutex
|
||||
parent *Container
|
||||
scoped map[string]*Container
|
||||
}
|
||||
|
||||
// ServiceDescriptor describes how a service should be instantiated
|
||||
type ServiceDescriptor struct {
|
||||
ServiceType reflect.Type
|
||||
Implementation reflect.Type
|
||||
Lifetime ServiceLifetime
|
||||
Factory FactoryFunc
|
||||
Instance interface{}
|
||||
Name string
|
||||
Dependencies []reflect.Type
|
||||
Tags []string
|
||||
Metadata map[string]interface{}
|
||||
Interceptors []Interceptor
|
||||
}
|
||||
|
||||
// ServiceLifetime defines the lifetime of a service
|
||||
type ServiceLifetime string
|
||||
|
||||
const (
|
||||
Transient ServiceLifetime = "transient" // New instance every time
|
||||
Singleton ServiceLifetime = "singleton" // Single instance for container lifetime
|
||||
Scoped ServiceLifetime = "scoped" // Single instance per scope
|
||||
)
|
||||
|
||||
// FactoryFunc creates service instances
|
||||
type FactoryFunc func(container *Container) (interface{}, error)
|
||||
|
||||
// Interceptor can intercept service creation and method calls
|
||||
type Interceptor interface {
|
||||
Intercept(ctx context.Context, target interface{}, method string, args []interface{}) (interface{}, error)
|
||||
}
|
||||
|
||||
// ContainerConfig configures the container behavior
|
||||
type ContainerConfig struct {
|
||||
EnableReflection bool `json:"enable_reflection"`
|
||||
EnableCircularDetection bool `json:"enable_circular_detection"`
|
||||
EnableInterception bool `json:"enable_interception"`
|
||||
EnableValidation bool `json:"enable_validation"`
|
||||
MaxDepth int `json:"max_depth"`
|
||||
CacheInstances bool `json:"cache_instances"`
|
||||
}
|
||||
|
||||
// ServiceBuilder provides a fluent interface for service registration
|
||||
type ServiceBuilder struct {
|
||||
container *Container
|
||||
serviceType reflect.Type
|
||||
implType reflect.Type
|
||||
lifetime ServiceLifetime
|
||||
factory FactoryFunc
|
||||
instance interface{}
|
||||
name string
|
||||
tags []string
|
||||
metadata map[string]interface{}
|
||||
interceptors []Interceptor
|
||||
}
|
||||
|
||||
// NewContainer creates a new dependency injection container
|
||||
func NewContainer(config ContainerConfig) *Container {
|
||||
container := &Container{
|
||||
services: make(map[reflect.Type]*ServiceDescriptor),
|
||||
instances: make(map[reflect.Type]interface{}),
|
||||
namedServices: make(map[string]*ServiceDescriptor),
|
||||
namedInstances: make(map[string]interface{}),
|
||||
singletons: make(map[reflect.Type]interface{}),
|
||||
factories: make(map[reflect.Type]FactoryFunc),
|
||||
interceptors: make([]Interceptor, 0),
|
||||
config: config,
|
||||
scoped: make(map[string]*Container),
|
||||
}
|
||||
|
||||
// Set default configuration
|
||||
if container.config.MaxDepth == 0 {
|
||||
container.config.MaxDepth = 10
|
||||
}
|
||||
|
||||
return container
|
||||
}
|
||||
|
||||
// Register registers a service type with its implementation
|
||||
func (c *Container) Register(serviceType, implementationType interface{}) *ServiceBuilder {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
sType := reflect.TypeOf(serviceType)
|
||||
if sType.Kind() == reflect.Ptr {
|
||||
sType = sType.Elem()
|
||||
}
|
||||
if sType.Kind() == reflect.Interface {
|
||||
sType = reflect.TypeOf(serviceType).Elem()
|
||||
}
|
||||
|
||||
implType := reflect.TypeOf(implementationType)
|
||||
if implType.Kind() == reflect.Ptr {
|
||||
implType = implType.Elem()
|
||||
}
|
||||
|
||||
return &ServiceBuilder{
|
||||
container: c,
|
||||
serviceType: sType,
|
||||
implType: implType,
|
||||
lifetime: Transient,
|
||||
tags: make([]string, 0),
|
||||
metadata: make(map[string]interface{}),
|
||||
interceptors: make([]Interceptor, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterInstance registers a specific instance
|
||||
func (c *Container) RegisterInstance(serviceType interface{}, instance interface{}) *ServiceBuilder {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
sType := reflect.TypeOf(serviceType)
|
||||
if sType.Kind() == reflect.Ptr {
|
||||
sType = sType.Elem()
|
||||
}
|
||||
if sType.Kind() == reflect.Interface {
|
||||
sType = reflect.TypeOf(serviceType).Elem()
|
||||
}
|
||||
|
||||
return &ServiceBuilder{
|
||||
container: c,
|
||||
serviceType: sType,
|
||||
instance: instance,
|
||||
lifetime: Singleton,
|
||||
tags: make([]string, 0),
|
||||
metadata: make(map[string]interface{}),
|
||||
interceptors: make([]Interceptor, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterFactory registers a factory function for creating instances
|
||||
func (c *Container) RegisterFactory(serviceType interface{}, factory FactoryFunc) *ServiceBuilder {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
sType := reflect.TypeOf(serviceType)
|
||||
if sType.Kind() == reflect.Ptr {
|
||||
sType = sType.Elem()
|
||||
}
|
||||
if sType.Kind() == reflect.Interface {
|
||||
sType = reflect.TypeOf(serviceType).Elem()
|
||||
}
|
||||
|
||||
return &ServiceBuilder{
|
||||
container: c,
|
||||
serviceType: sType,
|
||||
factory: factory,
|
||||
lifetime: Transient,
|
||||
tags: make([]string, 0),
|
||||
metadata: make(map[string]interface{}),
|
||||
interceptors: make([]Interceptor, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve resolves a service instance by type
|
||||
func (c *Container) Resolve(serviceType interface{}) (interface{}, error) {
|
||||
sType := reflect.TypeOf(serviceType)
|
||||
if sType.Kind() == reflect.Ptr {
|
||||
sType = sType.Elem()
|
||||
}
|
||||
if sType.Kind() == reflect.Interface {
|
||||
sType = reflect.TypeOf(serviceType).Elem()
|
||||
}
|
||||
|
||||
return c.resolveType(sType, make(map[reflect.Type]bool), 0)
|
||||
}
|
||||
|
||||
// ResolveNamed resolves a named service instance
|
||||
func (c *Container) ResolveNamed(name string) (interface{}, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// Check if instance already exists
|
||||
if instance, exists := c.namedInstances[name]; exists {
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Get service descriptor
|
||||
descriptor, exists := c.namedServices[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("named service not found: %s", name)
|
||||
}
|
||||
|
||||
return c.createInstance(descriptor, make(map[reflect.Type]bool), 0)
|
||||
}
|
||||
|
||||
// ResolveAll resolves all services with a specific tag
|
||||
func (c *Container) ResolveAll(tag string) ([]interface{}, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
var instances []interface{}
|
||||
|
||||
for _, descriptor := range c.services {
|
||||
for _, serviceTag := range descriptor.Tags {
|
||||
if serviceTag == tag {
|
||||
instance, err := c.createInstance(descriptor, make(map[reflect.Type]bool), 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve service with tag %s: %w", tag, err)
|
||||
}
|
||||
instances = append(instances, instance)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// TryResolve attempts to resolve a service, returning nil if not found
|
||||
func (c *Container) TryResolve(serviceType interface{}) interface{} {
|
||||
instance, err := c.Resolve(serviceType)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return instance
|
||||
}
|
||||
|
||||
// IsRegistered checks if a service type is registered
|
||||
func (c *Container) IsRegistered(serviceType interface{}) bool {
|
||||
sType := reflect.TypeOf(serviceType)
|
||||
if sType.Kind() == reflect.Ptr {
|
||||
sType = sType.Elem()
|
||||
}
|
||||
if sType.Kind() == reflect.Interface {
|
||||
sType = reflect.TypeOf(serviceType).Elem()
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
_, exists := c.services[sType]
|
||||
return exists
|
||||
}
|
||||
|
||||
// CreateScope creates a new scoped container
|
||||
func (c *Container) CreateScope(name string) *Container {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
scope := &Container{
|
||||
services: make(map[reflect.Type]*ServiceDescriptor),
|
||||
instances: make(map[reflect.Type]interface{}),
|
||||
namedServices: make(map[string]*ServiceDescriptor),
|
||||
namedInstances: make(map[string]interface{}),
|
||||
singletons: make(map[reflect.Type]interface{}),
|
||||
factories: make(map[reflect.Type]FactoryFunc),
|
||||
interceptors: make([]Interceptor, 0),
|
||||
config: c.config,
|
||||
parent: c,
|
||||
scoped: make(map[string]*Container),
|
||||
}
|
||||
|
||||
c.scoped[name] = scope
|
||||
return scope
|
||||
}
|
||||
|
||||
// GetScope retrieves a named scope
|
||||
func (c *Container) GetScope(name string) (*Container, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
scope, exists := c.scoped[name]
|
||||
return scope, exists
|
||||
}
|
||||
|
||||
// AddInterceptor adds a global interceptor
|
||||
func (c *Container) AddInterceptor(interceptor Interceptor) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.interceptors = append(c.interceptors, interceptor)
|
||||
}
|
||||
|
||||
// GetRegistrations returns all service registrations
|
||||
func (c *Container) GetRegistrations() map[reflect.Type]*ServiceDescriptor {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
registrations := make(map[reflect.Type]*ServiceDescriptor)
|
||||
for t, desc := range c.services {
|
||||
registrations[t] = desc
|
||||
}
|
||||
return registrations
|
||||
}
|
||||
|
||||
// Validate validates all service registrations
|
||||
func (c *Container) Validate() error {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if !c.config.EnableValidation {
|
||||
return nil
|
||||
}
|
||||
|
||||
for serviceType, descriptor := range c.services {
|
||||
if err := c.validateDescriptor(serviceType, descriptor); err != nil {
|
||||
return fmt.Errorf("validation failed for service %s: %w", serviceType.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dispose cleans up the container and all instances
|
||||
func (c *Container) Dispose() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Dispose all scoped containers
|
||||
for _, scope := range c.scoped {
|
||||
scope.Dispose()
|
||||
}
|
||||
|
||||
// Clear all maps
|
||||
c.services = make(map[reflect.Type]*ServiceDescriptor)
|
||||
c.instances = make(map[reflect.Type]interface{})
|
||||
c.namedServices = make(map[string]*ServiceDescriptor)
|
||||
c.namedInstances = make(map[string]interface{})
|
||||
c.singletons = make(map[reflect.Type]interface{})
|
||||
c.factories = make(map[reflect.Type]FactoryFunc)
|
||||
c.scoped = make(map[string]*Container)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Private methods
|
||||
|
||||
func (c *Container) resolveType(serviceType reflect.Type, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
|
||||
if depth > c.config.MaxDepth {
|
||||
return nil, fmt.Errorf("maximum resolution depth exceeded for type %s", serviceType.String())
|
||||
}
|
||||
|
||||
// Check for circular dependencies
|
||||
if c.config.EnableCircularDetection && resolving[serviceType] {
|
||||
return nil, fmt.Errorf("circular dependency detected for type %s", serviceType.String())
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// Check if singleton instance exists
|
||||
if instance, exists := c.singletons[serviceType]; exists {
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Check if cached instance exists
|
||||
if c.config.CacheInstances {
|
||||
if instance, exists := c.instances[serviceType]; exists {
|
||||
return instance, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Get service descriptor
|
||||
descriptor, exists := c.services[serviceType]
|
||||
if !exists {
|
||||
// Try parent container
|
||||
if c.parent != nil {
|
||||
return c.parent.resolveType(serviceType, resolving, depth+1)
|
||||
}
|
||||
return nil, fmt.Errorf("service not registered: %s", serviceType.String())
|
||||
}
|
||||
|
||||
resolving[serviceType] = true
|
||||
defer delete(resolving, serviceType)
|
||||
|
||||
return c.createInstance(descriptor, resolving, depth+1)
|
||||
}
|
||||
|
||||
func (c *Container) createInstance(descriptor *ServiceDescriptor, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
|
||||
// Use existing instance if available
|
||||
if descriptor.Instance != nil {
|
||||
return descriptor.Instance, nil
|
||||
}
|
||||
|
||||
// Use factory if available
|
||||
if descriptor.Factory != nil {
|
||||
instance, err := descriptor.Factory(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("factory failed for %s: %w", descriptor.ServiceType.String(), err)
|
||||
}
|
||||
|
||||
if descriptor.Lifetime == Singleton {
|
||||
c.singletons[descriptor.ServiceType] = instance
|
||||
}
|
||||
|
||||
return c.applyInterceptors(instance, descriptor)
|
||||
}
|
||||
|
||||
// Create instance using reflection
|
||||
if descriptor.Implementation == nil {
|
||||
return nil, fmt.Errorf("no implementation or factory provided for %s", descriptor.ServiceType.String())
|
||||
}
|
||||
|
||||
instance, err := c.createInstanceByReflection(descriptor, resolving, depth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store singleton
|
||||
if descriptor.Lifetime == Singleton {
|
||||
c.singletons[descriptor.ServiceType] = instance
|
||||
}
|
||||
|
||||
return c.applyInterceptors(instance, descriptor)
|
||||
}
|
||||
|
||||
func (c *Container) createInstanceByReflection(descriptor *ServiceDescriptor, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
|
||||
if !c.config.EnableReflection {
|
||||
return nil, fmt.Errorf("reflection is disabled")
|
||||
}
|
||||
|
||||
implType := descriptor.Implementation
|
||||
if implType.Kind() == reflect.Ptr {
|
||||
implType = implType.Elem()
|
||||
}
|
||||
|
||||
// Find constructor (assumes first constructor or struct creation)
|
||||
var constructorFunc reflect.Value
|
||||
|
||||
// Look for constructor function
|
||||
_ = "New" + implType.Name() // constructorName not used yet
|
||||
if implType.PkgPath() != "" {
|
||||
// Try to find package-level constructor
|
||||
// This is simplified - in a real implementation you'd use build tags or reflection
|
||||
// to find the actual constructor functions
|
||||
}
|
||||
|
||||
// Create instance
|
||||
if constructorFunc.IsValid() {
|
||||
// Use constructor function
|
||||
return c.callConstructor(constructorFunc, resolving, depth)
|
||||
} else {
|
||||
// Create struct directly
|
||||
return c.createStruct(implType, resolving, depth)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Container) createStruct(structType reflect.Type, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
|
||||
// Create new instance
|
||||
instance := reflect.New(structType)
|
||||
elem := instance.Elem()
|
||||
|
||||
// Inject dependencies into fields
|
||||
for i := 0; i < elem.NumField(); i++ {
|
||||
field := elem.Field(i)
|
||||
fieldType := elem.Type().Field(i)
|
||||
|
||||
// Check for dependency injection tags
|
||||
if tag := fieldType.Tag.Get("inject"); tag != "" {
|
||||
if !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
var dependency interface{}
|
||||
var err error
|
||||
|
||||
if tag == "true" || tag == "" {
|
||||
// Inject by type
|
||||
dependency, err = c.resolveType(field.Type(), resolving, depth)
|
||||
} else {
|
||||
// Inject by name
|
||||
dependency, err = c.ResolveNamed(tag)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Check if injection is optional
|
||||
if optionalTag := fieldType.Tag.Get("optional"); optionalTag == "true" {
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("failed to inject dependency for field %s: %w", fieldType.Name, err)
|
||||
}
|
||||
|
||||
field.Set(reflect.ValueOf(dependency))
|
||||
}
|
||||
}
|
||||
|
||||
return instance.Interface(), nil
|
||||
}
|
||||
|
||||
func (c *Container) callConstructor(constructor reflect.Value, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
|
||||
constructorType := constructor.Type()
|
||||
args := make([]reflect.Value, constructorType.NumIn())
|
||||
|
||||
// Resolve constructor arguments
|
||||
for i := 0; i < constructorType.NumIn(); i++ {
|
||||
argType := constructorType.In(i)
|
||||
arg, err := c.resolveType(argType, resolving, depth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve constructor argument %d (%s): %w", i, argType.String(), err)
|
||||
}
|
||||
args[i] = reflect.ValueOf(arg)
|
||||
}
|
||||
|
||||
// Call constructor
|
||||
results := constructor.Call(args)
|
||||
if len(results) == 0 {
|
||||
return nil, fmt.Errorf("constructor returned no values")
|
||||
}
|
||||
|
||||
instance := results[0].Interface()
|
||||
|
||||
// Check for error result
|
||||
if len(results) > 1 && !results[1].IsNil() {
|
||||
if err, ok := results[1].Interface().(error); ok {
|
||||
return nil, fmt.Errorf("constructor error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
func (c *Container) applyInterceptors(instance interface{}, descriptor *ServiceDescriptor) (interface{}, error) {
|
||||
if !c.config.EnableInterception {
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Apply service-specific interceptors
|
||||
for _, interceptor := range descriptor.Interceptors {
|
||||
// Apply interceptor (simplified - real implementation would create proxies)
|
||||
_ = interceptor
|
||||
}
|
||||
|
||||
// Apply global interceptors
|
||||
for _, interceptor := range c.interceptors {
|
||||
// Apply interceptor (simplified - real implementation would create proxies)
|
||||
_ = interceptor
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
func (c *Container) validateDescriptor(serviceType reflect.Type, descriptor *ServiceDescriptor) error {
|
||||
// Validate that implementation implements the service interface
|
||||
if descriptor.Implementation != nil {
|
||||
if serviceType.Kind() == reflect.Interface {
|
||||
if !descriptor.Implementation.Implements(serviceType) {
|
||||
return fmt.Errorf("implementation %s does not implement interface %s",
|
||||
descriptor.Implementation.String(), serviceType.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate dependencies
|
||||
for _, depType := range descriptor.Dependencies {
|
||||
if !c.IsRegistered(depType) && (c.parent == nil || !c.parent.IsRegistered(depType)) {
|
||||
return fmt.Errorf("dependency %s is not registered", depType.String())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServiceBuilder methods
|
||||
|
||||
// AsSingleton sets the service lifetime to singleton
|
||||
func (sb *ServiceBuilder) AsSingleton() *ServiceBuilder {
|
||||
sb.lifetime = Singleton
|
||||
return sb
|
||||
}
|
||||
|
||||
// AsTransient sets the service lifetime to transient
|
||||
func (sb *ServiceBuilder) AsTransient() *ServiceBuilder {
|
||||
sb.lifetime = Transient
|
||||
return sb
|
||||
}
|
||||
|
||||
// AsScoped sets the service lifetime to scoped
|
||||
func (sb *ServiceBuilder) AsScoped() *ServiceBuilder {
|
||||
sb.lifetime = Scoped
|
||||
return sb
|
||||
}
|
||||
|
||||
// WithName sets a name for the service
|
||||
func (sb *ServiceBuilder) WithName(name string) *ServiceBuilder {
|
||||
sb.name = name
|
||||
return sb
|
||||
}
|
||||
|
||||
// WithTag adds a tag to the service
|
||||
func (sb *ServiceBuilder) WithTag(tag string) *ServiceBuilder {
|
||||
sb.tags = append(sb.tags, tag)
|
||||
return sb
|
||||
}
|
||||
|
||||
// WithMetadata adds metadata to the service
|
||||
func (sb *ServiceBuilder) WithMetadata(key string, value interface{}) *ServiceBuilder {
|
||||
sb.metadata[key] = value
|
||||
return sb
|
||||
}
|
||||
|
||||
// WithInterceptor adds an interceptor to the service
|
||||
func (sb *ServiceBuilder) WithInterceptor(interceptor Interceptor) *ServiceBuilder {
|
||||
sb.interceptors = append(sb.interceptors, interceptor)
|
||||
return sb
|
||||
}
|
||||
|
||||
// Build finalizes the service registration
|
||||
func (sb *ServiceBuilder) Build() error {
|
||||
sb.container.mu.Lock()
|
||||
defer sb.container.mu.Unlock()
|
||||
|
||||
descriptor := &ServiceDescriptor{
|
||||
ServiceType: sb.serviceType,
|
||||
Implementation: sb.implType,
|
||||
Lifetime: sb.lifetime,
|
||||
Factory: sb.factory,
|
||||
Instance: sb.instance,
|
||||
Name: sb.name,
|
||||
Tags: sb.tags,
|
||||
Metadata: sb.metadata,
|
||||
Interceptors: sb.interceptors,
|
||||
}
|
||||
|
||||
// Store by type
|
||||
sb.container.services[sb.serviceType] = descriptor
|
||||
|
||||
// Store by name if provided
|
||||
if sb.name != "" {
|
||||
sb.container.namedServices[sb.name] = descriptor
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultInterceptor provides basic interception functionality
|
||||
type DefaultInterceptor struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func NewDefaultInterceptor(name string) *DefaultInterceptor {
|
||||
return &DefaultInterceptor{name: name}
|
||||
}
|
||||
|
||||
func (di *DefaultInterceptor) Intercept(ctx context.Context, target interface{}, method string, args []interface{}) (interface{}, error) {
|
||||
// Basic interception - could add logging, metrics, etc.
|
||||
return nil, nil
|
||||
}
|
||||
109
orig/pkg/lifecycle/error_enrichment.go
Normal file
109
orig/pkg/lifecycle/error_enrichment.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var txHashPattern = regexp.MustCompile(`0x[a-fA-F0-9]{64}`)
|
||||
|
||||
type RecordedError struct {
|
||||
Err error
|
||||
TxHash string
|
||||
}
|
||||
|
||||
func (re RecordedError) Error() string {
|
||||
if re.Err == nil {
|
||||
return ""
|
||||
}
|
||||
return re.Err.Error()
|
||||
}
|
||||
|
||||
func enrichErrorWithTxHash(message string, err error, attrs []interface{}) (error, string, []interface{}) {
|
||||
txHash, attrsWithTx := ensureTxHash(attrs, err)
|
||||
wrapped := fmt.Errorf("%s: %w", message, err)
|
||||
if txHash != "" {
|
||||
wrapped = fmt.Errorf("%s [tx_hash=%s]: %w", message, txHash, err)
|
||||
}
|
||||
return wrapped, txHash, attrsWithTx
|
||||
}
|
||||
|
||||
func ensureTxHash(attrs []interface{}, err error) (string, []interface{}) {
|
||||
txHash := extractTxHashFromAttrs(attrs)
|
||||
if txHash == "" {
|
||||
txHash = extractTxHashFromError(err)
|
||||
}
|
||||
|
||||
if txHash == "" {
|
||||
return "", attrs
|
||||
}
|
||||
|
||||
hasTxAttr := false
|
||||
for i := 0; i+1 < len(attrs); i += 2 {
|
||||
key, ok := attrs[i].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if key == "tx_hash" || key == "transaction_hash" || key == "tx" {
|
||||
hasTxAttr = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasTxAttr {
|
||||
attrs = append(attrs, "tx_hash", txHash)
|
||||
}
|
||||
|
||||
return txHash, attrs
|
||||
}
|
||||
|
||||
func extractTxHashFromAttrs(attrs []interface{}) string {
|
||||
for i := 0; i+1 < len(attrs); i += 2 {
|
||||
key, ok := attrs[i].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if key == "tx_hash" || key == "transaction_hash" || key == "tx" {
|
||||
if value, ok := attrs[i+1].(string); ok && isValidTxHash(value) {
|
||||
return strings.ToLower(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractTxHashFromError(err error) string {
|
||||
for err != nil {
|
||||
if match := txHashPattern.FindString(err.Error()); match != "" {
|
||||
return strings.ToLower(match)
|
||||
}
|
||||
err = errors.Unwrap(err)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isValidTxHash(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
if len(value) != 66 {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(value, "0x") {
|
||||
return false
|
||||
}
|
||||
for _, r := range value[2:] {
|
||||
if !isHexChar(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isHexChar(r rune) bool {
|
||||
return (r >= '0' && r <= '9') ||
|
||||
(r >= 'a' && r <= 'f') ||
|
||||
(r >= 'A' && r <= 'F')
|
||||
}
|
||||
1011
orig/pkg/lifecycle/health_monitor.go
Normal file
1011
orig/pkg/lifecycle/health_monitor.go
Normal file
File diff suppressed because it is too large
Load Diff
106
orig/pkg/lifecycle/health_monitor_test.go
Normal file
106
orig/pkg/lifecycle/health_monitor_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type stubHealthNotifier struct {
|
||||
failUntil int
|
||||
attempts int
|
||||
txHash string
|
||||
}
|
||||
|
||||
func (s *stubHealthNotifier) NotifyHealthChange(moduleID string, oldHealth, newHealth ModuleHealth) error {
|
||||
s.attempts++
|
||||
if s.attempts <= s.failUntil {
|
||||
return fmt.Errorf("notify failure %d for tx %s", s.attempts, s.txHash)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubHealthNotifier) NotifySystemHealth(health OverallHealth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubHealthNotifier) NotifyAlert(alert HealthAlert) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHealthMonitorNotifyWithRetrySuccess(t *testing.T) {
|
||||
config := HealthMonitorConfig{
|
||||
NotificationRetries: 3,
|
||||
NotificationRetryDelay: time.Nanosecond,
|
||||
}
|
||||
hm := NewHealthMonitor(config)
|
||||
hm.logger = nil
|
||||
|
||||
notifier := &stubHealthNotifier{failUntil: 2, txHash: "0xdddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd"}
|
||||
hm.notifier = notifier
|
||||
|
||||
err := hm.notifyWithRetry(func() error {
|
||||
return notifier.NotifyHealthChange("module", ModuleHealth{}, ModuleHealth{})
|
||||
}, "notify failure", "module_id", "module")
|
||||
if err != nil {
|
||||
t.Fatalf("expected notification to eventually succeed, got %v", err)
|
||||
}
|
||||
if notifier.attempts != 3 {
|
||||
t.Fatalf("expected 3 attempts, got %d", notifier.attempts)
|
||||
}
|
||||
if errs := hm.NotificationErrors(); len(errs) != 0 {
|
||||
t.Fatalf("expected no recorded notification errors, got %d", len(errs))
|
||||
}
|
||||
if hm.aggregatedNotificationError() != nil {
|
||||
t.Fatal("expected aggregated notification error to be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthMonitorNotifyWithRetryFailure(t *testing.T) {
|
||||
config := HealthMonitorConfig{
|
||||
NotificationRetries: 2,
|
||||
NotificationRetryDelay: time.Nanosecond,
|
||||
}
|
||||
hm := NewHealthMonitor(config)
|
||||
hm.logger = nil
|
||||
|
||||
txHash := "0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
|
||||
notifier := &stubHealthNotifier{failUntil: 3, txHash: txHash}
|
||||
hm.notifier = notifier
|
||||
|
||||
err := hm.notifyWithRetry(func() error {
|
||||
return notifier.NotifyHealthChange("module", ModuleHealth{}, ModuleHealth{})
|
||||
}, "notify failure", "module_id", "module")
|
||||
if err == nil {
|
||||
t.Fatal("expected notification to fail after retries")
|
||||
}
|
||||
if notifier.attempts != 2 {
|
||||
t.Fatalf("expected 2 attempts, got %d", notifier.attempts)
|
||||
}
|
||||
exported := hm.NotificationErrors()
|
||||
if len(exported) != 1 {
|
||||
t.Fatalf("expected 1 recorded notification error, got %d", len(exported))
|
||||
}
|
||||
copyErrs := hm.NotificationErrors()
|
||||
if copyErrs[0] == nil {
|
||||
t.Fatal("expected copy of notification errors to retain value")
|
||||
}
|
||||
if got := exported[0].Error(); !strings.Contains(got, txHash) {
|
||||
t.Fatalf("recorded notification error should include tx hash, got %q", got)
|
||||
}
|
||||
details := hm.NotificationErrorDetails()
|
||||
if len(details) != 1 {
|
||||
t.Fatalf("expected notification error details entry, got %d", len(details))
|
||||
}
|
||||
if details[0].TxHash != txHash {
|
||||
t.Fatalf("expected notification error detail to track tx hash %s, got %s", txHash, details[0].TxHash)
|
||||
}
|
||||
agg := hm.aggregatedNotificationError()
|
||||
if agg == nil {
|
||||
t.Fatal("expected aggregated notification error to be returned")
|
||||
}
|
||||
if got := agg.Error(); !strings.Contains(got, "notify failure") || !strings.Contains(got, "notify failure 1") || !strings.Contains(got, txHash) {
|
||||
t.Fatalf("aggregated notification error should include failure details and tx hash, got %q", got)
|
||||
}
|
||||
}
|
||||
417
orig/pkg/lifecycle/interfaces.go
Normal file
417
orig/pkg/lifecycle/interfaces.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BaseModule provides a default implementation of the Module interface
|
||||
type BaseModule struct {
|
||||
id string
|
||||
name string
|
||||
version string
|
||||
dependencies []string
|
||||
state ModuleState
|
||||
health ModuleHealth
|
||||
metrics ModuleMetrics
|
||||
config ModuleConfig
|
||||
}
|
||||
|
||||
// NewBaseModule creates a new base module
|
||||
func NewBaseModule(id, name, version string, dependencies []string) *BaseModule {
|
||||
return &BaseModule{
|
||||
id: id,
|
||||
name: name,
|
||||
version: version,
|
||||
dependencies: dependencies,
|
||||
state: StateUninitialized,
|
||||
health: ModuleHealth{
|
||||
Status: HealthUnknown,
|
||||
},
|
||||
metrics: ModuleMetrics{
|
||||
CustomMetrics: make(map[string]interface{}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Core lifecycle methods
|
||||
|
||||
func (bm *BaseModule) Initialize(ctx context.Context, config ModuleConfig) error {
|
||||
bm.config = config
|
||||
bm.state = StateInitialized
|
||||
bm.health.Status = HealthHealthy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bm *BaseModule) Start(ctx context.Context) error {
|
||||
if bm.state != StateInitialized && bm.state != StateStopped {
|
||||
return fmt.Errorf("invalid state for start: %s", bm.state)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
bm.state = StateRunning
|
||||
bm.metrics.StartupTime = time.Since(startTime)
|
||||
bm.metrics.LastActivity = time.Now()
|
||||
bm.health.Status = HealthHealthy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bm *BaseModule) Stop(ctx context.Context) error {
|
||||
if bm.state != StateRunning && bm.state != StatePaused {
|
||||
return fmt.Errorf("invalid state for stop: %s", bm.state)
|
||||
}
|
||||
|
||||
stopTime := time.Now()
|
||||
bm.state = StateStopped
|
||||
bm.metrics.ShutdownTime = time.Since(stopTime)
|
||||
bm.health.Status = HealthUnknown
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bm *BaseModule) Pause(ctx context.Context) error {
|
||||
if bm.state != StateRunning {
|
||||
return fmt.Errorf("invalid state for pause: %s", bm.state)
|
||||
}
|
||||
|
||||
bm.state = StatePaused
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bm *BaseModule) Resume(ctx context.Context) error {
|
||||
if bm.state != StatePaused {
|
||||
return fmt.Errorf("invalid state for resume: %s", bm.state)
|
||||
}
|
||||
|
||||
bm.state = StateRunning
|
||||
bm.metrics.LastActivity = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Module information
|
||||
|
||||
func (bm *BaseModule) GetID() string {
|
||||
return bm.id
|
||||
}
|
||||
|
||||
func (bm *BaseModule) GetName() string {
|
||||
return bm.name
|
||||
}
|
||||
|
||||
func (bm *BaseModule) GetVersion() string {
|
||||
return bm.version
|
||||
}
|
||||
|
||||
func (bm *BaseModule) GetDependencies() []string {
|
||||
return bm.dependencies
|
||||
}
|
||||
|
||||
// Health and status
|
||||
|
||||
func (bm *BaseModule) GetHealth() ModuleHealth {
|
||||
bm.health.LastCheck = time.Now()
|
||||
return bm.health
|
||||
}
|
||||
|
||||
func (bm *BaseModule) GetState() ModuleState {
|
||||
return bm.state
|
||||
}
|
||||
|
||||
func (bm *BaseModule) GetMetrics() ModuleMetrics {
|
||||
return bm.metrics
|
||||
}
|
||||
|
||||
// Protected methods for subclasses
|
||||
|
||||
func (bm *BaseModule) SetHealth(status HealthStatus, message string) {
|
||||
bm.health.Status = status
|
||||
bm.health.Message = message
|
||||
bm.health.LastCheck = time.Now()
|
||||
}
|
||||
|
||||
func (bm *BaseModule) SetState(state ModuleState) {
|
||||
bm.state = state
|
||||
}
|
||||
|
||||
func (bm *BaseModule) UpdateMetrics(updates map[string]interface{}) {
|
||||
for key, value := range updates {
|
||||
bm.metrics.CustomMetrics[key] = value
|
||||
}
|
||||
bm.metrics.LastActivity = time.Now()
|
||||
}
|
||||
|
||||
func (bm *BaseModule) IncrementMetric(name string, value int64) {
|
||||
if current, exists := bm.metrics.CustomMetrics[name]; exists {
|
||||
if currentVal, ok := current.(int64); ok {
|
||||
bm.metrics.CustomMetrics[name] = currentVal + value
|
||||
} else {
|
||||
bm.metrics.CustomMetrics[name] = value
|
||||
}
|
||||
} else {
|
||||
bm.metrics.CustomMetrics[name] = value
|
||||
}
|
||||
}
|
||||
|
||||
// SimpleEventBus provides a basic implementation of EventBus
|
||||
type SimpleEventBus struct {
|
||||
handlers map[EventType][]EventHandler
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSimpleEventBus() *SimpleEventBus {
|
||||
return &SimpleEventBus{
|
||||
handlers: make(map[EventType][]EventHandler),
|
||||
}
|
||||
}
|
||||
|
||||
func (seb *SimpleEventBus) Publish(event ModuleEvent) error {
|
||||
seb.mu.RLock()
|
||||
defer seb.mu.RUnlock()
|
||||
|
||||
handlers, exists := seb.handlers[event.Type]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, handler := range handlers {
|
||||
go func(h EventHandler) {
|
||||
if err := h(event); err != nil {
|
||||
// Log error but don't fail the publish
|
||||
}
|
||||
}(handler)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (seb *SimpleEventBus) Subscribe(eventType EventType, handler EventHandler) error {
|
||||
seb.mu.Lock()
|
||||
defer seb.mu.Unlock()
|
||||
|
||||
if _, exists := seb.handlers[eventType]; !exists {
|
||||
seb.handlers[eventType] = make([]EventHandler, 0)
|
||||
}
|
||||
|
||||
seb.handlers[eventType] = append(seb.handlers[eventType], handler)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LifecycleManager coordinates all lifecycle components
|
||||
type LifecycleManager struct {
|
||||
registry *ModuleRegistry
|
||||
healthMonitor *HealthMonitorImpl
|
||||
shutdownManager *ShutdownManager
|
||||
container *Container
|
||||
eventBus *SimpleEventBus
|
||||
config LifecycleConfig
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// LifecycleConfig configures the lifecycle manager
|
||||
type LifecycleConfig struct {
|
||||
RegistryConfig RegistryConfig `json:"registry_config"`
|
||||
HealthMonitorConfig HealthMonitorConfig `json:"health_monitor_config"`
|
||||
ShutdownConfig ShutdownConfig `json:"shutdown_config"`
|
||||
ContainerConfig ContainerConfig `json:"container_config"`
|
||||
EnableEventBus bool `json:"enable_event_bus"`
|
||||
EnableHealthMonitor bool `json:"enable_health_monitor"`
|
||||
EnableShutdownManager bool `json:"enable_shutdown_manager"`
|
||||
EnableDependencyInjection bool `json:"enable_dependency_injection"`
|
||||
}
|
||||
|
||||
// NewLifecycleManager creates a new lifecycle manager
|
||||
func NewLifecycleManager(config LifecycleConfig) *LifecycleManager {
|
||||
lm := &LifecycleManager{
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Create event bus
|
||||
if config.EnableEventBus {
|
||||
lm.eventBus = NewSimpleEventBus()
|
||||
}
|
||||
|
||||
// Create dependency injection container
|
||||
if config.EnableDependencyInjection {
|
||||
lm.container = NewContainer(config.ContainerConfig)
|
||||
}
|
||||
|
||||
// Create module registry
|
||||
lm.registry = NewModuleRegistry(config.RegistryConfig)
|
||||
if lm.eventBus != nil {
|
||||
lm.registry.SetEventBus(lm.eventBus)
|
||||
}
|
||||
|
||||
// Create health monitor
|
||||
if config.EnableHealthMonitor {
|
||||
lm.healthMonitor = NewHealthMonitor(config.HealthMonitorConfig)
|
||||
lm.registry.SetHealthMonitor(lm.healthMonitor)
|
||||
}
|
||||
|
||||
// Create shutdown manager
|
||||
if config.EnableShutdownManager {
|
||||
lm.shutdownManager = NewShutdownManager(lm.registry, config.ShutdownConfig)
|
||||
}
|
||||
|
||||
return lm
|
||||
}
|
||||
|
||||
// Initialize initializes the lifecycle manager
|
||||
func (lm *LifecycleManager) Initialize(ctx context.Context) error {
|
||||
lm.mu.Lock()
|
||||
defer lm.mu.Unlock()
|
||||
|
||||
// Validate container if enabled
|
||||
if lm.container != nil {
|
||||
if err := lm.container.Validate(); err != nil {
|
||||
return fmt.Errorf("container validation failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize registry
|
||||
if err := lm.registry.Initialize(ctx); err != nil {
|
||||
return fmt.Errorf("registry initialization failed: %w", err)
|
||||
}
|
||||
|
||||
// Start health monitor
|
||||
if lm.healthMonitor != nil {
|
||||
if err := lm.healthMonitor.Start(); err != nil {
|
||||
return fmt.Errorf("health monitor start failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start shutdown manager
|
||||
if lm.shutdownManager != nil {
|
||||
if err := lm.shutdownManager.Start(); err != nil {
|
||||
return fmt.Errorf("shutdown manager start failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start starts all modules
|
||||
func (lm *LifecycleManager) Start(ctx context.Context) error {
|
||||
return lm.registry.StartAll(ctx)
|
||||
}
|
||||
|
||||
// Stop stops all modules
|
||||
func (lm *LifecycleManager) Stop(ctx context.Context) error {
|
||||
return lm.registry.StopAll(ctx)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the entire system
|
||||
func (lm *LifecycleManager) Shutdown(ctx context.Context) error {
|
||||
if lm.shutdownManager != nil {
|
||||
return lm.shutdownManager.Shutdown(ctx)
|
||||
}
|
||||
return lm.registry.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// RegisterModule registers a new module
|
||||
func (lm *LifecycleManager) RegisterModule(module Module, config ModuleConfig) error {
|
||||
return lm.registry.Register(module, config)
|
||||
}
|
||||
|
||||
// GetModule retrieves a module by ID
|
||||
func (lm *LifecycleManager) GetModule(moduleID string) (Module, error) {
|
||||
return lm.registry.Get(moduleID)
|
||||
}
|
||||
|
||||
// GetRegistry returns the module registry
|
||||
func (lm *LifecycleManager) GetRegistry() *ModuleRegistry {
|
||||
lm.mu.RLock()
|
||||
defer lm.mu.RUnlock()
|
||||
return lm.registry
|
||||
}
|
||||
|
||||
// GetHealthMonitor returns the health monitor
|
||||
func (lm *LifecycleManager) GetHealthMonitor() *HealthMonitorImpl {
|
||||
lm.mu.RLock()
|
||||
defer lm.mu.RUnlock()
|
||||
return lm.healthMonitor
|
||||
}
|
||||
|
||||
// GetShutdownManager returns the shutdown manager
|
||||
func (lm *LifecycleManager) GetShutdownManager() *ShutdownManager {
|
||||
lm.mu.RLock()
|
||||
defer lm.mu.RUnlock()
|
||||
return lm.shutdownManager
|
||||
}
|
||||
|
||||
// GetContainer returns the dependency injection container
|
||||
func (lm *LifecycleManager) GetContainer() *Container {
|
||||
lm.mu.RLock()
|
||||
defer lm.mu.RUnlock()
|
||||
return lm.container
|
||||
}
|
||||
|
||||
// GetEventBus returns the event bus
|
||||
func (lm *LifecycleManager) GetEventBus() *SimpleEventBus {
|
||||
lm.mu.RLock()
|
||||
defer lm.mu.RUnlock()
|
||||
return lm.eventBus
|
||||
}
|
||||
|
||||
// GetOverallHealth returns the overall system health
|
||||
func (lm *LifecycleManager) GetOverallHealth() (OverallHealth, error) {
|
||||
if lm.healthMonitor == nil {
|
||||
return OverallHealth{}, fmt.Errorf("health monitor not enabled")
|
||||
}
|
||||
return lm.healthMonitor.GetOverallHealth(), nil
|
||||
}
|
||||
|
||||
// CreateDefaultConfig creates a default lifecycle configuration
|
||||
func CreateDefaultConfig() LifecycleConfig {
|
||||
return LifecycleConfig{
|
||||
RegistryConfig: RegistryConfig{
|
||||
StartTimeout: 30 * time.Second,
|
||||
StopTimeout: 15 * time.Second,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
EnableMetrics: true,
|
||||
EnableHealthMonitor: true,
|
||||
ParallelStartup: false,
|
||||
ParallelShutdown: true,
|
||||
FailureRecovery: true,
|
||||
AutoRestart: true,
|
||||
MaxRestartAttempts: 3,
|
||||
},
|
||||
HealthMonitorConfig: HealthMonitorConfig{
|
||||
CheckInterval: 30 * time.Second,
|
||||
CheckTimeout: 10 * time.Second,
|
||||
HistorySize: 100,
|
||||
FailureThreshold: 3,
|
||||
RecoveryThreshold: 3,
|
||||
EnableNotifications: true,
|
||||
EnableMetrics: true,
|
||||
EnableTrends: true,
|
||||
ParallelChecks: true,
|
||||
MaxConcurrentChecks: 10,
|
||||
},
|
||||
ShutdownConfig: ShutdownConfig{
|
||||
GracefulTimeout: 30 * time.Second,
|
||||
ForceTimeout: 60 * time.Second,
|
||||
SignalBufferSize: 10,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: time.Second,
|
||||
ParallelShutdown: true,
|
||||
SaveState: true,
|
||||
CleanupTempFiles: true,
|
||||
NotifyExternal: false,
|
||||
WaitForConnections: true,
|
||||
EnableMetrics: true,
|
||||
},
|
||||
ContainerConfig: ContainerConfig{
|
||||
EnableReflection: true,
|
||||
EnableCircularDetection: true,
|
||||
EnableInterception: false,
|
||||
EnableValidation: true,
|
||||
MaxDepth: 10,
|
||||
CacheInstances: true,
|
||||
},
|
||||
EnableEventBus: true,
|
||||
EnableHealthMonitor: true,
|
||||
EnableShutdownManager: true,
|
||||
EnableDependencyInjection: true,
|
||||
}
|
||||
}
|
||||
1019
orig/pkg/lifecycle/module_registry.go
Normal file
1019
orig/pkg/lifecycle/module_registry.go
Normal file
File diff suppressed because it is too large
Load Diff
96
orig/pkg/lifecycle/module_registry_test.go
Normal file
96
orig/pkg/lifecycle/module_registry_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type stubEventBus struct {
|
||||
failUntil int
|
||||
attempts int
|
||||
txHash string
|
||||
}
|
||||
|
||||
func (s *stubEventBus) Publish(event ModuleEvent) error {
|
||||
s.attempts++
|
||||
if s.attempts <= s.failUntil {
|
||||
return fmt.Errorf("publish failure %d for tx %s", s.attempts, s.txHash)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubEventBus) Subscribe(eventType EventType, handler EventHandler) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestModuleRegistryPublishEventWithRetrySuccess(t *testing.T) {
|
||||
registry := NewModuleRegistry(RegistryConfig{
|
||||
EventPublishRetries: 3,
|
||||
EventPublishDelay: time.Nanosecond,
|
||||
})
|
||||
registry.logger = nil
|
||||
|
||||
bus := &stubEventBus{failUntil: 2, txHash: "0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}
|
||||
registry.eventBus = bus
|
||||
|
||||
err := registry.publishEventWithRetry(ModuleEvent{ModuleID: "module-A", Type: EventModuleStarted}, "publish failed")
|
||||
if err != nil {
|
||||
t.Fatalf("expected publish to eventually succeed, got %v", err)
|
||||
}
|
||||
if bus.attempts != 3 {
|
||||
t.Fatalf("expected 3 attempts, got %d", bus.attempts)
|
||||
}
|
||||
if err := registry.aggregatedErrors(); err != nil {
|
||||
t.Fatalf("expected no aggregated errors, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModuleRegistryPublishEventWithRetryFailure(t *testing.T) {
|
||||
registry := NewModuleRegistry(RegistryConfig{
|
||||
EventPublishRetries: 2,
|
||||
EventPublishDelay: time.Nanosecond,
|
||||
})
|
||||
registry.logger = nil
|
||||
|
||||
txHash := "0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc"
|
||||
bus := &stubEventBus{failUntil: 3, txHash: txHash}
|
||||
registry.eventBus = bus
|
||||
|
||||
err := registry.publishEventWithRetry(ModuleEvent{ModuleID: "module-B", Type: EventModuleStopped}, "publish failed")
|
||||
if err == nil {
|
||||
t.Fatal("expected publish to fail after retries")
|
||||
}
|
||||
if bus.attempts != 2 {
|
||||
t.Fatalf("expected 2 attempts, got %d", bus.attempts)
|
||||
}
|
||||
|
||||
exported := registry.RegistryErrors()
|
||||
if len(exported) != 1 {
|
||||
t.Fatalf("expected 1 recorded registry error, got %d", len(exported))
|
||||
}
|
||||
if exported[0] == nil {
|
||||
t.Fatal("expected recorded error to be non-nil")
|
||||
}
|
||||
if copyErrs := registry.RegistryErrors(); copyErrs[0] == nil {
|
||||
t.Fatal("copy of registry errors should preserve values")
|
||||
}
|
||||
if got := exported[0].Error(); !strings.Contains(got, txHash) {
|
||||
t.Fatalf("recorded registry error should include tx hash, got %q", got)
|
||||
}
|
||||
details := registry.RegistryErrorDetails()
|
||||
if len(details) != 1 {
|
||||
t.Fatalf("expected registry error details to include entry, got %d", len(details))
|
||||
}
|
||||
if details[0].TxHash != txHash {
|
||||
t.Fatalf("expected registry error detail to track tx hash %s, got %s", txHash, details[0].TxHash)
|
||||
}
|
||||
agg := registry.aggregatedErrors()
|
||||
if agg == nil {
|
||||
t.Fatal("expected aggregated error to be returned")
|
||||
}
|
||||
if got := agg.Error(); !strings.Contains(got, "publish failed") || !strings.Contains(got, "publish failure 1") || !strings.Contains(got, txHash) {
|
||||
t.Fatalf("aggregated error should include failure details and tx hash, got %q", got)
|
||||
}
|
||||
}
|
||||
876
orig/pkg/lifecycle/shutdown_manager.go
Normal file
876
orig/pkg/lifecycle/shutdown_manager.go
Normal file
@@ -0,0 +1,876 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// ShutdownManager handles graceful shutdown of the application
|
||||
type ShutdownManager struct {
|
||||
registry *ModuleRegistry
|
||||
shutdownTasks []ShutdownTask
|
||||
shutdownHooks []ShutdownHook
|
||||
config ShutdownConfig
|
||||
signalChannel chan os.Signal
|
||||
shutdownChannel chan struct{}
|
||||
state ShutdownState
|
||||
startTime time.Time
|
||||
shutdownStarted time.Time
|
||||
mu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *logger.Logger
|
||||
shutdownErrors []error
|
||||
shutdownErrorDetails []RecordedError
|
||||
errMu sync.Mutex
|
||||
exitFunc func(code int)
|
||||
emergencyHandler func(ctx context.Context, reason string, err error) error
|
||||
}
|
||||
|
||||
// ShutdownTask represents a task to be executed during shutdown
|
||||
type ShutdownTask struct {
|
||||
Name string
|
||||
Priority int
|
||||
Timeout time.Duration
|
||||
Task func(ctx context.Context) error
|
||||
OnError func(error)
|
||||
Critical bool
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// ShutdownHook is called at different stages of shutdown
|
||||
type ShutdownHook interface {
|
||||
OnShutdownStarted(ctx context.Context) error
|
||||
OnModulesStopped(ctx context.Context) error
|
||||
OnCleanupStarted(ctx context.Context) error
|
||||
OnShutdownCompleted(ctx context.Context) error
|
||||
OnShutdownFailed(ctx context.Context, err error) error
|
||||
}
|
||||
|
||||
// ShutdownConfig configures shutdown behavior
|
||||
type ShutdownConfig struct {
|
||||
GracefulTimeout time.Duration `json:"graceful_timeout"`
|
||||
ForceTimeout time.Duration `json:"force_timeout"`
|
||||
SignalBufferSize int `json:"signal_buffer_size"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
RetryDelay time.Duration `json:"retry_delay"`
|
||||
ParallelShutdown bool `json:"parallel_shutdown"`
|
||||
SaveState bool `json:"save_state"`
|
||||
CleanupTempFiles bool `json:"cleanup_temp_files"`
|
||||
NotifyExternal bool `json:"notify_external"`
|
||||
WaitForConnections bool `json:"wait_for_connections"`
|
||||
EnableMetrics bool `json:"enable_metrics"`
|
||||
}
|
||||
|
||||
// ShutdownState represents the current shutdown state
|
||||
type ShutdownState string
|
||||
|
||||
const (
|
||||
ShutdownStateRunning ShutdownState = "running"
|
||||
ShutdownStateInitiated ShutdownState = "initiated"
|
||||
ShutdownStateModuleStop ShutdownState = "stopping_modules"
|
||||
ShutdownStateCleanup ShutdownState = "cleanup"
|
||||
ShutdownStateCompleted ShutdownState = "completed"
|
||||
ShutdownStateFailed ShutdownState = "failed"
|
||||
ShutdownStateForced ShutdownState = "forced"
|
||||
)
|
||||
|
||||
// ShutdownMetrics tracks shutdown performance
|
||||
type ShutdownMetrics struct {
|
||||
ShutdownInitiated time.Time `json:"shutdown_initiated"`
|
||||
ModuleStopTime time.Duration `json:"module_stop_time"`
|
||||
CleanupTime time.Duration `json:"cleanup_time"`
|
||||
TotalShutdownTime time.Duration `json:"total_shutdown_time"`
|
||||
TasksExecuted int `json:"tasks_executed"`
|
||||
TasksSuccessful int `json:"tasks_successful"`
|
||||
TasksFailed int `json:"tasks_failed"`
|
||||
RetryAttempts int `json:"retry_attempts"`
|
||||
ForceShutdown bool `json:"force_shutdown"`
|
||||
Signal string `json:"signal"`
|
||||
}
|
||||
|
||||
// ShutdownProgress tracks shutdown progress
|
||||
type ShutdownProgress struct {
|
||||
State ShutdownState `json:"state"`
|
||||
Progress float64 `json:"progress"`
|
||||
CurrentTask string `json:"current_task"`
|
||||
CompletedTasks int `json:"completed_tasks"`
|
||||
TotalTasks int `json:"total_tasks"`
|
||||
ElapsedTime time.Duration `json:"elapsed_time"`
|
||||
EstimatedRemaining time.Duration `json:"estimated_remaining"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewShutdownManager creates a new shutdown manager
|
||||
func NewShutdownManager(registry *ModuleRegistry, config ShutdownConfig) *ShutdownManager {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
sm := &ShutdownManager{
|
||||
registry: registry,
|
||||
shutdownTasks: make([]ShutdownTask, 0),
|
||||
shutdownHooks: make([]ShutdownHook, 0),
|
||||
config: config,
|
||||
signalChannel: make(chan os.Signal, config.SignalBufferSize),
|
||||
shutdownChannel: make(chan struct{}),
|
||||
state: ShutdownStateRunning,
|
||||
startTime: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
shutdownErrors: make([]error, 0),
|
||||
shutdownErrorDetails: make([]RecordedError, 0),
|
||||
exitFunc: os.Exit,
|
||||
}
|
||||
|
||||
// Set default configuration
|
||||
if sm.config.GracefulTimeout == 0 {
|
||||
sm.config.GracefulTimeout = 30 * time.Second
|
||||
}
|
||||
if sm.config.ForceTimeout == 0 {
|
||||
sm.config.ForceTimeout = 60 * time.Second
|
||||
}
|
||||
if sm.config.SignalBufferSize == 0 {
|
||||
sm.config.SignalBufferSize = 10
|
||||
}
|
||||
if sm.config.MaxRetries == 0 {
|
||||
sm.config.MaxRetries = 3
|
||||
}
|
||||
if sm.config.RetryDelay == 0 {
|
||||
sm.config.RetryDelay = time.Second
|
||||
}
|
||||
|
||||
if err := os.MkdirAll("logs", 0o755); err != nil {
|
||||
fmt.Printf("failed to ensure logs directory: %v\n", err)
|
||||
}
|
||||
sm.logger = logger.New("info", "", "logs/lifecycle_shutdown.log")
|
||||
|
||||
// Setup default shutdown tasks
|
||||
sm.setupDefaultTasks()
|
||||
|
||||
// Setup signal handling
|
||||
sm.setupSignalHandling()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// Start starts the shutdown manager
|
||||
func (sm *ShutdownManager) Start() error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if sm.state != ShutdownStateRunning {
|
||||
return fmt.Errorf("shutdown manager not in running state: %s", sm.state)
|
||||
}
|
||||
|
||||
// Start signal monitoring
|
||||
go sm.signalHandler()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown initiates graceful shutdown
|
||||
func (sm *ShutdownManager) Shutdown(ctx context.Context) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if sm.state != ShutdownStateRunning {
|
||||
return fmt.Errorf("shutdown already initiated: %s", sm.state)
|
||||
}
|
||||
|
||||
sm.state = ShutdownStateInitiated
|
||||
sm.shutdownStarted = time.Now()
|
||||
|
||||
// Close shutdown channel to signal shutdown
|
||||
close(sm.shutdownChannel)
|
||||
|
||||
err := sm.performShutdown(ctx)
|
||||
|
||||
combined := sm.combinedShutdownError()
|
||||
if combined != nil {
|
||||
return combined
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ForceShutdown forces immediate shutdown
|
||||
func (sm *ShutdownManager) ForceShutdown(ctx context.Context) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.state = ShutdownStateForced
|
||||
sm.cancel() // Cancel all operations
|
||||
|
||||
// Force stop all modules immediately
|
||||
var forceErr error
|
||||
if sm.registry != nil {
|
||||
forceCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := sm.registry.StopAll(forceCtx); err != nil {
|
||||
wrapped := fmt.Errorf("StopAll failed during force shutdown: %w", err)
|
||||
sm.recordShutdownError("StopAll failed in force shutdown", wrapped)
|
||||
forceErr = errors.Join(forceErr, wrapped)
|
||||
}
|
||||
}
|
||||
|
||||
if forceErr != nil {
|
||||
sm.recordShutdownError("Force shutdown encountered errors", forceErr)
|
||||
}
|
||||
|
||||
sm.exitFunc(1)
|
||||
return forceErr
|
||||
}
|
||||
|
||||
// AddShutdownTask adds a task to be executed during shutdown
|
||||
func (sm *ShutdownManager) AddShutdownTask(task ShutdownTask) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.shutdownTasks = append(sm.shutdownTasks, task)
|
||||
sm.sortTasksByPriority()
|
||||
}
|
||||
|
||||
// AddShutdownHook adds a hook to be called during shutdown phases
|
||||
func (sm *ShutdownManager) AddShutdownHook(hook ShutdownHook) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.shutdownHooks = append(sm.shutdownHooks, hook)
|
||||
}
|
||||
|
||||
// GetState returns the current shutdown state
|
||||
func (sm *ShutdownManager) GetState() ShutdownState {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.state
|
||||
}
|
||||
|
||||
// GetProgress returns the current shutdown progress
|
||||
func (sm *ShutdownManager) GetProgress() ShutdownProgress {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
totalTasks := len(sm.shutdownTasks)
|
||||
if sm.registry != nil {
|
||||
totalTasks += len(sm.registry.List())
|
||||
}
|
||||
|
||||
var progress float64
|
||||
var completedTasks int
|
||||
var currentTask string
|
||||
|
||||
switch sm.state {
|
||||
case ShutdownStateRunning:
|
||||
progress = 0
|
||||
currentTask = "Running"
|
||||
case ShutdownStateInitiated:
|
||||
progress = 0.1
|
||||
currentTask = "Shutdown initiated"
|
||||
case ShutdownStateModuleStop:
|
||||
progress = 0.3
|
||||
currentTask = "Stopping modules"
|
||||
completedTasks = totalTasks / 3
|
||||
case ShutdownStateCleanup:
|
||||
progress = 0.7
|
||||
currentTask = "Cleanup"
|
||||
completedTasks = (totalTasks * 2) / 3
|
||||
case ShutdownStateCompleted:
|
||||
progress = 1.0
|
||||
currentTask = "Completed"
|
||||
completedTasks = totalTasks
|
||||
case ShutdownStateFailed:
|
||||
progress = 0.8
|
||||
currentTask = "Failed"
|
||||
case ShutdownStateForced:
|
||||
progress = 1.0
|
||||
currentTask = "Forced shutdown"
|
||||
completedTasks = totalTasks
|
||||
}
|
||||
|
||||
elapsedTime := time.Since(sm.shutdownStarted)
|
||||
var estimatedRemaining time.Duration
|
||||
if progress > 0 && progress < 1.0 {
|
||||
totalEstimated := time.Duration(float64(elapsedTime) / progress)
|
||||
estimatedRemaining = totalEstimated - elapsedTime
|
||||
}
|
||||
|
||||
return ShutdownProgress{
|
||||
State: sm.state,
|
||||
Progress: progress,
|
||||
CurrentTask: currentTask,
|
||||
CompletedTasks: completedTasks,
|
||||
TotalTasks: totalTasks,
|
||||
ElapsedTime: elapsedTime,
|
||||
EstimatedRemaining: estimatedRemaining,
|
||||
Message: fmt.Sprintf("Shutdown %s", sm.state),
|
||||
}
|
||||
}
|
||||
|
||||
// Wait waits for shutdown to complete
|
||||
func (sm *ShutdownManager) Wait() {
|
||||
<-sm.shutdownChannel
|
||||
sm.wg.Wait()
|
||||
}
|
||||
|
||||
// WaitWithTimeout waits for shutdown with timeout
|
||||
func (sm *ShutdownManager) WaitWithTimeout(timeout time.Duration) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sm.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-time.After(timeout):
|
||||
return fmt.Errorf("shutdown timeout after %v", timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// Private methods
|
||||
|
||||
func (sm *ShutdownManager) setupSignalHandling() {
|
||||
signal.Notify(sm.signalChannel,
|
||||
syscall.SIGINT,
|
||||
syscall.SIGTERM,
|
||||
syscall.SIGQUIT,
|
||||
syscall.SIGHUP,
|
||||
)
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) setupDefaultTasks() {
|
||||
// Task: Save application state
|
||||
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
|
||||
Name: "save_state",
|
||||
Priority: 100,
|
||||
Timeout: 10 * time.Second,
|
||||
Task: func(ctx context.Context) error {
|
||||
if sm.config.SaveState {
|
||||
return sm.saveApplicationState(ctx)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Critical: false,
|
||||
Enabled: sm.config.SaveState,
|
||||
})
|
||||
|
||||
// Task: Close external connections
|
||||
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
|
||||
Name: "close_connections",
|
||||
Priority: 90,
|
||||
Timeout: 5 * time.Second,
|
||||
Task: func(ctx context.Context) error {
|
||||
return sm.closeExternalConnections(ctx)
|
||||
},
|
||||
Critical: false,
|
||||
Enabled: sm.config.WaitForConnections,
|
||||
})
|
||||
|
||||
// Task: Cleanup temporary files
|
||||
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
|
||||
Name: "cleanup_temp_files",
|
||||
Priority: 10,
|
||||
Timeout: 5 * time.Second,
|
||||
Task: func(ctx context.Context) error {
|
||||
if sm.config.CleanupTempFiles {
|
||||
return sm.cleanupTempFiles(ctx)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Critical: false,
|
||||
Enabled: sm.config.CleanupTempFiles,
|
||||
})
|
||||
|
||||
// Task: Notify external systems
|
||||
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
|
||||
Name: "notify_external",
|
||||
Priority: 80,
|
||||
Timeout: 3 * time.Second,
|
||||
Task: func(ctx context.Context) error {
|
||||
if sm.config.NotifyExternal {
|
||||
return sm.notifyExternalSystems(ctx)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Critical: false,
|
||||
Enabled: sm.config.NotifyExternal,
|
||||
})
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) signalHandler() {
|
||||
for {
|
||||
select {
|
||||
case sig := <-sm.signalChannel:
|
||||
switch sig {
|
||||
case syscall.SIGINT, syscall.SIGTERM:
|
||||
// Graceful shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sm.config.GracefulTimeout)
|
||||
if err := sm.Shutdown(ctx); err != nil {
|
||||
sm.recordShutdownError("Graceful shutdown failed from signal", err)
|
||||
cancel()
|
||||
// Force shutdown if graceful fails
|
||||
forceCtx, forceCancel := context.WithTimeout(context.Background(), sm.config.ForceTimeout)
|
||||
if err := sm.ForceShutdown(forceCtx); err != nil {
|
||||
sm.recordShutdownError("Force shutdown error in timeout scenario", err)
|
||||
// CRITICAL FIX: Escalate force shutdown failure to emergency protocols
|
||||
sm.triggerEmergencyShutdown("Force shutdown failed after graceful timeout", err)
|
||||
}
|
||||
forceCancel()
|
||||
}
|
||||
cancel()
|
||||
return
|
||||
case syscall.SIGQUIT:
|
||||
// Force shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sm.config.ForceTimeout)
|
||||
if err := sm.ForceShutdown(ctx); err != nil {
|
||||
sm.recordShutdownError("Force shutdown error in SIGQUIT handler", err)
|
||||
// CRITICAL FIX: Escalate force shutdown failure to emergency protocols
|
||||
sm.triggerEmergencyShutdown("Force shutdown failed on SIGQUIT", err)
|
||||
}
|
||||
cancel()
|
||||
return
|
||||
case syscall.SIGHUP:
|
||||
// Reload signal - could be used for configuration reload
|
||||
// For now, just log it
|
||||
continue
|
||||
}
|
||||
case <-sm.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) performShutdown(ctx context.Context) error {
|
||||
sm.wg.Add(1)
|
||||
defer sm.wg.Done()
|
||||
|
||||
// Create timeout context for entire shutdown
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, sm.config.GracefulTimeout)
|
||||
defer cancel()
|
||||
|
||||
var phaseErrors []error
|
||||
|
||||
// Phase 1: Call shutdown started hooks
|
||||
if err := sm.callHooks(shutdownCtx, "OnShutdownStarted", nil); err != nil {
|
||||
wrapped := fmt.Errorf("shutdown start hooks failed: %w", err)
|
||||
sm.recordShutdownError("Shutdown started hook failure", wrapped)
|
||||
phaseErrors = append(phaseErrors, wrapped)
|
||||
}
|
||||
|
||||
// Phase 2: Stop modules
|
||||
sm.state = ShutdownStateModuleStop
|
||||
if sm.registry != nil {
|
||||
if err := sm.registry.StopAll(shutdownCtx); err != nil {
|
||||
wrapped := fmt.Errorf("failed to stop modules: %w", err)
|
||||
sm.recordShutdownError("Module stop failure", wrapped)
|
||||
phaseErrors = append(phaseErrors, wrapped)
|
||||
}
|
||||
}
|
||||
|
||||
// Call modules stopped hooks
|
||||
if err := sm.callHooks(shutdownCtx, "OnModulesStopped", nil); err != nil {
|
||||
wrapped := fmt.Errorf("modules stopped hooks failed: %w", err)
|
||||
sm.recordShutdownError("Modules stopped hook failure", wrapped)
|
||||
phaseErrors = append(phaseErrors, wrapped)
|
||||
}
|
||||
|
||||
// Phase 3: Execute shutdown tasks
|
||||
sm.state = ShutdownStateCleanup
|
||||
if err := sm.callHooks(shutdownCtx, "OnCleanupStarted", nil); err != nil {
|
||||
wrapped := fmt.Errorf("cleanup hooks failed: %w", err)
|
||||
sm.recordShutdownError("Cleanup hook failure", wrapped)
|
||||
phaseErrors = append(phaseErrors, wrapped)
|
||||
}
|
||||
|
||||
if err := sm.executeShutdownTasks(shutdownCtx); err != nil {
|
||||
wrapped := fmt.Errorf("shutdown tasks failed: %w", err)
|
||||
sm.recordShutdownError("Shutdown task execution failure", wrapped)
|
||||
phaseErrors = append(phaseErrors, wrapped)
|
||||
}
|
||||
|
||||
// Phase 4: Final cleanup
|
||||
if len(phaseErrors) > 0 {
|
||||
finalErr := errors.Join(phaseErrors...)
|
||||
sm.state = ShutdownStateFailed
|
||||
if err := sm.callHooks(shutdownCtx, "OnShutdownFailed", finalErr); err != nil {
|
||||
wrapped := fmt.Errorf("shutdown failed hook error: %w", err)
|
||||
sm.recordShutdownError("Shutdown failed hook error", wrapped)
|
||||
finalErr = errors.Join(finalErr, wrapped)
|
||||
// CRITICAL FIX: Escalate hook failure during shutdown failed state
|
||||
sm.triggerEmergencyShutdown("Shutdown failed hook error", wrapped)
|
||||
}
|
||||
return finalErr
|
||||
}
|
||||
|
||||
sm.state = ShutdownStateCompleted
|
||||
if err := sm.callHooks(shutdownCtx, "OnShutdownCompleted", nil); err != nil {
|
||||
wrapped := fmt.Errorf("shutdown completed hook error: %w", err)
|
||||
sm.recordShutdownError("Shutdown completed hook error", wrapped)
|
||||
// CRITICAL FIX: Log but don't fail shutdown for completion hook errors
|
||||
// These are non-critical notifications that shouldn't prevent successful shutdown
|
||||
sm.logger.Warn("Shutdown completed hook failed", "error", wrapped)
|
||||
// Don't return error for completion hook failures - shutdown was successful
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) executeShutdownTasks(ctx context.Context) error {
|
||||
if sm.config.ParallelShutdown {
|
||||
return sm.executeTasksParallel(ctx)
|
||||
} else {
|
||||
return sm.executeTasksSequential(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) executeTasksSequential(ctx context.Context) error {
|
||||
var lastErr error
|
||||
|
||||
for _, task := range sm.shutdownTasks {
|
||||
if !task.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := sm.executeTask(ctx, task); err != nil {
|
||||
lastErr = err
|
||||
if task.Critical {
|
||||
return fmt.Errorf("critical task %s failed: %w", task.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) executeTasksParallel(ctx context.Context) error {
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, len(sm.shutdownTasks))
|
||||
|
||||
// Group tasks by priority
|
||||
priorityGroups := sm.groupTasksByPriority()
|
||||
|
||||
// Execute each priority group sequentially, but tasks within group in parallel
|
||||
for _, tasks := range priorityGroups {
|
||||
for _, task := range tasks {
|
||||
if !task.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(t ShutdownTask) {
|
||||
defer wg.Done()
|
||||
if err := sm.executeTask(ctx, t); err != nil {
|
||||
errors <- fmt.Errorf("task %s failed: %w", t.Name, err)
|
||||
}
|
||||
}(task)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
close(errors)
|
||||
|
||||
// Collect errors
|
||||
var criticalErr error
|
||||
var lastErr error
|
||||
for err := range errors {
|
||||
lastErr = err
|
||||
// Check if this was from a critical task
|
||||
for _, task := range sm.shutdownTasks {
|
||||
if task.Critical && fmt.Sprintf("task %s failed:", task.Name) == err.Error()[:len(fmt.Sprintf("task %s failed:", task.Name))] {
|
||||
criticalErr = err
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if criticalErr != nil {
|
||||
return criticalErr
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) executeTask(ctx context.Context, task ShutdownTask) error {
|
||||
// Create timeout context for the task
|
||||
taskCtx, cancel := context.WithTimeout(ctx, task.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Execute task with retry
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= sm.config.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
select {
|
||||
case <-time.After(sm.config.RetryDelay):
|
||||
case <-taskCtx.Done():
|
||||
return taskCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
err := task.Task(taskCtx)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
attemptNumber := attempt + 1
|
||||
|
||||
sm.recordShutdownError(
|
||||
fmt.Sprintf("Shutdown task %s failed", task.Name),
|
||||
fmt.Errorf("attempt %d: %w", attemptNumber, err),
|
||||
"task", task.Name,
|
||||
"attempt", attemptNumber,
|
||||
)
|
||||
|
||||
// Call error handler if provided
|
||||
if task.OnError != nil {
|
||||
task.OnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("task failed after %d attempts: %w", sm.config.MaxRetries, lastErr)
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) callHooks(ctx context.Context, hookMethod string, cause error) error {
|
||||
var hookErrors []error
|
||||
|
||||
for _, hook := range sm.shutdownHooks {
|
||||
hookName := fmt.Sprintf("%T", hook)
|
||||
var err error
|
||||
|
||||
switch hookMethod {
|
||||
case "OnShutdownStarted":
|
||||
err = hook.OnShutdownStarted(ctx)
|
||||
case "OnModulesStopped":
|
||||
err = hook.OnModulesStopped(ctx)
|
||||
case "OnCleanupStarted":
|
||||
err = hook.OnCleanupStarted(ctx)
|
||||
case "OnShutdownCompleted":
|
||||
err = hook.OnShutdownCompleted(ctx)
|
||||
case "OnShutdownFailed":
|
||||
err = hook.OnShutdownFailed(ctx, cause)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
recordContext := fmt.Sprintf("%s hook failure (%s)", hookMethod, hookName)
|
||||
sm.recordShutdownError(recordContext, err, "hook", hookName, "phase", hookMethod)
|
||||
hookErrors = append(hookErrors, fmt.Errorf("%s: %w", recordContext, err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(hookErrors) > 0 {
|
||||
return errors.Join(hookErrors...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) sortTasksByPriority() {
|
||||
// Simple bubble sort by priority (descending)
|
||||
for i := 0; i < len(sm.shutdownTasks); i++ {
|
||||
for j := i + 1; j < len(sm.shutdownTasks); j++ {
|
||||
if sm.shutdownTasks[j].Priority > sm.shutdownTasks[i].Priority {
|
||||
sm.shutdownTasks[i], sm.shutdownTasks[j] = sm.shutdownTasks[j], sm.shutdownTasks[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) groupTasksByPriority() [][]ShutdownTask {
|
||||
groups := make(map[int][]ShutdownTask)
|
||||
|
||||
for _, task := range sm.shutdownTasks {
|
||||
groups[task.Priority] = append(groups[task.Priority], task)
|
||||
}
|
||||
|
||||
// Convert to sorted slice
|
||||
var priorities []int
|
||||
for priority := range groups {
|
||||
priorities = append(priorities, priority)
|
||||
}
|
||||
|
||||
// Sort priorities descending
|
||||
for i := 0; i < len(priorities); i++ {
|
||||
for j := i + 1; j < len(priorities); j++ {
|
||||
if priorities[j] > priorities[i] {
|
||||
priorities[i], priorities[j] = priorities[j], priorities[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var result [][]ShutdownTask
|
||||
for _, priority := range priorities {
|
||||
result = append(result, groups[priority])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Default task implementations
|
||||
|
||||
func (sm *ShutdownManager) saveApplicationState(ctx context.Context) error {
|
||||
// Save application state to disk
|
||||
// This would save things like current configuration, runtime state, etc.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) closeExternalConnections(ctx context.Context) error {
|
||||
// Close database connections, external API connections, etc.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) cleanupTempFiles(ctx context.Context) error {
|
||||
// Remove temporary files, logs, caches, etc.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) notifyExternalSystems(ctx context.Context) error {
|
||||
// Notify external systems that this instance is shutting down
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) recordShutdownError(message string, err error, attrs ...interface{}) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
attrCopy := append([]interface{}{}, attrs...)
|
||||
wrapped, txHash, attrsWithTx := enrichErrorWithTxHash(message, err, attrCopy)
|
||||
|
||||
sm.errMu.Lock()
|
||||
sm.shutdownErrors = append(sm.shutdownErrors, wrapped)
|
||||
sm.shutdownErrorDetails = append(sm.shutdownErrorDetails, RecordedError{
|
||||
Err: wrapped,
|
||||
TxHash: txHash,
|
||||
})
|
||||
sm.errMu.Unlock()
|
||||
|
||||
if sm.logger != nil {
|
||||
kv := append([]interface{}{}, attrsWithTx...)
|
||||
kv = append(kv, "error", err)
|
||||
args := append([]interface{}{message}, kv...)
|
||||
sm.logger.Error(args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) combinedShutdownError() error {
|
||||
sm.errMu.Lock()
|
||||
defer sm.errMu.Unlock()
|
||||
|
||||
if len(sm.shutdownErrors) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
errs := make([]error, len(sm.shutdownErrors))
|
||||
copy(errs, sm.shutdownErrors)
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// ShutdownErrors returns a copy of recorded shutdown errors for diagnostics.
|
||||
func (sm *ShutdownManager) ShutdownErrors() []error {
|
||||
sm.errMu.Lock()
|
||||
defer sm.errMu.Unlock()
|
||||
|
||||
if len(sm.shutdownErrors) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
errs := make([]error, len(sm.shutdownErrors))
|
||||
copy(errs, sm.shutdownErrors)
|
||||
return errs
|
||||
}
|
||||
|
||||
// ShutdownErrorDetails returns recorded errors with associated metadata such as tx hash.
|
||||
func (sm *ShutdownManager) ShutdownErrorDetails() []RecordedError {
|
||||
sm.errMu.Lock()
|
||||
defer sm.errMu.Unlock()
|
||||
|
||||
if len(sm.shutdownErrorDetails) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
details := make([]RecordedError, len(sm.shutdownErrorDetails))
|
||||
copy(details, sm.shutdownErrorDetails)
|
||||
return details
|
||||
}
|
||||
|
||||
// DefaultShutdownHook provides a basic implementation of ShutdownHook
|
||||
type DefaultShutdownHook struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func NewDefaultShutdownHook(name string) *DefaultShutdownHook {
|
||||
return &DefaultShutdownHook{name: name}
|
||||
}
|
||||
|
||||
// triggerEmergencyShutdown performs emergency shutdown procedures when critical failures occur
|
||||
func (sm *ShutdownManager) triggerEmergencyShutdown(reason string, err error) {
|
||||
sm.logger.Error("EMERGENCY SHUTDOWN TRIGGERED",
|
||||
"reason", reason,
|
||||
"error", err,
|
||||
"state", sm.state,
|
||||
"timestamp", time.Now())
|
||||
|
||||
// Set emergency state
|
||||
sm.mu.Lock()
|
||||
sm.state = ShutdownStateFailed
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Attempt to signal all processes to terminate immediately
|
||||
// This is a last-resort mechanism
|
||||
if sm.emergencyHandler != nil {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := sm.emergencyHandler(ctx, reason, err); err != nil {
|
||||
sm.logger.Error("Emergency handler failed", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Log to all available outputs
|
||||
sm.recordShutdownError("EMERGENCY_SHUTDOWN", fmt.Errorf("%s: %w", reason, err))
|
||||
|
||||
// Attempt to notify monitoring systems if available
|
||||
if len(sm.shutdownHooks) > 0 {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// CRITICAL FIX: Log emergency shutdown notification failures
|
||||
if err := sm.callHooks(ctx, "OnEmergencyShutdown", fmt.Errorf("%s: %w", reason, err)); err != nil {
|
||||
sm.logger.Warn("Failed to call emergency shutdown hooks",
|
||||
"error", err,
|
||||
"reason", reason)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (dsh *DefaultShutdownHook) OnShutdownStarted(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dsh *DefaultShutdownHook) OnModulesStopped(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dsh *DefaultShutdownHook) OnCleanupStarted(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dsh *DefaultShutdownHook) OnShutdownCompleted(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dsh *DefaultShutdownHook) OnShutdownFailed(ctx context.Context, err error) error {
|
||||
return nil
|
||||
}
|
||||
111
orig/pkg/lifecycle/shutdown_manager_test.go
Normal file
111
orig/pkg/lifecycle/shutdown_manager_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testShutdownHook struct {
|
||||
errs map[string]error
|
||||
lastFailure error
|
||||
}
|
||||
|
||||
func (h *testShutdownHook) OnShutdownStarted(ctx context.Context) error {
|
||||
return h.errs["OnShutdownStarted"]
|
||||
}
|
||||
|
||||
func (h *testShutdownHook) OnModulesStopped(ctx context.Context) error {
|
||||
return h.errs["OnModulesStopped"]
|
||||
}
|
||||
|
||||
func (h *testShutdownHook) OnCleanupStarted(ctx context.Context) error {
|
||||
return h.errs["OnCleanupStarted"]
|
||||
}
|
||||
|
||||
func (h *testShutdownHook) OnShutdownCompleted(ctx context.Context) error {
|
||||
return h.errs["OnShutdownCompleted"]
|
||||
}
|
||||
|
||||
func (h *testShutdownHook) OnShutdownFailed(ctx context.Context, err error) error {
|
||||
h.lastFailure = err
|
||||
return h.errs["OnShutdownFailed"]
|
||||
}
|
||||
|
||||
func TestShutdownManagerErrorAggregation(t *testing.T) {
|
||||
sm := NewShutdownManager(nil, ShutdownConfig{})
|
||||
sm.logger = nil
|
||||
|
||||
txHash := "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
firstErr := fmt.Errorf("first failure on %s", txHash)
|
||||
secondErr := errors.New("second")
|
||||
|
||||
sm.recordShutdownError("first error", firstErr)
|
||||
sm.recordShutdownError("second error", secondErr)
|
||||
|
||||
if got := len(sm.shutdownErrors); got != 2 {
|
||||
t.Fatalf("expected 2 recorded errors, got %d", got)
|
||||
}
|
||||
|
||||
combined := sm.combinedShutdownError()
|
||||
if combined == nil {
|
||||
t.Fatal("expected combined shutdown error, got nil")
|
||||
}
|
||||
if !errors.Is(combined, firstErr) || !errors.Is(combined, secondErr) {
|
||||
t.Fatalf("combined error does not contain original errors: %v", combined)
|
||||
}
|
||||
|
||||
exportedErrors := sm.ShutdownErrors()
|
||||
if len(exportedErrors) != 2 {
|
||||
t.Fatalf("expected exported error slice of length 2, got %d", len(exportedErrors))
|
||||
}
|
||||
|
||||
details := sm.ShutdownErrorDetails()
|
||||
if len(details) != 2 {
|
||||
t.Fatalf("expected error detail slice of length 2, got %d", len(details))
|
||||
}
|
||||
if details[0].TxHash != txHash {
|
||||
t.Fatalf("expected recorded error to track tx hash %s, got %s", txHash, details[0].TxHash)
|
||||
}
|
||||
if details[0].Err == nil || !strings.Contains(details[0].Err.Error(), "tx_hash="+txHash) {
|
||||
t.Fatalf("expected recorded error message to include tx hash, got %v", details[0].Err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownManagerCallHooksAggregatesErrors(t *testing.T) {
|
||||
sm := NewShutdownManager(nil, ShutdownConfig{})
|
||||
sm.logger = nil
|
||||
sm.shutdownErrors = nil
|
||||
|
||||
hookErrA := errors.New("hookA failure")
|
||||
hookErrB := errors.New("hookB failure")
|
||||
hookA := &testShutdownHook{
|
||||
errs: map[string]error{
|
||||
"OnShutdownFailed": hookErrA,
|
||||
},
|
||||
}
|
||||
hookB := &testShutdownHook{
|
||||
errs: map[string]error{
|
||||
"OnShutdownFailed": hookErrB,
|
||||
},
|
||||
}
|
||||
|
||||
sm.shutdownHooks = []ShutdownHook{hookA, hookB}
|
||||
|
||||
cause := errors.New("original failure")
|
||||
err := sm.callHooks(context.Background(), "OnShutdownFailed", cause)
|
||||
if err == nil {
|
||||
t.Fatal("expected aggregated error from hooks, got nil")
|
||||
}
|
||||
if !errors.Is(err, hookErrA) || !errors.Is(err, hookErrB) {
|
||||
t.Fatalf("expected aggregated error to contain hook failures, got %v", err)
|
||||
}
|
||||
if hookA.lastFailure != cause || hookB.lastFailure != cause {
|
||||
t.Fatal("expected hook to receive original failure cause")
|
||||
}
|
||||
if len(sm.ShutdownErrors()) != 2 {
|
||||
t.Fatalf("expected shutdown errors to be recorded for each hook failure, got %d", len(sm.ShutdownErrors()))
|
||||
}
|
||||
}
|
||||
660
orig/pkg/lifecycle/state_machine.go
Normal file
660
orig/pkg/lifecycle/state_machine.go
Normal file
@@ -0,0 +1,660 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StateMachine manages module state transitions and enforces valid state changes
|
||||
type StateMachine struct {
|
||||
currentState ModuleState
|
||||
transitions map[ModuleState][]ModuleState
|
||||
stateHandlers map[ModuleState]StateHandler
|
||||
transitionHooks map[string]TransitionHook
|
||||
history []StateTransition
|
||||
module Module
|
||||
config StateMachineConfig
|
||||
mu sync.RWMutex
|
||||
metrics StateMachineMetrics
|
||||
}
|
||||
|
||||
// StateHandler handles operations when entering a specific state
|
||||
type StateHandler func(ctx context.Context, machine *StateMachine) error
|
||||
|
||||
// TransitionHook is called before or after state transitions
|
||||
type TransitionHook func(ctx context.Context, from, to ModuleState, machine *StateMachine) error
|
||||
|
||||
// StateTransition represents a state change event
|
||||
type StateTransition struct {
|
||||
From ModuleState `json:"from"`
|
||||
To ModuleState `json:"to"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Success bool `json:"success"`
|
||||
Error error `json:"error,omitempty"`
|
||||
Trigger string `json:"trigger"`
|
||||
Context map[string]interface{} `json:"context"`
|
||||
}
|
||||
|
||||
// StateMachineConfig configures state machine behavior
|
||||
type StateMachineConfig struct {
|
||||
InitialState ModuleState `json:"initial_state"`
|
||||
TransitionTimeout time.Duration `json:"transition_timeout"`
|
||||
MaxHistorySize int `json:"max_history_size"`
|
||||
EnableMetrics bool `json:"enable_metrics"`
|
||||
EnableValidation bool `json:"enable_validation"`
|
||||
AllowConcurrent bool `json:"allow_concurrent"`
|
||||
RetryFailedTransitions bool `json:"retry_failed_transitions"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
RetryDelay time.Duration `json:"retry_delay"`
|
||||
}
|
||||
|
||||
// StateMachineMetrics tracks state machine performance
|
||||
type StateMachineMetrics struct {
|
||||
TotalTransitions int64 `json:"total_transitions"`
|
||||
SuccessfulTransitions int64 `json:"successful_transitions"`
|
||||
FailedTransitions int64 `json:"failed_transitions"`
|
||||
StateDistribution map[ModuleState]int64 `json:"state_distribution"`
|
||||
TransitionTimes map[string]time.Duration `json:"transition_times"`
|
||||
AverageTransitionTime time.Duration `json:"average_transition_time"`
|
||||
LongestTransition time.Duration `json:"longest_transition"`
|
||||
LastTransition time.Time `json:"last_transition"`
|
||||
CurrentStateDuration time.Duration `json:"current_state_duration"`
|
||||
stateEnterTime time.Time
|
||||
}
|
||||
|
||||
// NewStateMachine creates a new state machine for a module
|
||||
func NewStateMachine(module Module, config StateMachineConfig) *StateMachine {
|
||||
sm := &StateMachine{
|
||||
currentState: config.InitialState,
|
||||
transitions: createDefaultTransitions(),
|
||||
stateHandlers: make(map[ModuleState]StateHandler),
|
||||
transitionHooks: make(map[string]TransitionHook),
|
||||
history: make([]StateTransition, 0),
|
||||
module: module,
|
||||
config: config,
|
||||
metrics: StateMachineMetrics{
|
||||
StateDistribution: make(map[ModuleState]int64),
|
||||
TransitionTimes: make(map[string]time.Duration),
|
||||
stateEnterTime: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
// Set default config values
|
||||
if sm.config.TransitionTimeout == 0 {
|
||||
sm.config.TransitionTimeout = 30 * time.Second
|
||||
}
|
||||
if sm.config.MaxHistorySize == 0 {
|
||||
sm.config.MaxHistorySize = 100
|
||||
}
|
||||
if sm.config.MaxRetries == 0 {
|
||||
sm.config.MaxRetries = 3
|
||||
}
|
||||
if sm.config.RetryDelay == 0 {
|
||||
sm.config.RetryDelay = time.Second
|
||||
}
|
||||
|
||||
// Setup default state handlers
|
||||
sm.setupDefaultHandlers()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// GetCurrentState returns the current state
|
||||
func (sm *StateMachine) GetCurrentState() ModuleState {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.currentState
|
||||
}
|
||||
|
||||
// CanTransition checks if a transition from current state to target state is valid
|
||||
func (sm *StateMachine) CanTransition(to ModuleState) bool {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
validTransitions, exists := sm.transitions[sm.currentState]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, validState := range validTransitions {
|
||||
if validState == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Transition performs a state transition
|
||||
func (sm *StateMachine) Transition(ctx context.Context, to ModuleState, trigger string) error {
|
||||
if !sm.config.AllowConcurrent {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
} else {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
}
|
||||
|
||||
return sm.performTransition(ctx, to, trigger)
|
||||
}
|
||||
|
||||
// TransitionWithRetry performs a state transition with retry logic
|
||||
func (sm *StateMachine) TransitionWithRetry(ctx context.Context, to ModuleState, trigger string) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= sm.config.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Wait before retrying
|
||||
select {
|
||||
case <-time.After(sm.config.RetryDelay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
err := sm.Transition(ctx, to, trigger)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// Don't retry if it's a validation error
|
||||
if !sm.config.RetryFailedTransitions {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("transition failed after %d attempts: %w", sm.config.MaxRetries, lastErr)
|
||||
}
|
||||
|
||||
// Initialize transitions to initialized state
|
||||
func (sm *StateMachine) Initialize(ctx context.Context) error {
|
||||
return sm.Transition(ctx, StateInitialized, "initialize")
|
||||
}
|
||||
|
||||
// Start transitions to running state
|
||||
func (sm *StateMachine) Start(ctx context.Context) error {
|
||||
return sm.Transition(ctx, StateRunning, "start")
|
||||
}
|
||||
|
||||
// Stop transitions to stopped state
|
||||
func (sm *StateMachine) Stop(ctx context.Context) error {
|
||||
return sm.Transition(ctx, StateStopped, "stop")
|
||||
}
|
||||
|
||||
// Pause transitions to paused state
|
||||
func (sm *StateMachine) Pause(ctx context.Context) error {
|
||||
return sm.Transition(ctx, StatePaused, "pause")
|
||||
}
|
||||
|
||||
// Resume transitions to running state from paused
|
||||
func (sm *StateMachine) Resume(ctx context.Context) error {
|
||||
return sm.Transition(ctx, StateRunning, "resume")
|
||||
}
|
||||
|
||||
// Fail transitions to failed state
|
||||
func (sm *StateMachine) Fail(ctx context.Context, reason string) error {
|
||||
return sm.Transition(ctx, StateFailed, fmt.Sprintf("fail: %s", reason))
|
||||
}
|
||||
|
||||
// SetStateHandler sets a custom handler for a specific state
|
||||
func (sm *StateMachine) SetStateHandler(state ModuleState, handler StateHandler) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.stateHandlers[state] = handler
|
||||
}
|
||||
|
||||
// SetTransitionHook sets a hook for state transitions
|
||||
func (sm *StateMachine) SetTransitionHook(name string, hook TransitionHook) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.transitionHooks[name] = hook
|
||||
}
|
||||
|
||||
// GetHistory returns the state transition history
|
||||
func (sm *StateMachine) GetHistory() []StateTransition {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
history := make([]StateTransition, len(sm.history))
|
||||
copy(history, sm.history)
|
||||
return history
|
||||
}
|
||||
|
||||
// GetMetrics returns state machine metrics
|
||||
func (sm *StateMachine) GetMetrics() StateMachineMetrics {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
// Update current state duration
|
||||
metrics := sm.metrics
|
||||
metrics.CurrentStateDuration = time.Since(sm.metrics.stateEnterTime)
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// AddCustomTransition adds a custom state transition rule
|
||||
func (sm *StateMachine) AddCustomTransition(from, to ModuleState) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if _, exists := sm.transitions[from]; !exists {
|
||||
sm.transitions[from] = make([]ModuleState, 0)
|
||||
}
|
||||
|
||||
// Check if transition already exists
|
||||
for _, existing := range sm.transitions[from] {
|
||||
if existing == to {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sm.transitions[from] = append(sm.transitions[from], to)
|
||||
}
|
||||
|
||||
// RemoveTransition removes a state transition rule
|
||||
func (sm *StateMachine) RemoveTransition(from, to ModuleState) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
transitions, exists := sm.transitions[from]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for i, transition := range transitions {
|
||||
if transition == to {
|
||||
sm.transitions[from] = append(transitions[:i], transitions[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetValidTransitions returns all valid transitions from current state
|
||||
func (sm *StateMachine) GetValidTransitions() []ModuleState {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
validTransitions, exists := sm.transitions[sm.currentState]
|
||||
if !exists {
|
||||
return []ModuleState{}
|
||||
}
|
||||
|
||||
result := make([]ModuleState, len(validTransitions))
|
||||
copy(result, validTransitions)
|
||||
return result
|
||||
}
|
||||
|
||||
// IsInState checks if the state machine is in a specific state
|
||||
func (sm *StateMachine) IsInState(state ModuleState) bool {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.currentState == state
|
||||
}
|
||||
|
||||
// IsInAnyState checks if the state machine is in any of the provided states
|
||||
func (sm *StateMachine) IsInAnyState(states ...ModuleState) bool {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
for _, state := range states {
|
||||
if sm.currentState == state {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WaitForState waits until the state machine reaches a specific state or times out
|
||||
func (sm *StateMachine) WaitForState(ctx context.Context, state ModuleState, timeout time.Duration) error {
|
||||
if sm.IsInState(state) {
|
||||
return nil
|
||||
}
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
return fmt.Errorf("timeout waiting for state %s", state)
|
||||
case <-ticker.C:
|
||||
if sm.IsInState(state) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets the state machine to its initial state
|
||||
func (sm *StateMachine) Reset(ctx context.Context) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
// Clear history
|
||||
sm.history = make([]StateTransition, 0)
|
||||
|
||||
// Reset metrics
|
||||
sm.metrics = StateMachineMetrics{
|
||||
StateDistribution: make(map[ModuleState]int64),
|
||||
TransitionTimes: make(map[string]time.Duration),
|
||||
stateEnterTime: time.Now(),
|
||||
}
|
||||
|
||||
// Transition to initial state
|
||||
return sm.performTransition(ctx, sm.config.InitialState, "reset")
|
||||
}
|
||||
|
||||
// Private methods
|
||||
|
||||
func (sm *StateMachine) performTransition(ctx context.Context, to ModuleState, trigger string) error {
|
||||
startTime := time.Now()
|
||||
from := sm.currentState
|
||||
|
||||
// Validate transition
|
||||
if sm.config.EnableValidation && !sm.canTransitionUnsafe(to) {
|
||||
return fmt.Errorf("invalid transition from %s to %s", from, to)
|
||||
}
|
||||
|
||||
// Create transition context
|
||||
transitionCtx := map[string]interface{}{
|
||||
"trigger": trigger,
|
||||
"start_time": startTime,
|
||||
"module_id": sm.module.GetID(),
|
||||
}
|
||||
|
||||
// Execute pre-transition hooks
|
||||
for name, hook := range sm.transitionHooks {
|
||||
hookCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
|
||||
err := func() error {
|
||||
defer cancel()
|
||||
if err := hook(hookCtx, from, to, sm); err != nil {
|
||||
return fmt.Errorf("pre-transition hook %s failed: %w", name, err)
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
sm.recordFailedTransition(from, to, startTime, trigger, err, transitionCtx)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Execute state-specific logic
|
||||
if err := sm.executeStateTransition(ctx, from, to); err != nil {
|
||||
sm.recordFailedTransition(from, to, startTime, trigger, err, transitionCtx)
|
||||
return fmt.Errorf("state transition failed: %w", err)
|
||||
}
|
||||
|
||||
// Update current state
|
||||
sm.currentState = to
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// Update metrics
|
||||
if sm.config.EnableMetrics {
|
||||
sm.updateMetrics(from, to, duration)
|
||||
}
|
||||
|
||||
// Record successful transition
|
||||
sm.recordSuccessfulTransition(from, to, startTime, duration, trigger, transitionCtx)
|
||||
|
||||
// Execute post-transition hooks
|
||||
for _, hook := range sm.transitionHooks {
|
||||
hookCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
|
||||
func() {
|
||||
defer cancel()
|
||||
if err := hook(hookCtx, from, to, sm); err != nil {
|
||||
// Log error but don't fail the transition
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Execute state handler for new state
|
||||
if handler, exists := sm.stateHandlers[to]; exists {
|
||||
handlerCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
|
||||
func() {
|
||||
defer cancel()
|
||||
if err := handler(handlerCtx, sm); err != nil {
|
||||
// Log error but don't fail the transition
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *StateMachine) executeStateTransition(ctx context.Context, from, to ModuleState) error {
|
||||
// Create timeout context for the operation
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
|
||||
defer cancel()
|
||||
|
||||
switch to {
|
||||
case StateInitialized:
|
||||
return sm.module.Initialize(timeoutCtx, ModuleConfig{})
|
||||
case StateRunning:
|
||||
if from == StatePaused {
|
||||
return sm.module.Resume(timeoutCtx)
|
||||
}
|
||||
return sm.module.Start(timeoutCtx)
|
||||
case StateStopped:
|
||||
return sm.module.Stop(timeoutCtx)
|
||||
case StatePaused:
|
||||
return sm.module.Pause(timeoutCtx)
|
||||
case StateFailed:
|
||||
// Failed state doesn't require module action
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unknown target state: %s", to)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StateMachine) canTransitionUnsafe(to ModuleState) bool {
|
||||
validTransitions, exists := sm.transitions[sm.currentState]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, validState := range validTransitions {
|
||||
if validState == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (sm *StateMachine) recordSuccessfulTransition(from, to ModuleState, startTime time.Time, duration time.Duration, trigger string, context map[string]interface{}) {
|
||||
transition := StateTransition{
|
||||
From: from,
|
||||
To: to,
|
||||
Timestamp: startTime,
|
||||
Duration: duration,
|
||||
Success: true,
|
||||
Trigger: trigger,
|
||||
Context: context,
|
||||
}
|
||||
|
||||
sm.addToHistory(transition)
|
||||
}
|
||||
|
||||
func (sm *StateMachine) recordFailedTransition(from, to ModuleState, startTime time.Time, trigger string, err error, context map[string]interface{}) {
|
||||
transition := StateTransition{
|
||||
From: from,
|
||||
To: to,
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Success: false,
|
||||
Error: err,
|
||||
Trigger: trigger,
|
||||
Context: context,
|
||||
}
|
||||
|
||||
sm.addToHistory(transition)
|
||||
|
||||
if sm.config.EnableMetrics {
|
||||
sm.metrics.FailedTransitions++
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StateMachine) addToHistory(transition StateTransition) {
|
||||
sm.history = append(sm.history, transition)
|
||||
|
||||
// Trim history if it exceeds max size
|
||||
if len(sm.history) > sm.config.MaxHistorySize {
|
||||
sm.history = sm.history[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StateMachine) updateMetrics(from, to ModuleState, duration time.Duration) {
|
||||
sm.metrics.TotalTransitions++
|
||||
sm.metrics.SuccessfulTransitions++
|
||||
sm.metrics.StateDistribution[to]++
|
||||
sm.metrics.LastTransition = time.Now()
|
||||
|
||||
// Update transition times
|
||||
transitionKey := fmt.Sprintf("%s->%s", from, to)
|
||||
sm.metrics.TransitionTimes[transitionKey] = duration
|
||||
|
||||
// Update average transition time
|
||||
if sm.metrics.TotalTransitions > 0 {
|
||||
total := time.Duration(0)
|
||||
for _, d := range sm.metrics.TransitionTimes {
|
||||
total += d
|
||||
}
|
||||
sm.metrics.AverageTransitionTime = total / time.Duration(len(sm.metrics.TransitionTimes))
|
||||
}
|
||||
|
||||
// Update longest transition
|
||||
if duration > sm.metrics.LongestTransition {
|
||||
sm.metrics.LongestTransition = duration
|
||||
}
|
||||
|
||||
// Update state enter time for duration tracking
|
||||
sm.metrics.stateEnterTime = time.Now()
|
||||
}
|
||||
|
||||
func (sm *StateMachine) setupDefaultHandlers() {
|
||||
// Default handlers for common states
|
||||
sm.stateHandlers[StateInitialized] = func(ctx context.Context, machine *StateMachine) error {
|
||||
// State entered successfully
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.stateHandlers[StateRunning] = func(ctx context.Context, machine *StateMachine) error {
|
||||
// Module is now running
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.stateHandlers[StateStopped] = func(ctx context.Context, machine *StateMachine) error {
|
||||
// Module has stopped
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.stateHandlers[StatePaused] = func(ctx context.Context, machine *StateMachine) error {
|
||||
// Module is paused
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.stateHandlers[StateFailed] = func(ctx context.Context, machine *StateMachine) error {
|
||||
// Handle failure state - could trigger recovery logic
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// createDefaultTransitions creates the standard state transition rules
|
||||
func createDefaultTransitions() map[ModuleState][]ModuleState {
|
||||
return map[ModuleState][]ModuleState{
|
||||
StateUninitialized: {StateInitialized, StateFailed},
|
||||
StateInitialized: {StateStarting, StateStopped, StateFailed},
|
||||
StateStarting: {StateRunning, StateFailed},
|
||||
StateRunning: {StatePausing, StateStopping, StateFailed},
|
||||
StatePausing: {StatePaused, StateFailed},
|
||||
StatePaused: {StateResuming, StateStopping, StateFailed},
|
||||
StateResuming: {StateRunning, StateFailed},
|
||||
StateStopping: {StateStopped, StateFailed},
|
||||
StateStopped: {StateInitialized, StateStarting, StateFailed},
|
||||
StateFailed: {StateInitialized, StateStopped}, // Recovery paths
|
||||
}
|
||||
}
|
||||
|
||||
// StateMachineBuilder provides a fluent interface for building state machines
|
||||
type StateMachineBuilder struct {
|
||||
config StateMachineConfig
|
||||
stateHandlers map[ModuleState]StateHandler
|
||||
transitionHooks map[string]TransitionHook
|
||||
customTransitions map[ModuleState][]ModuleState
|
||||
}
|
||||
|
||||
// NewStateMachineBuilder creates a new state machine builder
|
||||
func NewStateMachineBuilder() *StateMachineBuilder {
|
||||
return &StateMachineBuilder{
|
||||
config: StateMachineConfig{
|
||||
InitialState: StateUninitialized,
|
||||
TransitionTimeout: 30 * time.Second,
|
||||
MaxHistorySize: 100,
|
||||
EnableMetrics: true,
|
||||
EnableValidation: true,
|
||||
},
|
||||
stateHandlers: make(map[ModuleState]StateHandler),
|
||||
transitionHooks: make(map[string]TransitionHook),
|
||||
customTransitions: make(map[ModuleState][]ModuleState),
|
||||
}
|
||||
}
|
||||
|
||||
// WithConfig sets the state machine configuration
|
||||
func (smb *StateMachineBuilder) WithConfig(config StateMachineConfig) *StateMachineBuilder {
|
||||
smb.config = config
|
||||
return smb
|
||||
}
|
||||
|
||||
// WithStateHandler adds a state handler
|
||||
func (smb *StateMachineBuilder) WithStateHandler(state ModuleState, handler StateHandler) *StateMachineBuilder {
|
||||
smb.stateHandlers[state] = handler
|
||||
return smb
|
||||
}
|
||||
|
||||
// WithTransitionHook adds a transition hook
|
||||
func (smb *StateMachineBuilder) WithTransitionHook(name string, hook TransitionHook) *StateMachineBuilder {
|
||||
smb.transitionHooks[name] = hook
|
||||
return smb
|
||||
}
|
||||
|
||||
// WithCustomTransition adds a custom transition rule
|
||||
func (smb *StateMachineBuilder) WithCustomTransition(from, to ModuleState) *StateMachineBuilder {
|
||||
if _, exists := smb.customTransitions[from]; !exists {
|
||||
smb.customTransitions[from] = make([]ModuleState, 0)
|
||||
}
|
||||
smb.customTransitions[from] = append(smb.customTransitions[from], to)
|
||||
return smb
|
||||
}
|
||||
|
||||
// Build creates the state machine
|
||||
func (smb *StateMachineBuilder) Build(module Module) *StateMachine {
|
||||
sm := NewStateMachine(module, smb.config)
|
||||
|
||||
// Add state handlers
|
||||
for state, handler := range smb.stateHandlers {
|
||||
sm.SetStateHandler(state, handler)
|
||||
}
|
||||
|
||||
// Add transition hooks
|
||||
for name, hook := range smb.transitionHooks {
|
||||
sm.SetTransitionHook(name, hook)
|
||||
}
|
||||
|
||||
// Add custom transitions
|
||||
for from, toStates := range smb.customTransitions {
|
||||
for _, to := range toStates {
|
||||
sm.AddCustomTransition(from, to)
|
||||
}
|
||||
}
|
||||
|
||||
return sm
|
||||
}
|
||||
Reference in New Issue
Block a user