fix: resolve all compilation issues across transport and lifecycle packages
- Fixed duplicate type declarations in transport package - Removed unused variables in lifecycle and dependency injection - Fixed big.Int arithmetic operations in uniswap contracts - Added missing methods to MetricsCollector (IncrementCounter, RecordLatency, etc.) - Fixed jitter calculation in TCP transport retry logic - Updated ComponentHealth field access to use transport type - Ensured all core packages build successfully All major compilation errors resolved: ✅ Transport package builds clean ✅ Lifecycle package builds clean ✅ Main MEV bot application builds clean ✅ Fixed method signature mismatches ✅ Resolved type conflicts and duplications 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
661
pkg/lifecycle/dependency_injection.go
Normal file
661
pkg/lifecycle/dependency_injection.go
Normal file
@@ -0,0 +1,661 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Container provides dependency injection functionality
|
||||
type Container struct {
|
||||
services map[reflect.Type]*ServiceDescriptor
|
||||
instances map[reflect.Type]interface{}
|
||||
namedServices map[string]*ServiceDescriptor
|
||||
namedInstances map[string]interface{}
|
||||
singletons map[reflect.Type]interface{}
|
||||
factories map[reflect.Type]FactoryFunc
|
||||
interceptors []Interceptor
|
||||
config ContainerConfig
|
||||
mu sync.RWMutex
|
||||
parent *Container
|
||||
scoped map[string]*Container
|
||||
}
|
||||
|
||||
// ServiceDescriptor describes how a service should be instantiated
|
||||
type ServiceDescriptor struct {
|
||||
ServiceType reflect.Type
|
||||
Implementation reflect.Type
|
||||
Lifetime ServiceLifetime
|
||||
Factory FactoryFunc
|
||||
Instance interface{}
|
||||
Name string
|
||||
Dependencies []reflect.Type
|
||||
Tags []string
|
||||
Metadata map[string]interface{}
|
||||
Interceptors []Interceptor
|
||||
}
|
||||
|
||||
// ServiceLifetime defines the lifetime of a service
|
||||
type ServiceLifetime string
|
||||
|
||||
const (
|
||||
Transient ServiceLifetime = "transient" // New instance every time
|
||||
Singleton ServiceLifetime = "singleton" // Single instance for container lifetime
|
||||
Scoped ServiceLifetime = "scoped" // Single instance per scope
|
||||
)
|
||||
|
||||
// FactoryFunc creates service instances
|
||||
type FactoryFunc func(container *Container) (interface{}, error)
|
||||
|
||||
// Interceptor can intercept service creation and method calls
|
||||
type Interceptor interface {
|
||||
Intercept(ctx context.Context, target interface{}, method string, args []interface{}) (interface{}, error)
|
||||
}
|
||||
|
||||
// ContainerConfig configures the container behavior
|
||||
type ContainerConfig struct {
|
||||
EnableReflection bool `json:"enable_reflection"`
|
||||
EnableCircularDetection bool `json:"enable_circular_detection"`
|
||||
EnableInterception bool `json:"enable_interception"`
|
||||
EnableValidation bool `json:"enable_validation"`
|
||||
MaxDepth int `json:"max_depth"`
|
||||
CacheInstances bool `json:"cache_instances"`
|
||||
}
|
||||
|
||||
// ServiceBuilder provides a fluent interface for service registration
|
||||
type ServiceBuilder struct {
|
||||
container *Container
|
||||
serviceType reflect.Type
|
||||
implType reflect.Type
|
||||
lifetime ServiceLifetime
|
||||
factory FactoryFunc
|
||||
instance interface{}
|
||||
name string
|
||||
tags []string
|
||||
metadata map[string]interface{}
|
||||
interceptors []Interceptor
|
||||
}
|
||||
|
||||
// NewContainer creates a new dependency injection container
|
||||
func NewContainer(config ContainerConfig) *Container {
|
||||
container := &Container{
|
||||
services: make(map[reflect.Type]*ServiceDescriptor),
|
||||
instances: make(map[reflect.Type]interface{}),
|
||||
namedServices: make(map[string]*ServiceDescriptor),
|
||||
namedInstances: make(map[string]interface{}),
|
||||
singletons: make(map[reflect.Type]interface{}),
|
||||
factories: make(map[reflect.Type]FactoryFunc),
|
||||
interceptors: make([]Interceptor, 0),
|
||||
config: config,
|
||||
scoped: make(map[string]*Container),
|
||||
}
|
||||
|
||||
// Set default configuration
|
||||
if container.config.MaxDepth == 0 {
|
||||
container.config.MaxDepth = 10
|
||||
}
|
||||
|
||||
return container
|
||||
}
|
||||
|
||||
// Register registers a service type with its implementation
|
||||
func (c *Container) Register(serviceType, implementationType interface{}) *ServiceBuilder {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
sType := reflect.TypeOf(serviceType)
|
||||
if sType.Kind() == reflect.Ptr {
|
||||
sType = sType.Elem()
|
||||
}
|
||||
if sType.Kind() == reflect.Interface {
|
||||
sType = reflect.TypeOf(serviceType).Elem()
|
||||
}
|
||||
|
||||
implType := reflect.TypeOf(implementationType)
|
||||
if implType.Kind() == reflect.Ptr {
|
||||
implType = implType.Elem()
|
||||
}
|
||||
|
||||
return &ServiceBuilder{
|
||||
container: c,
|
||||
serviceType: sType,
|
||||
implType: implType,
|
||||
lifetime: Transient,
|
||||
tags: make([]string, 0),
|
||||
metadata: make(map[string]interface{}),
|
||||
interceptors: make([]Interceptor, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterInstance registers a specific instance
|
||||
func (c *Container) RegisterInstance(serviceType interface{}, instance interface{}) *ServiceBuilder {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
sType := reflect.TypeOf(serviceType)
|
||||
if sType.Kind() == reflect.Ptr {
|
||||
sType = sType.Elem()
|
||||
}
|
||||
if sType.Kind() == reflect.Interface {
|
||||
sType = reflect.TypeOf(serviceType).Elem()
|
||||
}
|
||||
|
||||
return &ServiceBuilder{
|
||||
container: c,
|
||||
serviceType: sType,
|
||||
instance: instance,
|
||||
lifetime: Singleton,
|
||||
tags: make([]string, 0),
|
||||
metadata: make(map[string]interface{}),
|
||||
interceptors: make([]Interceptor, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterFactory registers a factory function for creating instances
|
||||
func (c *Container) RegisterFactory(serviceType interface{}, factory FactoryFunc) *ServiceBuilder {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
sType := reflect.TypeOf(serviceType)
|
||||
if sType.Kind() == reflect.Ptr {
|
||||
sType = sType.Elem()
|
||||
}
|
||||
if sType.Kind() == reflect.Interface {
|
||||
sType = reflect.TypeOf(serviceType).Elem()
|
||||
}
|
||||
|
||||
return &ServiceBuilder{
|
||||
container: c,
|
||||
serviceType: sType,
|
||||
factory: factory,
|
||||
lifetime: Transient,
|
||||
tags: make([]string, 0),
|
||||
metadata: make(map[string]interface{}),
|
||||
interceptors: make([]Interceptor, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve resolves a service instance by type
|
||||
func (c *Container) Resolve(serviceType interface{}) (interface{}, error) {
|
||||
sType := reflect.TypeOf(serviceType)
|
||||
if sType.Kind() == reflect.Ptr {
|
||||
sType = sType.Elem()
|
||||
}
|
||||
if sType.Kind() == reflect.Interface {
|
||||
sType = reflect.TypeOf(serviceType).Elem()
|
||||
}
|
||||
|
||||
return c.resolveType(sType, make(map[reflect.Type]bool), 0)
|
||||
}
|
||||
|
||||
// ResolveNamed resolves a named service instance
|
||||
func (c *Container) ResolveNamed(name string) (interface{}, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// Check if instance already exists
|
||||
if instance, exists := c.namedInstances[name]; exists {
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Get service descriptor
|
||||
descriptor, exists := c.namedServices[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("named service not found: %s", name)
|
||||
}
|
||||
|
||||
return c.createInstance(descriptor, make(map[reflect.Type]bool), 0)
|
||||
}
|
||||
|
||||
// ResolveAll resolves all services with a specific tag
|
||||
func (c *Container) ResolveAll(tag string) ([]interface{}, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
var instances []interface{}
|
||||
|
||||
for _, descriptor := range c.services {
|
||||
for _, serviceTag := range descriptor.Tags {
|
||||
if serviceTag == tag {
|
||||
instance, err := c.createInstance(descriptor, make(map[reflect.Type]bool), 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve service with tag %s: %w", tag, err)
|
||||
}
|
||||
instances = append(instances, instance)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// TryResolve attempts to resolve a service, returning nil if not found
|
||||
func (c *Container) TryResolve(serviceType interface{}) interface{} {
|
||||
instance, err := c.Resolve(serviceType)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return instance
|
||||
}
|
||||
|
||||
// IsRegistered checks if a service type is registered
|
||||
func (c *Container) IsRegistered(serviceType interface{}) bool {
|
||||
sType := reflect.TypeOf(serviceType)
|
||||
if sType.Kind() == reflect.Ptr {
|
||||
sType = sType.Elem()
|
||||
}
|
||||
if sType.Kind() == reflect.Interface {
|
||||
sType = reflect.TypeOf(serviceType).Elem()
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
_, exists := c.services[sType]
|
||||
return exists
|
||||
}
|
||||
|
||||
// CreateScope creates a new scoped container
|
||||
func (c *Container) CreateScope(name string) *Container {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
scope := &Container{
|
||||
services: make(map[reflect.Type]*ServiceDescriptor),
|
||||
instances: make(map[reflect.Type]interface{}),
|
||||
namedServices: make(map[string]*ServiceDescriptor),
|
||||
namedInstances: make(map[string]interface{}),
|
||||
singletons: make(map[reflect.Type]interface{}),
|
||||
factories: make(map[reflect.Type]FactoryFunc),
|
||||
interceptors: make([]Interceptor, 0),
|
||||
config: c.config,
|
||||
parent: c,
|
||||
scoped: make(map[string]*Container),
|
||||
}
|
||||
|
||||
c.scoped[name] = scope
|
||||
return scope
|
||||
}
|
||||
|
||||
// GetScope retrieves a named scope
|
||||
func (c *Container) GetScope(name string) (*Container, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
scope, exists := c.scoped[name]
|
||||
return scope, exists
|
||||
}
|
||||
|
||||
// AddInterceptor adds a global interceptor
|
||||
func (c *Container) AddInterceptor(interceptor Interceptor) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.interceptors = append(c.interceptors, interceptor)
|
||||
}
|
||||
|
||||
// GetRegistrations returns all service registrations
|
||||
func (c *Container) GetRegistrations() map[reflect.Type]*ServiceDescriptor {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
registrations := make(map[reflect.Type]*ServiceDescriptor)
|
||||
for t, desc := range c.services {
|
||||
registrations[t] = desc
|
||||
}
|
||||
return registrations
|
||||
}
|
||||
|
||||
// Validate validates all service registrations
|
||||
func (c *Container) Validate() error {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if !c.config.EnableValidation {
|
||||
return nil
|
||||
}
|
||||
|
||||
for serviceType, descriptor := range c.services {
|
||||
if err := c.validateDescriptor(serviceType, descriptor); err != nil {
|
||||
return fmt.Errorf("validation failed for service %s: %w", serviceType.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dispose cleans up the container and all instances
|
||||
func (c *Container) Dispose() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Dispose all scoped containers
|
||||
for _, scope := range c.scoped {
|
||||
scope.Dispose()
|
||||
}
|
||||
|
||||
// Clear all maps
|
||||
c.services = make(map[reflect.Type]*ServiceDescriptor)
|
||||
c.instances = make(map[reflect.Type]interface{})
|
||||
c.namedServices = make(map[string]*ServiceDescriptor)
|
||||
c.namedInstances = make(map[string]interface{})
|
||||
c.singletons = make(map[reflect.Type]interface{})
|
||||
c.factories = make(map[reflect.Type]FactoryFunc)
|
||||
c.scoped = make(map[string]*Container)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Private methods
|
||||
|
||||
func (c *Container) resolveType(serviceType reflect.Type, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
|
||||
if depth > c.config.MaxDepth {
|
||||
return nil, fmt.Errorf("maximum resolution depth exceeded for type %s", serviceType.String())
|
||||
}
|
||||
|
||||
// Check for circular dependencies
|
||||
if c.config.EnableCircularDetection && resolving[serviceType] {
|
||||
return nil, fmt.Errorf("circular dependency detected for type %s", serviceType.String())
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// Check if singleton instance exists
|
||||
if instance, exists := c.singletons[serviceType]; exists {
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Check if cached instance exists
|
||||
if c.config.CacheInstances {
|
||||
if instance, exists := c.instances[serviceType]; exists {
|
||||
return instance, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Get service descriptor
|
||||
descriptor, exists := c.services[serviceType]
|
||||
if !exists {
|
||||
// Try parent container
|
||||
if c.parent != nil {
|
||||
return c.parent.resolveType(serviceType, resolving, depth+1)
|
||||
}
|
||||
return nil, fmt.Errorf("service not registered: %s", serviceType.String())
|
||||
}
|
||||
|
||||
resolving[serviceType] = true
|
||||
defer delete(resolving, serviceType)
|
||||
|
||||
return c.createInstance(descriptor, resolving, depth+1)
|
||||
}
|
||||
|
||||
func (c *Container) createInstance(descriptor *ServiceDescriptor, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
|
||||
// Use existing instance if available
|
||||
if descriptor.Instance != nil {
|
||||
return descriptor.Instance, nil
|
||||
}
|
||||
|
||||
// Use factory if available
|
||||
if descriptor.Factory != nil {
|
||||
instance, err := descriptor.Factory(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("factory failed for %s: %w", descriptor.ServiceType.String(), err)
|
||||
}
|
||||
|
||||
if descriptor.Lifetime == Singleton {
|
||||
c.singletons[descriptor.ServiceType] = instance
|
||||
}
|
||||
|
||||
return c.applyInterceptors(instance, descriptor)
|
||||
}
|
||||
|
||||
// Create instance using reflection
|
||||
if descriptor.Implementation == nil {
|
||||
return nil, fmt.Errorf("no implementation or factory provided for %s", descriptor.ServiceType.String())
|
||||
}
|
||||
|
||||
instance, err := c.createInstanceByReflection(descriptor, resolving, depth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store singleton
|
||||
if descriptor.Lifetime == Singleton {
|
||||
c.singletons[descriptor.ServiceType] = instance
|
||||
}
|
||||
|
||||
return c.applyInterceptors(instance, descriptor)
|
||||
}
|
||||
|
||||
func (c *Container) createInstanceByReflection(descriptor *ServiceDescriptor, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
|
||||
if !c.config.EnableReflection {
|
||||
return nil, fmt.Errorf("reflection is disabled")
|
||||
}
|
||||
|
||||
implType := descriptor.Implementation
|
||||
if implType.Kind() == reflect.Ptr {
|
||||
implType = implType.Elem()
|
||||
}
|
||||
|
||||
// Find constructor (assumes first constructor or struct creation)
|
||||
var constructorFunc reflect.Value
|
||||
|
||||
// Look for constructor function
|
||||
_ = "New" + implType.Name() // constructorName not used yet
|
||||
if implType.PkgPath() != "" {
|
||||
// Try to find package-level constructor
|
||||
// This is simplified - in a real implementation you'd use build tags or reflection
|
||||
// to find the actual constructor functions
|
||||
}
|
||||
|
||||
// Create instance
|
||||
if constructorFunc.IsValid() {
|
||||
// Use constructor function
|
||||
return c.callConstructor(constructorFunc, resolving, depth)
|
||||
} else {
|
||||
// Create struct directly
|
||||
return c.createStruct(implType, resolving, depth)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Container) createStruct(structType reflect.Type, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
|
||||
// Create new instance
|
||||
instance := reflect.New(structType)
|
||||
elem := instance.Elem()
|
||||
|
||||
// Inject dependencies into fields
|
||||
for i := 0; i < elem.NumField(); i++ {
|
||||
field := elem.Field(i)
|
||||
fieldType := elem.Type().Field(i)
|
||||
|
||||
// Check for dependency injection tags
|
||||
if tag := fieldType.Tag.Get("inject"); tag != "" {
|
||||
if !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
var dependency interface{}
|
||||
var err error
|
||||
|
||||
if tag == "true" || tag == "" {
|
||||
// Inject by type
|
||||
dependency, err = c.resolveType(field.Type(), resolving, depth)
|
||||
} else {
|
||||
// Inject by name
|
||||
dependency, err = c.ResolveNamed(tag)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Check if injection is optional
|
||||
if optionalTag := fieldType.Tag.Get("optional"); optionalTag == "true" {
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("failed to inject dependency for field %s: %w", fieldType.Name, err)
|
||||
}
|
||||
|
||||
field.Set(reflect.ValueOf(dependency))
|
||||
}
|
||||
}
|
||||
|
||||
return instance.Interface(), nil
|
||||
}
|
||||
|
||||
func (c *Container) callConstructor(constructor reflect.Value, resolving map[reflect.Type]bool, depth int) (interface{}, error) {
|
||||
constructorType := constructor.Type()
|
||||
args := make([]reflect.Value, constructorType.NumIn())
|
||||
|
||||
// Resolve constructor arguments
|
||||
for i := 0; i < constructorType.NumIn(); i++ {
|
||||
argType := constructorType.In(i)
|
||||
arg, err := c.resolveType(argType, resolving, depth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve constructor argument %d (%s): %w", i, argType.String(), err)
|
||||
}
|
||||
args[i] = reflect.ValueOf(arg)
|
||||
}
|
||||
|
||||
// Call constructor
|
||||
results := constructor.Call(args)
|
||||
if len(results) == 0 {
|
||||
return nil, fmt.Errorf("constructor returned no values")
|
||||
}
|
||||
|
||||
instance := results[0].Interface()
|
||||
|
||||
// Check for error result
|
||||
if len(results) > 1 && !results[1].IsNil() {
|
||||
if err, ok := results[1].Interface().(error); ok {
|
||||
return nil, fmt.Errorf("constructor error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
func (c *Container) applyInterceptors(instance interface{}, descriptor *ServiceDescriptor) (interface{}, error) {
|
||||
if !c.config.EnableInterception {
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Apply service-specific interceptors
|
||||
for _, interceptor := range descriptor.Interceptors {
|
||||
// Apply interceptor (simplified - real implementation would create proxies)
|
||||
_ = interceptor
|
||||
}
|
||||
|
||||
// Apply global interceptors
|
||||
for _, interceptor := range c.interceptors {
|
||||
// Apply interceptor (simplified - real implementation would create proxies)
|
||||
_ = interceptor
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
func (c *Container) validateDescriptor(serviceType reflect.Type, descriptor *ServiceDescriptor) error {
|
||||
// Validate that implementation implements the service interface
|
||||
if descriptor.Implementation != nil {
|
||||
if serviceType.Kind() == reflect.Interface {
|
||||
if !descriptor.Implementation.Implements(serviceType) {
|
||||
return fmt.Errorf("implementation %s does not implement interface %s",
|
||||
descriptor.Implementation.String(), serviceType.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate dependencies
|
||||
for _, depType := range descriptor.Dependencies {
|
||||
if !c.IsRegistered(depType) && (c.parent == nil || !c.parent.IsRegistered(depType)) {
|
||||
return fmt.Errorf("dependency %s is not registered", depType.String())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServiceBuilder methods
|
||||
|
||||
// AsSingleton sets the service lifetime to singleton
|
||||
func (sb *ServiceBuilder) AsSingleton() *ServiceBuilder {
|
||||
sb.lifetime = Singleton
|
||||
return sb
|
||||
}
|
||||
|
||||
// AsTransient sets the service lifetime to transient
|
||||
func (sb *ServiceBuilder) AsTransient() *ServiceBuilder {
|
||||
sb.lifetime = Transient
|
||||
return sb
|
||||
}
|
||||
|
||||
// AsScoped sets the service lifetime to scoped
|
||||
func (sb *ServiceBuilder) AsScoped() *ServiceBuilder {
|
||||
sb.lifetime = Scoped
|
||||
return sb
|
||||
}
|
||||
|
||||
// WithName sets a name for the service
|
||||
func (sb *ServiceBuilder) WithName(name string) *ServiceBuilder {
|
||||
sb.name = name
|
||||
return sb
|
||||
}
|
||||
|
||||
// WithTag adds a tag to the service
|
||||
func (sb *ServiceBuilder) WithTag(tag string) *ServiceBuilder {
|
||||
sb.tags = append(sb.tags, tag)
|
||||
return sb
|
||||
}
|
||||
|
||||
// WithMetadata adds metadata to the service
|
||||
func (sb *ServiceBuilder) WithMetadata(key string, value interface{}) *ServiceBuilder {
|
||||
sb.metadata[key] = value
|
||||
return sb
|
||||
}
|
||||
|
||||
// WithInterceptor adds an interceptor to the service
|
||||
func (sb *ServiceBuilder) WithInterceptor(interceptor Interceptor) *ServiceBuilder {
|
||||
sb.interceptors = append(sb.interceptors, interceptor)
|
||||
return sb
|
||||
}
|
||||
|
||||
// Build finalizes the service registration
|
||||
func (sb *ServiceBuilder) Build() error {
|
||||
sb.container.mu.Lock()
|
||||
defer sb.container.mu.Unlock()
|
||||
|
||||
descriptor := &ServiceDescriptor{
|
||||
ServiceType: sb.serviceType,
|
||||
Implementation: sb.implType,
|
||||
Lifetime: sb.lifetime,
|
||||
Factory: sb.factory,
|
||||
Instance: sb.instance,
|
||||
Name: sb.name,
|
||||
Tags: sb.tags,
|
||||
Metadata: sb.metadata,
|
||||
Interceptors: sb.interceptors,
|
||||
}
|
||||
|
||||
// Store by type
|
||||
sb.container.services[sb.serviceType] = descriptor
|
||||
|
||||
// Store by name if provided
|
||||
if sb.name != "" {
|
||||
sb.container.namedServices[sb.name] = descriptor
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultInterceptor provides basic interception functionality
|
||||
type DefaultInterceptor struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func NewDefaultInterceptor(name string) *DefaultInterceptor {
|
||||
return &DefaultInterceptor{name: name}
|
||||
}
|
||||
|
||||
func (di *DefaultInterceptor) Intercept(ctx context.Context, target interface{}, method string, args []interface{}) (interface{}, error) {
|
||||
// Basic interception - could add logging, metrics, etc.
|
||||
return nil, nil
|
||||
}
|
||||
848
pkg/lifecycle/health_monitor.go
Normal file
848
pkg/lifecycle/health_monitor.go
Normal file
@@ -0,0 +1,848 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthMonitorImpl implements comprehensive health monitoring for modules
|
||||
type HealthMonitorImpl struct {
|
||||
monitors map[string]*ModuleMonitor
|
||||
config HealthMonitorConfig
|
||||
aggregator HealthAggregator
|
||||
notifier HealthNotifier
|
||||
metrics HealthMetrics
|
||||
rules []HealthRule
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
running bool
|
||||
}
|
||||
|
||||
// ModuleMonitor monitors a specific module's health
|
||||
type ModuleMonitor struct {
|
||||
moduleID string
|
||||
module *RegisteredModule
|
||||
config ModuleHealthConfig
|
||||
lastCheck time.Time
|
||||
checkCount int64
|
||||
successCount int64
|
||||
failureCount int64
|
||||
history []HealthCheckResult
|
||||
currentHealth ModuleHealth
|
||||
trend HealthTrend
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// HealthMonitorConfig configures the health monitoring system
|
||||
type HealthMonitorConfig struct {
|
||||
CheckInterval time.Duration `json:"check_interval"`
|
||||
CheckTimeout time.Duration `json:"check_timeout"`
|
||||
HistorySize int `json:"history_size"`
|
||||
FailureThreshold int `json:"failure_threshold"`
|
||||
RecoveryThreshold int `json:"recovery_threshold"`
|
||||
EnableNotifications bool `json:"enable_notifications"`
|
||||
EnableMetrics bool `json:"enable_metrics"`
|
||||
EnableTrends bool `json:"enable_trends"`
|
||||
ParallelChecks bool `json:"parallel_checks"`
|
||||
MaxConcurrentChecks int `json:"max_concurrent_checks"`
|
||||
}
|
||||
|
||||
// ModuleHealthConfig configures health checking for a specific module
|
||||
type ModuleHealthConfig struct {
|
||||
CheckInterval time.Duration `json:"check_interval"`
|
||||
CheckTimeout time.Duration `json:"check_timeout"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CriticalModule bool `json:"critical_module"`
|
||||
CustomChecks []HealthCheck `json:"custom_checks"`
|
||||
FailureThreshold int `json:"failure_threshold"`
|
||||
RecoveryThreshold int `json:"recovery_threshold"`
|
||||
AutoRestart bool `json:"auto_restart"`
|
||||
MaxRestarts int `json:"max_restarts"`
|
||||
RestartDelay time.Duration `json:"restart_delay"`
|
||||
}
|
||||
|
||||
// HealthCheck represents a custom health check
|
||||
type HealthCheck struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
CheckFunc func() error `json:"-"`
|
||||
Interval time.Duration `json:"interval"`
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
Critical bool `json:"critical"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// HealthCheckResult represents the result of a health check
|
||||
type HealthCheckResult struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Status HealthStatus `json:"status"`
|
||||
ResponseTime time.Duration `json:"response_time"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
Checks map[string]CheckResult `json:"checks"`
|
||||
Error error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// CheckResult represents the result of an individual check
|
||||
type CheckResult struct {
|
||||
Name string `json:"name"`
|
||||
Status HealthStatus `json:"status"`
|
||||
ResponseTime time.Duration `json:"response_time"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
Error error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// HealthTrend tracks health trends over time
|
||||
type HealthTrend struct {
|
||||
Direction TrendDirection `json:"direction"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
Slope float64 `json:"slope"`
|
||||
Prediction HealthStatus `json:"prediction"`
|
||||
TimeToAlert time.Duration `json:"time_to_alert"`
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
}
|
||||
|
||||
// TrendDirection indicates the health trend direction
|
||||
type TrendDirection string
|
||||
|
||||
const (
|
||||
TrendImproving TrendDirection = "improving"
|
||||
TrendStable TrendDirection = "stable"
|
||||
TrendDegrading TrendDirection = "degrading"
|
||||
TrendUnknown TrendDirection = "unknown"
|
||||
)
|
||||
|
||||
// HealthAggregator aggregates health status from multiple modules
|
||||
type HealthAggregator interface {
|
||||
AggregateHealth(modules map[string]ModuleHealth) OverallHealth
|
||||
CalculateSystemHealth(individual []ModuleHealth) HealthStatus
|
||||
GetHealthScore(health ModuleHealth) float64
|
||||
}
|
||||
|
||||
// HealthNotifier sends health notifications
|
||||
type HealthNotifier interface {
|
||||
NotifyHealthChange(moduleID string, oldHealth, newHealth ModuleHealth) error
|
||||
NotifySystemHealth(health OverallHealth) error
|
||||
NotifyAlert(alert HealthAlert) error
|
||||
}
|
||||
|
||||
// OverallHealth represents the overall system health
|
||||
type OverallHealth struct {
|
||||
Status HealthStatus `json:"status"`
|
||||
Score float64 `json:"score"`
|
||||
ModuleCount int `json:"module_count"`
|
||||
HealthyCount int `json:"healthy_count"`
|
||||
DegradedCount int `json:"degraded_count"`
|
||||
UnhealthyCount int `json:"unhealthy_count"`
|
||||
CriticalIssues []string `json:"critical_issues"`
|
||||
Modules map[string]ModuleHealth `json:"modules"`
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
Trends map[string]HealthTrend `json:"trends"`
|
||||
Recommendations []HealthRecommendation `json:"recommendations"`
|
||||
}
|
||||
|
||||
// HealthAlert represents a health alert
|
||||
type HealthAlert struct {
|
||||
ID string `json:"id"`
|
||||
ModuleID string `json:"module_id"`
|
||||
Severity AlertSeverity `json:"severity"`
|
||||
Type AlertType `json:"type"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Resolved bool `json:"resolved"`
|
||||
ResolvedAt time.Time `json:"resolved_at,omitempty"`
|
||||
}
|
||||
|
||||
// AlertSeverity defines alert severity levels
|
||||
type AlertSeverity string
|
||||
|
||||
const (
|
||||
SeverityInfo AlertSeverity = "info"
|
||||
SeverityWarning AlertSeverity = "warning"
|
||||
SeverityError AlertSeverity = "error"
|
||||
SeverityCritical AlertSeverity = "critical"
|
||||
)
|
||||
|
||||
// AlertType defines types of alerts
|
||||
type AlertType string
|
||||
|
||||
const (
|
||||
AlertHealthChange AlertType = "health_change"
|
||||
AlertThresholdBreach AlertType = "threshold_breach"
|
||||
AlertTrendAlert AlertType = "trend_alert"
|
||||
AlertSystemDown AlertType = "system_down"
|
||||
AlertRecovery AlertType = "recovery"
|
||||
)
|
||||
|
||||
// HealthRule defines rules for health evaluation
|
||||
type HealthRule struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Condition func(ModuleHealth) bool `json:"-"`
|
||||
Action func(string, ModuleHealth) error `json:"-"`
|
||||
Severity AlertSeverity `json:"severity"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// HealthRecommendation provides actionable health recommendations
|
||||
type HealthRecommendation struct {
|
||||
ModuleID string `json:"module_id"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Action string `json:"action"`
|
||||
Priority string `json:"priority"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// HealthMetrics tracks health monitoring metrics
|
||||
type HealthMetrics struct {
|
||||
ChecksPerformed int64 `json:"checks_performed"`
|
||||
ChecksSuccessful int64 `json:"checks_successful"`
|
||||
ChecksFailed int64 `json:"checks_failed"`
|
||||
AverageCheckTime time.Duration `json:"average_check_time"`
|
||||
AlertsGenerated int64 `json:"alerts_generated"`
|
||||
ModuleRestarts int64 `json:"module_restarts"`
|
||||
SystemDowntime time.Duration `json:"system_downtime"`
|
||||
ModuleHealthScores map[string]float64 `json:"module_health_scores"`
|
||||
TrendAccuracy float64 `json:"trend_accuracy"`
|
||||
}
|
||||
|
||||
// NewHealthMonitor creates a new health monitor
|
||||
func NewHealthMonitor(config HealthMonitorConfig) *HealthMonitorImpl {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
hm := &HealthMonitorImpl{
|
||||
monitors: make(map[string]*ModuleMonitor),
|
||||
config: config,
|
||||
aggregator: NewDefaultHealthAggregator(),
|
||||
notifier: NewDefaultHealthNotifier(),
|
||||
rules: make([]HealthRule, 0),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
metrics: HealthMetrics{
|
||||
ModuleHealthScores: make(map[string]float64),
|
||||
},
|
||||
}
|
||||
|
||||
// Set default configuration
|
||||
if hm.config.CheckInterval == 0 {
|
||||
hm.config.CheckInterval = 30 * time.Second
|
||||
}
|
||||
if hm.config.CheckTimeout == 0 {
|
||||
hm.config.CheckTimeout = 10 * time.Second
|
||||
}
|
||||
if hm.config.HistorySize == 0 {
|
||||
hm.config.HistorySize = 100
|
||||
}
|
||||
if hm.config.FailureThreshold == 0 {
|
||||
hm.config.FailureThreshold = 3
|
||||
}
|
||||
if hm.config.RecoveryThreshold == 0 {
|
||||
hm.config.RecoveryThreshold = 3
|
||||
}
|
||||
if hm.config.MaxConcurrentChecks == 0 {
|
||||
hm.config.MaxConcurrentChecks = 10
|
||||
}
|
||||
|
||||
// Setup default health rules
|
||||
hm.setupDefaultRules()
|
||||
|
||||
return hm
|
||||
}
|
||||
|
||||
// Start starts the health monitoring system
|
||||
func (hm *HealthMonitorImpl) Start() error {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if hm.running {
|
||||
return fmt.Errorf("health monitor already running")
|
||||
}
|
||||
|
||||
hm.running = true
|
||||
|
||||
// Start monitoring loop
|
||||
go hm.monitoringLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the health monitoring system
|
||||
func (hm *HealthMonitorImpl) Stop() error {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if !hm.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
hm.cancel()
|
||||
hm.running = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartMonitoring starts monitoring a specific module
|
||||
func (hm *HealthMonitorImpl) StartMonitoring(module *RegisteredModule) error {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
moduleID := module.ID
|
||||
|
||||
// Create module monitor
|
||||
monitor := &ModuleMonitor{
|
||||
moduleID: moduleID,
|
||||
module: module,
|
||||
config: ModuleHealthConfig{
|
||||
CheckInterval: hm.config.CheckInterval,
|
||||
CheckTimeout: hm.config.CheckTimeout,
|
||||
Enabled: true,
|
||||
CriticalModule: module.Config.CriticalModule,
|
||||
FailureThreshold: hm.config.FailureThreshold,
|
||||
RecoveryThreshold: hm.config.RecoveryThreshold,
|
||||
AutoRestart: module.Config.MaxRestarts > 0,
|
||||
MaxRestarts: module.Config.MaxRestarts,
|
||||
RestartDelay: module.Config.RestartDelay,
|
||||
},
|
||||
history: make([]HealthCheckResult, 0),
|
||||
currentHealth: ModuleHealth{
|
||||
Status: HealthUnknown,
|
||||
LastCheck: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
hm.monitors[moduleID] = monitor
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopMonitoring stops monitoring a specific module
|
||||
func (hm *HealthMonitorImpl) StopMonitoring(moduleID string) error {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
delete(hm.monitors, moduleID)
|
||||
delete(hm.metrics.ModuleHealthScores, moduleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckHealth performs a health check on a specific module
|
||||
func (hm *HealthMonitorImpl) CheckHealth(module *RegisteredModule) ModuleHealth {
|
||||
moduleID := module.ID
|
||||
|
||||
hm.mu.RLock()
|
||||
monitor, exists := hm.monitors[moduleID]
|
||||
hm.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return ModuleHealth{
|
||||
Status: HealthUnknown,
|
||||
Message: "Module not monitored",
|
||||
}
|
||||
}
|
||||
|
||||
return hm.performHealthCheck(monitor)
|
||||
}
|
||||
|
||||
// GetHealthStatus returns the health status of all monitored modules
|
||||
func (hm *HealthMonitorImpl) GetHealthStatus() map[string]ModuleHealth {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
|
||||
status := make(map[string]ModuleHealth)
|
||||
for moduleID, monitor := range hm.monitors {
|
||||
monitor.mu.RLock()
|
||||
status[moduleID] = monitor.currentHealth
|
||||
monitor.mu.RUnlock()
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// GetOverallHealth returns the overall system health
|
||||
func (hm *HealthMonitorImpl) GetOverallHealth() OverallHealth {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
|
||||
moduleHealths := make(map[string]ModuleHealth)
|
||||
for moduleID, monitor := range hm.monitors {
|
||||
monitor.mu.RLock()
|
||||
moduleHealths[moduleID] = monitor.currentHealth
|
||||
monitor.mu.RUnlock()
|
||||
}
|
||||
|
||||
return hm.aggregator.AggregateHealth(moduleHealths)
|
||||
}
|
||||
|
||||
// AddHealthRule adds a custom health rule
|
||||
func (hm *HealthMonitorImpl) AddHealthRule(rule HealthRule) {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
hm.rules = append(hm.rules, rule)
|
||||
}
|
||||
|
||||
// SetHealthAggregator sets a custom health aggregator
|
||||
func (hm *HealthMonitorImpl) SetHealthAggregator(aggregator HealthAggregator) {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
hm.aggregator = aggregator
|
||||
}
|
||||
|
||||
// SetHealthNotifier sets a custom health notifier
|
||||
func (hm *HealthMonitorImpl) SetHealthNotifier(notifier HealthNotifier) {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
hm.notifier = notifier
|
||||
}
|
||||
|
||||
// GetMetrics returns health monitoring metrics
|
||||
func (hm *HealthMonitorImpl) GetMetrics() HealthMetrics {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
return hm.metrics
|
||||
}
|
||||
|
||||
// Private methods
|
||||
|
||||
func (hm *HealthMonitorImpl) monitoringLoop() {
|
||||
ticker := time.NewTicker(hm.config.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hm.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
hm.performAllHealthChecks()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HealthMonitorImpl) performAllHealthChecks() {
|
||||
hm.mu.RLock()
|
||||
monitors := make([]*ModuleMonitor, 0, len(hm.monitors))
|
||||
for _, monitor := range hm.monitors {
|
||||
monitors = append(monitors, monitor)
|
||||
}
|
||||
hm.mu.RUnlock()
|
||||
|
||||
if hm.config.ParallelChecks {
|
||||
hm.performHealthChecksParallel(monitors)
|
||||
} else {
|
||||
hm.performHealthChecksSequential(monitors)
|
||||
}
|
||||
|
||||
// Update overall health and send notifications
|
||||
overallHealth := hm.GetOverallHealth()
|
||||
if hm.config.EnableNotifications {
|
||||
hm.notifier.NotifySystemHealth(overallHealth)
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HealthMonitorImpl) performHealthChecksSequential(monitors []*ModuleMonitor) {
|
||||
for _, monitor := range monitors {
|
||||
if monitor.config.Enabled {
|
||||
hm.performHealthCheck(monitor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HealthMonitorImpl) performHealthChecksParallel(monitors []*ModuleMonitor) {
|
||||
semaphore := make(chan struct{}, hm.config.MaxConcurrentChecks)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, monitor := range monitors {
|
||||
if monitor.config.Enabled {
|
||||
wg.Add(1)
|
||||
go func(m *ModuleMonitor) {
|
||||
defer wg.Done()
|
||||
semaphore <- struct{}{}
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
hm.performHealthCheck(m)
|
||||
}(monitor)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (hm *HealthMonitorImpl) performHealthCheck(monitor *ModuleMonitor) ModuleHealth {
|
||||
start := time.Now()
|
||||
|
||||
monitor.mu.Lock()
|
||||
defer monitor.mu.Unlock()
|
||||
|
||||
monitor.checkCount++
|
||||
monitor.lastCheck = start
|
||||
|
||||
// Create check context with timeout
|
||||
ctx, cancel := context.WithTimeout(hm.ctx, monitor.config.CheckTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Perform basic module health check
|
||||
moduleHealth := monitor.module.Instance.GetHealth()
|
||||
|
||||
// Perform custom health checks
|
||||
checkResults := make(map[string]CheckResult)
|
||||
for _, check := range monitor.config.CustomChecks {
|
||||
if check.Enabled {
|
||||
checkResult := hm.performCustomCheck(ctx, check)
|
||||
checkResults[check.Name] = checkResult
|
||||
|
||||
// Update overall status based on check results
|
||||
if check.Critical && checkResult.Status != HealthHealthy {
|
||||
moduleHealth.Status = HealthUnhealthy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create health check result
|
||||
result := HealthCheckResult{
|
||||
Timestamp: start,
|
||||
Status: moduleHealth.Status,
|
||||
ResponseTime: time.Since(start),
|
||||
Message: moduleHealth.Message,
|
||||
Details: moduleHealth.Details,
|
||||
Checks: checkResults,
|
||||
}
|
||||
|
||||
// Update statistics
|
||||
if result.Status == HealthHealthy {
|
||||
monitor.successCount++
|
||||
} else {
|
||||
monitor.failureCount++
|
||||
}
|
||||
|
||||
// Add to history
|
||||
monitor.history = append(monitor.history, result)
|
||||
if len(monitor.history) > hm.config.HistorySize {
|
||||
monitor.history = monitor.history[1:]
|
||||
}
|
||||
|
||||
// Update current health
|
||||
oldHealth := monitor.currentHealth
|
||||
monitor.currentHealth = moduleHealth
|
||||
monitor.currentHealth.LastCheck = start
|
||||
monitor.currentHealth.RestartCount = int(monitor.module.HealthStatus.RestartCount)
|
||||
|
||||
// Calculate uptime
|
||||
if !monitor.module.StartTime.IsZero() {
|
||||
monitor.currentHealth.Uptime = time.Since(monitor.module.StartTime)
|
||||
}
|
||||
|
||||
// Update trends if enabled
|
||||
if hm.config.EnableTrends {
|
||||
monitor.trend = hm.calculateHealthTrend(monitor)
|
||||
}
|
||||
|
||||
// Apply health rules
|
||||
hm.applyHealthRules(monitor.moduleID, monitor.currentHealth)
|
||||
|
||||
// Send notifications if health changed
|
||||
if hm.config.EnableNotifications && oldHealth.Status != monitor.currentHealth.Status {
|
||||
hm.notifier.NotifyHealthChange(monitor.moduleID, oldHealth, monitor.currentHealth)
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
if hm.config.EnableMetrics {
|
||||
hm.updateMetrics(monitor, result)
|
||||
}
|
||||
|
||||
return monitor.currentHealth
|
||||
}
|
||||
|
||||
func (hm *HealthMonitorImpl) performCustomCheck(ctx context.Context, check HealthCheck) CheckResult {
|
||||
start := time.Now()
|
||||
|
||||
result := CheckResult{
|
||||
Name: check.Name,
|
||||
Status: HealthHealthy,
|
||||
ResponseTime: 0,
|
||||
Message: "Check passed",
|
||||
Details: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Create timeout context for the check
|
||||
checkCtx, cancel := context.WithTimeout(ctx, check.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Run the check
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- check.CheckFunc()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
result.ResponseTime = time.Since(start)
|
||||
if err != nil {
|
||||
result.Status = HealthUnhealthy
|
||||
result.Message = err.Error()
|
||||
result.Error = err
|
||||
}
|
||||
case <-checkCtx.Done():
|
||||
result.ResponseTime = time.Since(start)
|
||||
result.Status = HealthUnhealthy
|
||||
result.Message = "Check timed out"
|
||||
result.Error = checkCtx.Err()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (hm *HealthMonitorImpl) calculateHealthTrend(monitor *ModuleMonitor) HealthTrend {
|
||||
if len(monitor.history) < 5 {
|
||||
return HealthTrend{
|
||||
Direction: TrendUnknown,
|
||||
Confidence: 0,
|
||||
LastUpdated: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Simple trend calculation based on recent health status
|
||||
recent := monitor.history[len(monitor.history)-5:]
|
||||
healthyCount := 0
|
||||
|
||||
for _, result := range recent {
|
||||
if result.Status == HealthHealthy {
|
||||
healthyCount++
|
||||
}
|
||||
}
|
||||
|
||||
healthRatio := float64(healthyCount) / float64(len(recent))
|
||||
|
||||
var direction TrendDirection
|
||||
var confidence float64
|
||||
|
||||
if healthRatio > 0.8 {
|
||||
direction = TrendImproving
|
||||
confidence = healthRatio
|
||||
} else if healthRatio < 0.4 {
|
||||
direction = TrendDegrading
|
||||
confidence = 1.0 - healthRatio
|
||||
} else {
|
||||
direction = TrendStable
|
||||
confidence = 0.5
|
||||
}
|
||||
|
||||
return HealthTrend{
|
||||
Direction: direction,
|
||||
Confidence: confidence,
|
||||
Slope: healthRatio - 0.5, // Simplified slope calculation
|
||||
Prediction: hm.predictHealthStatus(healthRatio),
|
||||
LastUpdated: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HealthMonitorImpl) predictHealthStatus(healthRatio float64) HealthStatus {
|
||||
if healthRatio > 0.7 {
|
||||
return HealthHealthy
|
||||
} else if healthRatio > 0.3 {
|
||||
return HealthDegraded
|
||||
} else {
|
||||
return HealthUnhealthy
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HealthMonitorImpl) applyHealthRules(moduleID string, health ModuleHealth) {
|
||||
for _, rule := range hm.rules {
|
||||
if rule.Enabled && rule.Condition(health) {
|
||||
if err := rule.Action(moduleID, health); err != nil {
|
||||
// Log error but continue with other rules
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HealthMonitorImpl) updateMetrics(monitor *ModuleMonitor, result HealthCheckResult) {
|
||||
hm.metrics.ChecksPerformed++
|
||||
|
||||
if result.Status == HealthHealthy {
|
||||
hm.metrics.ChecksSuccessful++
|
||||
} else {
|
||||
hm.metrics.ChecksFailed++
|
||||
}
|
||||
|
||||
// Update average check time
|
||||
if hm.metrics.ChecksPerformed > 0 {
|
||||
totalTime := hm.metrics.AverageCheckTime * time.Duration(hm.metrics.ChecksPerformed-1)
|
||||
hm.metrics.AverageCheckTime = (totalTime + result.ResponseTime) / time.Duration(hm.metrics.ChecksPerformed)
|
||||
}
|
||||
|
||||
// Update health score
|
||||
score := hm.aggregator.GetHealthScore(monitor.currentHealth)
|
||||
hm.metrics.ModuleHealthScores[monitor.moduleID] = score
|
||||
}
|
||||
|
||||
func (hm *HealthMonitorImpl) setupDefaultRules() {
|
||||
// Rule: Alert on unhealthy critical modules
|
||||
hm.rules = append(hm.rules, HealthRule{
|
||||
Name: "critical_module_unhealthy",
|
||||
Description: "Alert when a critical module becomes unhealthy",
|
||||
Condition: func(health ModuleHealth) bool {
|
||||
return health.Status == HealthUnhealthy
|
||||
},
|
||||
Action: func(moduleID string, health ModuleHealth) error {
|
||||
alert := HealthAlert{
|
||||
ID: fmt.Sprintf("critical_%s_%d", moduleID, time.Now().Unix()),
|
||||
ModuleID: moduleID,
|
||||
Severity: SeverityCritical,
|
||||
Type: AlertHealthChange,
|
||||
Message: fmt.Sprintf("Critical module %s is unhealthy: %s", moduleID, health.Message),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
return hm.notifier.NotifyAlert(alert)
|
||||
},
|
||||
Severity: SeverityCritical,
|
||||
Enabled: true,
|
||||
})
|
||||
|
||||
// Rule: Alert on degraded performance
|
||||
hm.rules = append(hm.rules, HealthRule{
|
||||
Name: "degraded_performance",
|
||||
Description: "Alert when module performance is degraded",
|
||||
Condition: func(health ModuleHealth) bool {
|
||||
return health.Status == HealthDegraded
|
||||
},
|
||||
Action: func(moduleID string, health ModuleHealth) error {
|
||||
alert := HealthAlert{
|
||||
ID: fmt.Sprintf("degraded_%s_%d", moduleID, time.Now().Unix()),
|
||||
ModuleID: moduleID,
|
||||
Severity: SeverityWarning,
|
||||
Type: AlertHealthChange,
|
||||
Message: fmt.Sprintf("Module %s performance is degraded: %s", moduleID, health.Message),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
return hm.notifier.NotifyAlert(alert)
|
||||
},
|
||||
Severity: SeverityWarning,
|
||||
Enabled: true,
|
||||
})
|
||||
}
|
||||
|
||||
// DefaultHealthAggregator implements basic health aggregation
|
||||
type DefaultHealthAggregator struct{}
|
||||
|
||||
func NewDefaultHealthAggregator() *DefaultHealthAggregator {
|
||||
return &DefaultHealthAggregator{}
|
||||
}
|
||||
|
||||
func (dha *DefaultHealthAggregator) AggregateHealth(modules map[string]ModuleHealth) OverallHealth {
|
||||
overall := OverallHealth{
|
||||
Modules: modules,
|
||||
LastUpdated: time.Now(),
|
||||
Trends: make(map[string]HealthTrend),
|
||||
}
|
||||
|
||||
if len(modules) == 0 {
|
||||
overall.Status = HealthUnknown
|
||||
return overall
|
||||
}
|
||||
|
||||
overall.ModuleCount = len(modules)
|
||||
var totalScore float64
|
||||
|
||||
for moduleID, health := range modules {
|
||||
score := dha.GetHealthScore(health)
|
||||
totalScore += score
|
||||
|
||||
switch health.Status {
|
||||
case HealthHealthy:
|
||||
overall.HealthyCount++
|
||||
case HealthDegraded:
|
||||
overall.DegradedCount++
|
||||
case HealthUnhealthy:
|
||||
overall.UnhealthyCount++
|
||||
overall.CriticalIssues = append(overall.CriticalIssues,
|
||||
fmt.Sprintf("Module %s is unhealthy: %s", moduleID, health.Message))
|
||||
}
|
||||
}
|
||||
|
||||
overall.Score = totalScore / float64(len(modules))
|
||||
overall.Status = dha.CalculateSystemHealth(getHealthValues(modules))
|
||||
|
||||
return overall
|
||||
}
|
||||
|
||||
func (dha *DefaultHealthAggregator) CalculateSystemHealth(individual []ModuleHealth) HealthStatus {
|
||||
if len(individual) == 0 {
|
||||
return HealthUnknown
|
||||
}
|
||||
|
||||
healthyCount := 0
|
||||
degradedCount := 0
|
||||
unhealthyCount := 0
|
||||
|
||||
for _, health := range individual {
|
||||
switch health.Status {
|
||||
case HealthHealthy:
|
||||
healthyCount++
|
||||
case HealthDegraded:
|
||||
degradedCount++
|
||||
case HealthUnhealthy:
|
||||
unhealthyCount++
|
||||
}
|
||||
}
|
||||
|
||||
total := len(individual)
|
||||
healthyRatio := float64(healthyCount) / float64(total)
|
||||
unhealthyRatio := float64(unhealthyCount) / float64(total)
|
||||
|
||||
if unhealthyRatio > 0.3 {
|
||||
return HealthUnhealthy
|
||||
} else if healthyRatio < 0.7 {
|
||||
return HealthDegraded
|
||||
} else {
|
||||
return HealthHealthy
|
||||
}
|
||||
}
|
||||
|
||||
func (dha *DefaultHealthAggregator) GetHealthScore(health ModuleHealth) float64 {
|
||||
switch health.Status {
|
||||
case HealthHealthy:
|
||||
return 1.0
|
||||
case HealthDegraded:
|
||||
return 0.5
|
||||
case HealthUnhealthy:
|
||||
return 0.0
|
||||
default:
|
||||
return 0.0
|
||||
}
|
||||
}
|
||||
|
||||
func getHealthValues(modules map[string]ModuleHealth) []ModuleHealth {
|
||||
values := make([]ModuleHealth, 0, len(modules))
|
||||
for _, health := range modules {
|
||||
values = append(values, health)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// DefaultHealthNotifier implements basic health notifications
|
||||
type DefaultHealthNotifier struct{}
|
||||
|
||||
func NewDefaultHealthNotifier() *DefaultHealthNotifier {
|
||||
return &DefaultHealthNotifier{}
|
||||
}
|
||||
|
||||
func (dhn *DefaultHealthNotifier) NotifyHealthChange(moduleID string, oldHealth, newHealth ModuleHealth) error {
|
||||
// Basic notification implementation - could be extended to send emails, webhooks, etc.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dhn *DefaultHealthNotifier) NotifySystemHealth(health OverallHealth) error {
|
||||
// Basic notification implementation
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dhn *DefaultHealthNotifier) NotifyAlert(alert HealthAlert) error {
|
||||
// Basic notification implementation
|
||||
return nil
|
||||
}
|
||||
838
pkg/lifecycle/module_registry.go
Normal file
838
pkg/lifecycle/module_registry.go
Normal file
@@ -0,0 +1,838 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ModuleRegistry manages the registration, discovery, and lifecycle of system modules
|
||||
type ModuleRegistry struct {
|
||||
modules map[string]*RegisteredModule
|
||||
modulesByType map[reflect.Type][]*RegisteredModule
|
||||
dependencies map[string][]string
|
||||
startOrder []string
|
||||
stopOrder []string
|
||||
state RegistryState
|
||||
eventBus EventBus
|
||||
healthMonitor HealthMonitor
|
||||
config RegistryConfig
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// RegisteredModule represents a module in the registry
|
||||
type RegisteredModule struct {
|
||||
ID string
|
||||
Name string
|
||||
Type reflect.Type
|
||||
Instance Module
|
||||
Config ModuleConfig
|
||||
Dependencies []string
|
||||
State ModuleState
|
||||
Metadata map[string]interface{}
|
||||
StartTime time.Time
|
||||
StopTime time.Time
|
||||
HealthStatus ModuleHealth
|
||||
Metrics ModuleMetrics
|
||||
Created time.Time
|
||||
Version string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Module interface that all modules must implement
|
||||
type Module interface {
|
||||
// Core lifecycle methods
|
||||
Initialize(ctx context.Context, config ModuleConfig) error
|
||||
Start(ctx context.Context) error
|
||||
Stop(ctx context.Context) error
|
||||
Pause(ctx context.Context) error
|
||||
Resume(ctx context.Context) error
|
||||
|
||||
// Module information
|
||||
GetID() string
|
||||
GetName() string
|
||||
GetVersion() string
|
||||
GetDependencies() []string
|
||||
|
||||
// Health and status
|
||||
GetHealth() ModuleHealth
|
||||
GetState() ModuleState
|
||||
GetMetrics() ModuleMetrics
|
||||
}
|
||||
|
||||
// ModuleState represents the current state of a module
|
||||
type ModuleState string
|
||||
|
||||
const (
|
||||
StateUninitialized ModuleState = "uninitialized"
|
||||
StateInitialized ModuleState = "initialized"
|
||||
StateStarting ModuleState = "starting"
|
||||
StateRunning ModuleState = "running"
|
||||
StatePausing ModuleState = "pausing"
|
||||
StatePaused ModuleState = "paused"
|
||||
StateResuming ModuleState = "resuming"
|
||||
StateStopping ModuleState = "stopping"
|
||||
StateStopped ModuleState = "stopped"
|
||||
StateFailed ModuleState = "failed"
|
||||
)
|
||||
|
||||
// RegistryState represents the state of the entire registry
|
||||
type RegistryState string
|
||||
|
||||
const (
|
||||
RegistryUninitialized RegistryState = "uninitialized"
|
||||
RegistryInitialized RegistryState = "initialized"
|
||||
RegistryStarting RegistryState = "starting"
|
||||
RegistryRunning RegistryState = "running"
|
||||
RegistryStopping RegistryState = "stopping"
|
||||
RegistryStopped RegistryState = "stopped"
|
||||
RegistryFailed RegistryState = "failed"
|
||||
)
|
||||
|
||||
// ModuleConfig contains configuration for a module
|
||||
type ModuleConfig struct {
|
||||
Settings map[string]interface{} `json:"settings"`
|
||||
Enabled bool `json:"enabled"`
|
||||
StartTimeout time.Duration `json:"start_timeout"`
|
||||
StopTimeout time.Duration `json:"stop_timeout"`
|
||||
HealthCheckInterval time.Duration `json:"health_check_interval"`
|
||||
MaxRestarts int `json:"max_restarts"`
|
||||
RestartDelay time.Duration `json:"restart_delay"`
|
||||
CriticalModule bool `json:"critical_module"`
|
||||
}
|
||||
|
||||
// ModuleHealth represents the health status of a module
|
||||
type ModuleHealth struct {
|
||||
Status HealthStatus `json:"status"`
|
||||
LastCheck time.Time `json:"last_check"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
Uptime time.Duration `json:"uptime"`
|
||||
RestartCount int `json:"restart_count"`
|
||||
}
|
||||
|
||||
// HealthStatus represents health check results
|
||||
type HealthStatus string
|
||||
|
||||
const (
|
||||
HealthHealthy HealthStatus = "healthy"
|
||||
HealthDegraded HealthStatus = "degraded"
|
||||
HealthUnhealthy HealthStatus = "unhealthy"
|
||||
HealthUnknown HealthStatus = "unknown"
|
||||
)
|
||||
|
||||
// ModuleMetrics contains performance metrics for a module
|
||||
type ModuleMetrics struct {
|
||||
StartupTime time.Duration `json:"startup_time"`
|
||||
ShutdownTime time.Duration `json:"shutdown_time"`
|
||||
MemoryUsage int64 `json:"memory_usage"`
|
||||
CPUUsage float64 `json:"cpu_usage"`
|
||||
RequestCount int64 `json:"request_count"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
LastActivity time.Time `json:"last_activity"`
|
||||
CustomMetrics map[string]interface{} `json:"custom_metrics"`
|
||||
}
|
||||
|
||||
// RegistryConfig configures the module registry
|
||||
type RegistryConfig struct {
|
||||
StartTimeout time.Duration `json:"start_timeout"`
|
||||
StopTimeout time.Duration `json:"stop_timeout"`
|
||||
HealthCheckInterval time.Duration `json:"health_check_interval"`
|
||||
EnableMetrics bool `json:"enable_metrics"`
|
||||
EnableHealthMonitor bool `json:"enable_health_monitor"`
|
||||
ParallelStartup bool `json:"parallel_startup"`
|
||||
ParallelShutdown bool `json:"parallel_shutdown"`
|
||||
FailureRecovery bool `json:"failure_recovery"`
|
||||
AutoRestart bool `json:"auto_restart"`
|
||||
MaxRestartAttempts int `json:"max_restart_attempts"`
|
||||
}
|
||||
|
||||
// EventBus interface for module events
|
||||
type EventBus interface {
|
||||
Publish(event ModuleEvent) error
|
||||
Subscribe(eventType EventType, handler EventHandler) error
|
||||
}
|
||||
|
||||
// HealthMonitor interface for health monitoring
|
||||
type HealthMonitor interface {
|
||||
CheckHealth(module *RegisteredModule) ModuleHealth
|
||||
StartMonitoring(module *RegisteredModule) error
|
||||
StopMonitoring(moduleID string) error
|
||||
GetHealthStatus() map[string]ModuleHealth
|
||||
}
|
||||
|
||||
// ModuleEvent represents an event in the module lifecycle
|
||||
type ModuleEvent struct {
|
||||
Type EventType `json:"type"`
|
||||
ModuleID string `json:"module_id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Error error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// EventType defines types of module events
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventModuleRegistered EventType = "module_registered"
|
||||
EventModuleUnregistered EventType = "module_unregistered"
|
||||
EventModuleInitialized EventType = "module_initialized"
|
||||
EventModuleStarted EventType = "module_started"
|
||||
EventModuleStopped EventType = "module_stopped"
|
||||
EventModulePaused EventType = "module_paused"
|
||||
EventModuleResumed EventType = "module_resumed"
|
||||
EventModuleFailed EventType = "module_failed"
|
||||
EventModuleRestarted EventType = "module_restarted"
|
||||
EventHealthCheck EventType = "health_check"
|
||||
)
|
||||
|
||||
// EventHandler handles module events
|
||||
type EventHandler func(event ModuleEvent) error
|
||||
|
||||
// NewModuleRegistry creates a new module registry
|
||||
func NewModuleRegistry(config RegistryConfig) *ModuleRegistry {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
registry := &ModuleRegistry{
|
||||
modules: make(map[string]*RegisteredModule),
|
||||
modulesByType: make(map[reflect.Type][]*RegisteredModule),
|
||||
dependencies: make(map[string][]string),
|
||||
config: config,
|
||||
state: RegistryUninitialized,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Set default configuration
|
||||
if registry.config.StartTimeout == 0 {
|
||||
registry.config.StartTimeout = 30 * time.Second
|
||||
}
|
||||
if registry.config.StopTimeout == 0 {
|
||||
registry.config.StopTimeout = 15 * time.Second
|
||||
}
|
||||
if registry.config.HealthCheckInterval == 0 {
|
||||
registry.config.HealthCheckInterval = 30 * time.Second
|
||||
}
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
// Register registers a new module with the registry
|
||||
func (mr *ModuleRegistry) Register(module Module, config ModuleConfig) error {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
id := module.GetID()
|
||||
if _, exists := mr.modules[id]; exists {
|
||||
return fmt.Errorf("module already registered: %s", id)
|
||||
}
|
||||
|
||||
moduleType := reflect.TypeOf(module)
|
||||
registered := &RegisteredModule{
|
||||
ID: id,
|
||||
Name: module.GetName(),
|
||||
Type: moduleType,
|
||||
Instance: module,
|
||||
Config: config,
|
||||
Dependencies: module.GetDependencies(),
|
||||
State: StateUninitialized,
|
||||
Metadata: make(map[string]interface{}),
|
||||
HealthStatus: ModuleHealth{Status: HealthUnknown},
|
||||
Created: time.Now(),
|
||||
Version: module.GetVersion(),
|
||||
}
|
||||
|
||||
mr.modules[id] = registered
|
||||
mr.modulesByType[moduleType] = append(mr.modulesByType[moduleType], registered)
|
||||
mr.dependencies[id] = module.GetDependencies()
|
||||
|
||||
// Publish event
|
||||
if mr.eventBus != nil {
|
||||
mr.eventBus.Publish(ModuleEvent{
|
||||
Type: EventModuleRegistered,
|
||||
ModuleID: id,
|
||||
Timestamp: time.Now(),
|
||||
Data: map[string]interface{}{
|
||||
"name": module.GetName(),
|
||||
"version": module.GetVersion(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unregister removes a module from the registry
|
||||
func (mr *ModuleRegistry) Unregister(moduleID string) error {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
registered, exists := mr.modules[moduleID]
|
||||
if !exists {
|
||||
return fmt.Errorf("module not found: %s", moduleID)
|
||||
}
|
||||
|
||||
// Stop module if running
|
||||
if registered.State == StateRunning {
|
||||
if err := mr.stopModule(registered); err != nil {
|
||||
return fmt.Errorf("failed to stop module before unregistering: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from type index
|
||||
moduleType := registered.Type
|
||||
typeModules := mr.modulesByType[moduleType]
|
||||
for i, mod := range typeModules {
|
||||
if mod.ID == moduleID {
|
||||
mr.modulesByType[moduleType] = append(typeModules[:i], typeModules[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from maps
|
||||
delete(mr.modules, moduleID)
|
||||
delete(mr.dependencies, moduleID)
|
||||
|
||||
// Publish event
|
||||
if mr.eventBus != nil {
|
||||
mr.eventBus.Publish(ModuleEvent{
|
||||
Type: EventModuleUnregistered,
|
||||
ModuleID: moduleID,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a module by ID
|
||||
func (mr *ModuleRegistry) Get(moduleID string) (Module, error) {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
|
||||
registered, exists := mr.modules[moduleID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("module not found: %s", moduleID)
|
||||
}
|
||||
|
||||
return registered.Instance, nil
|
||||
}
|
||||
|
||||
// GetByType retrieves all modules of a specific type
|
||||
func (mr *ModuleRegistry) GetByType(moduleType reflect.Type) []Module {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
|
||||
registeredModules := mr.modulesByType[moduleType]
|
||||
modules := make([]Module, len(registeredModules))
|
||||
for i, registered := range registeredModules {
|
||||
modules[i] = registered.Instance
|
||||
}
|
||||
|
||||
return modules
|
||||
}
|
||||
|
||||
// List returns all registered module IDs
|
||||
func (mr *ModuleRegistry) List() []string {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
|
||||
ids := make([]string, 0, len(mr.modules))
|
||||
for id := range mr.modules {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// GetState returns the current state of a module
|
||||
func (mr *ModuleRegistry) GetState(moduleID string) (ModuleState, error) {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
|
||||
registered, exists := mr.modules[moduleID]
|
||||
if !exists {
|
||||
return "", fmt.Errorf("module not found: %s", moduleID)
|
||||
}
|
||||
|
||||
return registered.State, nil
|
||||
}
|
||||
|
||||
// GetRegistryState returns the current state of the registry
|
||||
func (mr *ModuleRegistry) GetRegistryState() RegistryState {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
return mr.state
|
||||
}
|
||||
|
||||
// Initialize initializes all registered modules
|
||||
func (mr *ModuleRegistry) Initialize(ctx context.Context) error {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
if mr.state != RegistryUninitialized {
|
||||
return fmt.Errorf("registry already initialized")
|
||||
}
|
||||
|
||||
mr.state = RegistryInitialized
|
||||
|
||||
// Calculate start order based on dependencies
|
||||
startOrder, err := mr.calculateStartOrder()
|
||||
if err != nil {
|
||||
mr.state = RegistryFailed
|
||||
return fmt.Errorf("failed to calculate start order: %w", err)
|
||||
}
|
||||
mr.startOrder = startOrder
|
||||
|
||||
// Calculate stop order (reverse of start order)
|
||||
mr.stopOrder = make([]string, len(startOrder))
|
||||
for i, id := range startOrder {
|
||||
mr.stopOrder[len(startOrder)-1-i] = id
|
||||
}
|
||||
|
||||
// Initialize all modules
|
||||
for _, moduleID := range mr.startOrder {
|
||||
registered := mr.modules[moduleID]
|
||||
if err := mr.initializeModule(ctx, registered); err != nil {
|
||||
mr.state = RegistryFailed
|
||||
return fmt.Errorf("failed to initialize module %s: %w", moduleID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartAll starts all registered modules in dependency order
|
||||
func (mr *ModuleRegistry) StartAll(ctx context.Context) error {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
if mr.state != RegistryInitialized && mr.state != RegistryStopped {
|
||||
return fmt.Errorf("invalid registry state for start: %s", mr.state)
|
||||
}
|
||||
|
||||
mr.state = RegistryStarting
|
||||
|
||||
if mr.config.ParallelStartup {
|
||||
return mr.startAllParallel(ctx)
|
||||
} else {
|
||||
return mr.startAllSequential(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// StopAll stops all modules in reverse dependency order
|
||||
func (mr *ModuleRegistry) StopAll(ctx context.Context) error {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
if mr.state != RegistryRunning {
|
||||
return fmt.Errorf("invalid registry state for stop: %s", mr.state)
|
||||
}
|
||||
|
||||
mr.state = RegistryStopping
|
||||
|
||||
if mr.config.ParallelShutdown {
|
||||
return mr.stopAllParallel(ctx)
|
||||
} else {
|
||||
return mr.stopAllSequential(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts a specific module
|
||||
func (mr *ModuleRegistry) Start(ctx context.Context, moduleID string) error {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
registered, exists := mr.modules[moduleID]
|
||||
if !exists {
|
||||
return fmt.Errorf("module not found: %s", moduleID)
|
||||
}
|
||||
|
||||
return mr.startModule(ctx, registered)
|
||||
}
|
||||
|
||||
// Stop stops a specific module
|
||||
func (mr *ModuleRegistry) Stop(ctx context.Context, moduleID string) error {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
registered, exists := mr.modules[moduleID]
|
||||
if !exists {
|
||||
return fmt.Errorf("module not found: %s", moduleID)
|
||||
}
|
||||
|
||||
return mr.stopModule(registered)
|
||||
}
|
||||
|
||||
// Pause pauses a specific module
|
||||
func (mr *ModuleRegistry) Pause(ctx context.Context, moduleID string) error {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
registered, exists := mr.modules[moduleID]
|
||||
if !exists {
|
||||
return fmt.Errorf("module not found: %s", moduleID)
|
||||
}
|
||||
|
||||
return mr.pauseModule(ctx, registered)
|
||||
}
|
||||
|
||||
// Resume resumes a paused module
|
||||
func (mr *ModuleRegistry) Resume(ctx context.Context, moduleID string) error {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
registered, exists := mr.modules[moduleID]
|
||||
if !exists {
|
||||
return fmt.Errorf("module not found: %s", moduleID)
|
||||
}
|
||||
|
||||
return mr.resumeModule(ctx, registered)
|
||||
}
|
||||
|
||||
// SetEventBus sets the event bus for the registry
|
||||
func (mr *ModuleRegistry) SetEventBus(eventBus EventBus) {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
mr.eventBus = eventBus
|
||||
}
|
||||
|
||||
// SetHealthMonitor sets the health monitor for the registry
|
||||
func (mr *ModuleRegistry) SetHealthMonitor(healthMonitor HealthMonitor) {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
mr.healthMonitor = healthMonitor
|
||||
}
|
||||
|
||||
// GetHealth returns the health status of all modules
|
||||
func (mr *ModuleRegistry) GetHealth() map[string]ModuleHealth {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
|
||||
health := make(map[string]ModuleHealth)
|
||||
for id, registered := range mr.modules {
|
||||
health[id] = registered.HealthStatus
|
||||
}
|
||||
|
||||
return health
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics for all modules
|
||||
func (mr *ModuleRegistry) GetMetrics() map[string]ModuleMetrics {
|
||||
mr.mu.RLock()
|
||||
defer mr.mu.RUnlock()
|
||||
|
||||
metrics := make(map[string]ModuleMetrics)
|
||||
for id, registered := range mr.modules {
|
||||
metrics[id] = registered.Metrics
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the registry
|
||||
func (mr *ModuleRegistry) Shutdown(ctx context.Context) error {
|
||||
if mr.state == RegistryRunning {
|
||||
if err := mr.StopAll(ctx); err != nil {
|
||||
return fmt.Errorf("failed to stop all modules: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
mr.cancel()
|
||||
mr.state = RegistryStopped
|
||||
return nil
|
||||
}
|
||||
|
||||
// Private methods
|
||||
|
||||
func (mr *ModuleRegistry) calculateStartOrder() ([]string, error) {
|
||||
// Topological sort based on dependencies
|
||||
visited := make(map[string]bool)
|
||||
temp := make(map[string]bool)
|
||||
var order []string
|
||||
|
||||
var visit func(string) error
|
||||
visit = func(moduleID string) error {
|
||||
if temp[moduleID] {
|
||||
return fmt.Errorf("circular dependency detected involving module: %s", moduleID)
|
||||
}
|
||||
if visited[moduleID] {
|
||||
return nil
|
||||
}
|
||||
|
||||
temp[moduleID] = true
|
||||
for _, depID := range mr.dependencies[moduleID] {
|
||||
if _, exists := mr.modules[depID]; !exists {
|
||||
return fmt.Errorf("dependency not found: %s (required by %s)", depID, moduleID)
|
||||
}
|
||||
if err := visit(depID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
temp[moduleID] = false
|
||||
visited[moduleID] = true
|
||||
order = append(order, moduleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for moduleID := range mr.modules {
|
||||
if !visited[moduleID] {
|
||||
if err := visit(moduleID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (mr *ModuleRegistry) initializeModule(ctx context.Context, registered *RegisteredModule) error {
|
||||
registered.State = StateInitialized
|
||||
|
||||
if err := registered.Instance.Initialize(ctx, registered.Config); err != nil {
|
||||
registered.State = StateFailed
|
||||
return err
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if mr.eventBus != nil {
|
||||
mr.eventBus.Publish(ModuleEvent{
|
||||
Type: EventModuleInitialized,
|
||||
ModuleID: registered.ID,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *ModuleRegistry) startModule(ctx context.Context, registered *RegisteredModule) error {
|
||||
if registered.State != StateInitialized && registered.State != StateStopped {
|
||||
return fmt.Errorf("invalid state for start: %s", registered.State)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
registered.State = StateStarting
|
||||
registered.StartTime = startTime
|
||||
|
||||
// Create timeout context
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, registered.Config.StartTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := registered.Instance.Start(timeoutCtx); err != nil {
|
||||
registered.State = StateFailed
|
||||
return err
|
||||
}
|
||||
|
||||
registered.State = StateRunning
|
||||
registered.Metrics.StartupTime = time.Since(startTime)
|
||||
|
||||
// Start health monitoring
|
||||
if mr.healthMonitor != nil {
|
||||
mr.healthMonitor.StartMonitoring(registered)
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if mr.eventBus != nil {
|
||||
mr.eventBus.Publish(ModuleEvent{
|
||||
Type: EventModuleStarted,
|
||||
ModuleID: registered.ID,
|
||||
Timestamp: time.Now(),
|
||||
Data: map[string]interface{}{
|
||||
"startup_time": registered.Metrics.StartupTime,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *ModuleRegistry) stopModule(registered *RegisteredModule) error {
|
||||
if registered.State != StateRunning && registered.State != StatePaused {
|
||||
return fmt.Errorf("invalid state for stop: %s", registered.State)
|
||||
}
|
||||
|
||||
stopTime := time.Now()
|
||||
registered.State = StateStopping
|
||||
|
||||
// Create timeout context
|
||||
ctx, cancel := context.WithTimeout(mr.ctx, registered.Config.StopTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := registered.Instance.Stop(ctx); err != nil {
|
||||
registered.State = StateFailed
|
||||
return err
|
||||
}
|
||||
|
||||
registered.State = StateStopped
|
||||
registered.StopTime = stopTime
|
||||
registered.Metrics.ShutdownTime = time.Since(stopTime)
|
||||
|
||||
// Stop health monitoring
|
||||
if mr.healthMonitor != nil {
|
||||
mr.healthMonitor.StopMonitoring(registered.ID)
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if mr.eventBus != nil {
|
||||
mr.eventBus.Publish(ModuleEvent{
|
||||
Type: EventModuleStopped,
|
||||
ModuleID: registered.ID,
|
||||
Timestamp: time.Now(),
|
||||
Data: map[string]interface{}{
|
||||
"shutdown_time": registered.Metrics.ShutdownTime,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *ModuleRegistry) pauseModule(ctx context.Context, registered *RegisteredModule) error {
|
||||
if registered.State != StateRunning {
|
||||
return fmt.Errorf("invalid state for pause: %s", registered.State)
|
||||
}
|
||||
|
||||
registered.State = StatePausing
|
||||
|
||||
if err := registered.Instance.Pause(ctx); err != nil {
|
||||
registered.State = StateFailed
|
||||
return err
|
||||
}
|
||||
|
||||
registered.State = StatePaused
|
||||
|
||||
// Publish event
|
||||
if mr.eventBus != nil {
|
||||
mr.eventBus.Publish(ModuleEvent{
|
||||
Type: EventModulePaused,
|
||||
ModuleID: registered.ID,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *ModuleRegistry) resumeModule(ctx context.Context, registered *RegisteredModule) error {
|
||||
if registered.State != StatePaused {
|
||||
return fmt.Errorf("invalid state for resume: %s", registered.State)
|
||||
}
|
||||
|
||||
registered.State = StateResuming
|
||||
|
||||
if err := registered.Instance.Resume(ctx); err != nil {
|
||||
registered.State = StateFailed
|
||||
return err
|
||||
}
|
||||
|
||||
registered.State = StateRunning
|
||||
|
||||
// Publish event
|
||||
if mr.eventBus != nil {
|
||||
mr.eventBus.Publish(ModuleEvent{
|
||||
Type: EventModuleResumed,
|
||||
ModuleID: registered.ID,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *ModuleRegistry) startAllSequential(ctx context.Context) error {
|
||||
for _, moduleID := range mr.startOrder {
|
||||
registered := mr.modules[moduleID]
|
||||
if registered.Config.Enabled {
|
||||
if err := mr.startModule(ctx, registered); err != nil {
|
||||
mr.state = RegistryFailed
|
||||
return fmt.Errorf("failed to start module %s: %w", moduleID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mr.state = RegistryRunning
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *ModuleRegistry) startAllParallel(ctx context.Context) error {
|
||||
// Start modules in parallel, respecting dependencies
|
||||
// This is a simplified implementation - in production you'd want more sophisticated parallel startup
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, len(mr.modules))
|
||||
|
||||
for _, moduleID := range mr.startOrder {
|
||||
registered := mr.modules[moduleID]
|
||||
if registered.Config.Enabled {
|
||||
wg.Add(1)
|
||||
go func(reg *RegisteredModule) {
|
||||
defer wg.Done()
|
||||
if err := mr.startModule(ctx, reg); err != nil {
|
||||
errors <- fmt.Errorf("failed to start module %s: %w", reg.ID, err)
|
||||
}
|
||||
}(registered)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
for err := range errors {
|
||||
mr.state = RegistryFailed
|
||||
return err
|
||||
}
|
||||
|
||||
mr.state = RegistryRunning
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *ModuleRegistry) stopAllSequential(ctx context.Context) error {
|
||||
for _, moduleID := range mr.stopOrder {
|
||||
registered := mr.modules[moduleID]
|
||||
if registered.State == StateRunning || registered.State == StatePaused {
|
||||
if err := mr.stopModule(registered); err != nil {
|
||||
mr.state = RegistryFailed
|
||||
return fmt.Errorf("failed to stop module %s: %w", moduleID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mr.state = RegistryStopped
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *ModuleRegistry) stopAllParallel(ctx context.Context) error {
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, len(mr.modules))
|
||||
|
||||
for _, moduleID := range mr.stopOrder {
|
||||
registered := mr.modules[moduleID]
|
||||
if registered.State == StateRunning || registered.State == StatePaused {
|
||||
wg.Add(1)
|
||||
go func(reg *RegisteredModule) {
|
||||
defer wg.Done()
|
||||
if err := mr.stopModule(reg); err != nil {
|
||||
errors <- fmt.Errorf("failed to stop module %s: %w", reg.ID, err)
|
||||
}
|
||||
}(registered)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
for err := range errors {
|
||||
mr.state = RegistryFailed
|
||||
return err
|
||||
}
|
||||
|
||||
mr.state = RegistryStopped
|
||||
return nil
|
||||
}
|
||||
690
pkg/lifecycle/shutdown_manager.go
Normal file
690
pkg/lifecycle/shutdown_manager.go
Normal file
@@ -0,0 +1,690 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShutdownManager handles graceful shutdown of the application
|
||||
type ShutdownManager struct {
|
||||
registry *ModuleRegistry
|
||||
shutdownTasks []ShutdownTask
|
||||
shutdownHooks []ShutdownHook
|
||||
config ShutdownConfig
|
||||
signalChannel chan os.Signal
|
||||
shutdownChannel chan struct{}
|
||||
state ShutdownState
|
||||
startTime time.Time
|
||||
shutdownStarted time.Time
|
||||
mu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// ShutdownTask represents a task to be executed during shutdown
|
||||
type ShutdownTask struct {
|
||||
Name string
|
||||
Priority int
|
||||
Timeout time.Duration
|
||||
Task func(ctx context.Context) error
|
||||
OnError func(error)
|
||||
Critical bool
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// ShutdownHook is called at different stages of shutdown
|
||||
type ShutdownHook interface {
|
||||
OnShutdownStarted(ctx context.Context) error
|
||||
OnModulesStopped(ctx context.Context) error
|
||||
OnCleanupStarted(ctx context.Context) error
|
||||
OnShutdownCompleted(ctx context.Context) error
|
||||
OnShutdownFailed(ctx context.Context, err error) error
|
||||
}
|
||||
|
||||
// ShutdownConfig configures shutdown behavior
|
||||
type ShutdownConfig struct {
|
||||
GracefulTimeout time.Duration `json:"graceful_timeout"`
|
||||
ForceTimeout time.Duration `json:"force_timeout"`
|
||||
SignalBufferSize int `json:"signal_buffer_size"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
RetryDelay time.Duration `json:"retry_delay"`
|
||||
ParallelShutdown bool `json:"parallel_shutdown"`
|
||||
SaveState bool `json:"save_state"`
|
||||
CleanupTempFiles bool `json:"cleanup_temp_files"`
|
||||
NotifyExternal bool `json:"notify_external"`
|
||||
WaitForConnections bool `json:"wait_for_connections"`
|
||||
EnableMetrics bool `json:"enable_metrics"`
|
||||
}
|
||||
|
||||
// ShutdownState represents the current shutdown state
|
||||
type ShutdownState string
|
||||
|
||||
const (
|
||||
ShutdownStateRunning ShutdownState = "running"
|
||||
ShutdownStateInitiated ShutdownState = "initiated"
|
||||
ShutdownStateModuleStop ShutdownState = "stopping_modules"
|
||||
ShutdownStateCleanup ShutdownState = "cleanup"
|
||||
ShutdownStateCompleted ShutdownState = "completed"
|
||||
ShutdownStateFailed ShutdownState = "failed"
|
||||
ShutdownStateForced ShutdownState = "forced"
|
||||
)
|
||||
|
||||
// ShutdownMetrics tracks shutdown performance
|
||||
type ShutdownMetrics struct {
|
||||
ShutdownInitiated time.Time `json:"shutdown_initiated"`
|
||||
ModuleStopTime time.Duration `json:"module_stop_time"`
|
||||
CleanupTime time.Duration `json:"cleanup_time"`
|
||||
TotalShutdownTime time.Duration `json:"total_shutdown_time"`
|
||||
TasksExecuted int `json:"tasks_executed"`
|
||||
TasksSuccessful int `json:"tasks_successful"`
|
||||
TasksFailed int `json:"tasks_failed"`
|
||||
RetryAttempts int `json:"retry_attempts"`
|
||||
ForceShutdown bool `json:"force_shutdown"`
|
||||
Signal string `json:"signal"`
|
||||
}
|
||||
|
||||
// ShutdownProgress tracks shutdown progress
|
||||
type ShutdownProgress struct {
|
||||
State ShutdownState `json:"state"`
|
||||
Progress float64 `json:"progress"`
|
||||
CurrentTask string `json:"current_task"`
|
||||
CompletedTasks int `json:"completed_tasks"`
|
||||
TotalTasks int `json:"total_tasks"`
|
||||
ElapsedTime time.Duration `json:"elapsed_time"`
|
||||
EstimatedRemaining time.Duration `json:"estimated_remaining"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewShutdownManager creates a new shutdown manager
|
||||
func NewShutdownManager(registry *ModuleRegistry, config ShutdownConfig) *ShutdownManager {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
sm := &ShutdownManager{
|
||||
registry: registry,
|
||||
shutdownTasks: make([]ShutdownTask, 0),
|
||||
shutdownHooks: make([]ShutdownHook, 0),
|
||||
config: config,
|
||||
signalChannel: make(chan os.Signal, config.SignalBufferSize),
|
||||
shutdownChannel: make(chan struct{}),
|
||||
state: ShutdownStateRunning,
|
||||
startTime: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Set default configuration
|
||||
if sm.config.GracefulTimeout == 0 {
|
||||
sm.config.GracefulTimeout = 30 * time.Second
|
||||
}
|
||||
if sm.config.ForceTimeout == 0 {
|
||||
sm.config.ForceTimeout = 60 * time.Second
|
||||
}
|
||||
if sm.config.SignalBufferSize == 0 {
|
||||
sm.config.SignalBufferSize = 10
|
||||
}
|
||||
if sm.config.MaxRetries == 0 {
|
||||
sm.config.MaxRetries = 3
|
||||
}
|
||||
if sm.config.RetryDelay == 0 {
|
||||
sm.config.RetryDelay = time.Second
|
||||
}
|
||||
|
||||
// Setup default shutdown tasks
|
||||
sm.setupDefaultTasks()
|
||||
|
||||
// Setup signal handling
|
||||
sm.setupSignalHandling()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// Start starts the shutdown manager
|
||||
func (sm *ShutdownManager) Start() error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if sm.state != ShutdownStateRunning {
|
||||
return fmt.Errorf("shutdown manager not in running state: %s", sm.state)
|
||||
}
|
||||
|
||||
// Start signal monitoring
|
||||
go sm.signalHandler()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown initiates graceful shutdown
|
||||
func (sm *ShutdownManager) Shutdown(ctx context.Context) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if sm.state != ShutdownStateRunning {
|
||||
return fmt.Errorf("shutdown already initiated: %s", sm.state)
|
||||
}
|
||||
|
||||
sm.state = ShutdownStateInitiated
|
||||
sm.shutdownStarted = time.Now()
|
||||
|
||||
// Close shutdown channel to signal shutdown
|
||||
close(sm.shutdownChannel)
|
||||
|
||||
return sm.performShutdown(ctx)
|
||||
}
|
||||
|
||||
// ForceShutdown forces immediate shutdown
|
||||
func (sm *ShutdownManager) ForceShutdown(ctx context.Context) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.state = ShutdownStateForced
|
||||
sm.cancel() // Cancel all operations
|
||||
|
||||
// Force stop all modules immediately
|
||||
if sm.registry != nil {
|
||||
forceCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
sm.registry.StopAll(forceCtx)
|
||||
}
|
||||
|
||||
os.Exit(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddShutdownTask adds a task to be executed during shutdown
|
||||
func (sm *ShutdownManager) AddShutdownTask(task ShutdownTask) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.shutdownTasks = append(sm.shutdownTasks, task)
|
||||
sm.sortTasksByPriority()
|
||||
}
|
||||
|
||||
// AddShutdownHook adds a hook to be called during shutdown phases
|
||||
func (sm *ShutdownManager) AddShutdownHook(hook ShutdownHook) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.shutdownHooks = append(sm.shutdownHooks, hook)
|
||||
}
|
||||
|
||||
// GetState returns the current shutdown state
|
||||
func (sm *ShutdownManager) GetState() ShutdownState {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.state
|
||||
}
|
||||
|
||||
// GetProgress returns the current shutdown progress
|
||||
func (sm *ShutdownManager) GetProgress() ShutdownProgress {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
totalTasks := len(sm.shutdownTasks)
|
||||
if sm.registry != nil {
|
||||
totalTasks += len(sm.registry.List())
|
||||
}
|
||||
|
||||
var progress float64
|
||||
var completedTasks int
|
||||
var currentTask string
|
||||
|
||||
switch sm.state {
|
||||
case ShutdownStateRunning:
|
||||
progress = 0
|
||||
currentTask = "Running"
|
||||
case ShutdownStateInitiated:
|
||||
progress = 0.1
|
||||
currentTask = "Shutdown initiated"
|
||||
case ShutdownStateModuleStop:
|
||||
progress = 0.3
|
||||
currentTask = "Stopping modules"
|
||||
completedTasks = totalTasks / 3
|
||||
case ShutdownStateCleanup:
|
||||
progress = 0.7
|
||||
currentTask = "Cleanup"
|
||||
completedTasks = (totalTasks * 2) / 3
|
||||
case ShutdownStateCompleted:
|
||||
progress = 1.0
|
||||
currentTask = "Completed"
|
||||
completedTasks = totalTasks
|
||||
case ShutdownStateFailed:
|
||||
progress = 0.8
|
||||
currentTask = "Failed"
|
||||
case ShutdownStateForced:
|
||||
progress = 1.0
|
||||
currentTask = "Forced shutdown"
|
||||
completedTasks = totalTasks
|
||||
}
|
||||
|
||||
elapsedTime := time.Since(sm.shutdownStarted)
|
||||
var estimatedRemaining time.Duration
|
||||
if progress > 0 && progress < 1.0 {
|
||||
totalEstimated := time.Duration(float64(elapsedTime) / progress)
|
||||
estimatedRemaining = totalEstimated - elapsedTime
|
||||
}
|
||||
|
||||
return ShutdownProgress{
|
||||
State: sm.state,
|
||||
Progress: progress,
|
||||
CurrentTask: currentTask,
|
||||
CompletedTasks: completedTasks,
|
||||
TotalTasks: totalTasks,
|
||||
ElapsedTime: elapsedTime,
|
||||
EstimatedRemaining: estimatedRemaining,
|
||||
Message: fmt.Sprintf("Shutdown %s", sm.state),
|
||||
}
|
||||
}
|
||||
|
||||
// Wait waits for shutdown to complete
|
||||
func (sm *ShutdownManager) Wait() {
|
||||
<-sm.shutdownChannel
|
||||
sm.wg.Wait()
|
||||
}
|
||||
|
||||
// WaitWithTimeout waits for shutdown with timeout
|
||||
func (sm *ShutdownManager) WaitWithTimeout(timeout time.Duration) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sm.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-time.After(timeout):
|
||||
return fmt.Errorf("shutdown timeout after %v", timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// Private methods
|
||||
|
||||
func (sm *ShutdownManager) setupSignalHandling() {
|
||||
signal.Notify(sm.signalChannel,
|
||||
syscall.SIGINT,
|
||||
syscall.SIGTERM,
|
||||
syscall.SIGQUIT,
|
||||
syscall.SIGHUP,
|
||||
)
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) setupDefaultTasks() {
|
||||
// Task: Save application state
|
||||
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
|
||||
Name: "save_state",
|
||||
Priority: 100,
|
||||
Timeout: 10 * time.Second,
|
||||
Task: func(ctx context.Context) error {
|
||||
if sm.config.SaveState {
|
||||
return sm.saveApplicationState(ctx)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Critical: false,
|
||||
Enabled: sm.config.SaveState,
|
||||
})
|
||||
|
||||
// Task: Close external connections
|
||||
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
|
||||
Name: "close_connections",
|
||||
Priority: 90,
|
||||
Timeout: 5 * time.Second,
|
||||
Task: func(ctx context.Context) error {
|
||||
return sm.closeExternalConnections(ctx)
|
||||
},
|
||||
Critical: false,
|
||||
Enabled: sm.config.WaitForConnections,
|
||||
})
|
||||
|
||||
// Task: Cleanup temporary files
|
||||
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
|
||||
Name: "cleanup_temp_files",
|
||||
Priority: 10,
|
||||
Timeout: 5 * time.Second,
|
||||
Task: func(ctx context.Context) error {
|
||||
if sm.config.CleanupTempFiles {
|
||||
return sm.cleanupTempFiles(ctx)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Critical: false,
|
||||
Enabled: sm.config.CleanupTempFiles,
|
||||
})
|
||||
|
||||
// Task: Notify external systems
|
||||
sm.shutdownTasks = append(sm.shutdownTasks, ShutdownTask{
|
||||
Name: "notify_external",
|
||||
Priority: 80,
|
||||
Timeout: 3 * time.Second,
|
||||
Task: func(ctx context.Context) error {
|
||||
if sm.config.NotifyExternal {
|
||||
return sm.notifyExternalSystems(ctx)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Critical: false,
|
||||
Enabled: sm.config.NotifyExternal,
|
||||
})
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) signalHandler() {
|
||||
for {
|
||||
select {
|
||||
case sig := <-sm.signalChannel:
|
||||
switch sig {
|
||||
case syscall.SIGINT, syscall.SIGTERM:
|
||||
// Graceful shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sm.config.GracefulTimeout)
|
||||
if err := sm.Shutdown(ctx); err != nil {
|
||||
cancel()
|
||||
// Force shutdown if graceful fails
|
||||
forceCtx, forceCancel := context.WithTimeout(context.Background(), sm.config.ForceTimeout)
|
||||
sm.ForceShutdown(forceCtx)
|
||||
forceCancel()
|
||||
}
|
||||
cancel()
|
||||
return
|
||||
case syscall.SIGQUIT:
|
||||
// Force shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sm.config.ForceTimeout)
|
||||
sm.ForceShutdown(ctx)
|
||||
cancel()
|
||||
return
|
||||
case syscall.SIGHUP:
|
||||
// Reload signal - could be used for configuration reload
|
||||
// For now, just log it
|
||||
continue
|
||||
}
|
||||
case <-sm.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) performShutdown(ctx context.Context) error {
|
||||
sm.wg.Add(1)
|
||||
defer sm.wg.Done()
|
||||
|
||||
// Create timeout context for entire shutdown
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, sm.config.GracefulTimeout)
|
||||
defer cancel()
|
||||
|
||||
var shutdownErr error
|
||||
|
||||
// Phase 1: Call shutdown started hooks
|
||||
if err := sm.callHooks(shutdownCtx, "OnShutdownStarted"); err != nil {
|
||||
shutdownErr = fmt.Errorf("shutdown hooks failed: %w", err)
|
||||
}
|
||||
|
||||
// Phase 2: Stop modules
|
||||
sm.state = ShutdownStateModuleStop
|
||||
if sm.registry != nil {
|
||||
if err := sm.registry.StopAll(shutdownCtx); err != nil {
|
||||
shutdownErr = fmt.Errorf("failed to stop modules: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Call modules stopped hooks
|
||||
if err := sm.callHooks(shutdownCtx, "OnModulesStopped"); err != nil {
|
||||
if shutdownErr == nil {
|
||||
shutdownErr = fmt.Errorf("modules stopped hooks failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Execute shutdown tasks
|
||||
sm.state = ShutdownStateCleanup
|
||||
if err := sm.callHooks(shutdownCtx, "OnCleanupStarted"); err != nil {
|
||||
if shutdownErr == nil {
|
||||
shutdownErr = fmt.Errorf("cleanup hooks failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := sm.executeShutdownTasks(shutdownCtx); err != nil {
|
||||
if shutdownErr == nil {
|
||||
shutdownErr = fmt.Errorf("shutdown tasks failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 4: Final cleanup
|
||||
if shutdownErr != nil {
|
||||
sm.state = ShutdownStateFailed
|
||||
sm.callHooks(shutdownCtx, "OnShutdownFailed")
|
||||
} else {
|
||||
sm.state = ShutdownStateCompleted
|
||||
sm.callHooks(shutdownCtx, "OnShutdownCompleted")
|
||||
}
|
||||
|
||||
return shutdownErr
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) executeShutdownTasks(ctx context.Context) error {
|
||||
if sm.config.ParallelShutdown {
|
||||
return sm.executeTasksParallel(ctx)
|
||||
} else {
|
||||
return sm.executeTasksSequential(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) executeTasksSequential(ctx context.Context) error {
|
||||
var lastErr error
|
||||
|
||||
for _, task := range sm.shutdownTasks {
|
||||
if !task.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := sm.executeTask(ctx, task); err != nil {
|
||||
lastErr = err
|
||||
if task.Critical {
|
||||
return fmt.Errorf("critical task %s failed: %w", task.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) executeTasksParallel(ctx context.Context) error {
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, len(sm.shutdownTasks))
|
||||
|
||||
// Group tasks by priority
|
||||
priorityGroups := sm.groupTasksByPriority()
|
||||
|
||||
// Execute each priority group sequentially, but tasks within group in parallel
|
||||
for _, tasks := range priorityGroups {
|
||||
for _, task := range tasks {
|
||||
if !task.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(t ShutdownTask) {
|
||||
defer wg.Done()
|
||||
if err := sm.executeTask(ctx, t); err != nil {
|
||||
errors <- fmt.Errorf("task %s failed: %w", t.Name, err)
|
||||
}
|
||||
}(task)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
close(errors)
|
||||
|
||||
// Collect errors
|
||||
var criticalErr error
|
||||
var lastErr error
|
||||
for err := range errors {
|
||||
lastErr = err
|
||||
// Check if this was from a critical task
|
||||
for _, task := range sm.shutdownTasks {
|
||||
if task.Critical && fmt.Sprintf("task %s failed:", task.Name) == err.Error()[:len(fmt.Sprintf("task %s failed:", task.Name))] {
|
||||
criticalErr = err
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if criticalErr != nil {
|
||||
return criticalErr
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) executeTask(ctx context.Context, task ShutdownTask) error {
|
||||
// Create timeout context for the task
|
||||
taskCtx, cancel := context.WithTimeout(ctx, task.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Execute task with retry
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= sm.config.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
select {
|
||||
case <-time.After(sm.config.RetryDelay):
|
||||
case <-taskCtx.Done():
|
||||
return taskCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
err := task.Task(taskCtx)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// Call error handler if provided
|
||||
if task.OnError != nil {
|
||||
task.OnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("task failed after %d attempts: %w", sm.config.MaxRetries, lastErr)
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) callHooks(ctx context.Context, hookMethod string) error {
|
||||
var lastErr error
|
||||
|
||||
for _, hook := range sm.shutdownHooks {
|
||||
var err error
|
||||
|
||||
switch hookMethod {
|
||||
case "OnShutdownStarted":
|
||||
err = hook.OnShutdownStarted(ctx)
|
||||
case "OnModulesStopped":
|
||||
err = hook.OnModulesStopped(ctx)
|
||||
case "OnCleanupStarted":
|
||||
err = hook.OnCleanupStarted(ctx)
|
||||
case "OnShutdownCompleted":
|
||||
err = hook.OnShutdownCompleted(ctx)
|
||||
case "OnShutdownFailed":
|
||||
err = hook.OnShutdownFailed(ctx, lastErr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) sortTasksByPriority() {
|
||||
// Simple bubble sort by priority (descending)
|
||||
for i := 0; i < len(sm.shutdownTasks); i++ {
|
||||
for j := i + 1; j < len(sm.shutdownTasks); j++ {
|
||||
if sm.shutdownTasks[j].Priority > sm.shutdownTasks[i].Priority {
|
||||
sm.shutdownTasks[i], sm.shutdownTasks[j] = sm.shutdownTasks[j], sm.shutdownTasks[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) groupTasksByPriority() [][]ShutdownTask {
|
||||
groups := make(map[int][]ShutdownTask)
|
||||
|
||||
for _, task := range sm.shutdownTasks {
|
||||
groups[task.Priority] = append(groups[task.Priority], task)
|
||||
}
|
||||
|
||||
// Convert to sorted slice
|
||||
var priorities []int
|
||||
for priority := range groups {
|
||||
priorities = append(priorities, priority)
|
||||
}
|
||||
|
||||
// Sort priorities descending
|
||||
for i := 0; i < len(priorities); i++ {
|
||||
for j := i + 1; j < len(priorities); j++ {
|
||||
if priorities[j] > priorities[i] {
|
||||
priorities[i], priorities[j] = priorities[j], priorities[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var result [][]ShutdownTask
|
||||
for _, priority := range priorities {
|
||||
result = append(result, groups[priority])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Default task implementations
|
||||
|
||||
func (sm *ShutdownManager) saveApplicationState(ctx context.Context) error {
|
||||
// Save application state to disk
|
||||
// This would save things like current configuration, runtime state, etc.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) closeExternalConnections(ctx context.Context) error {
|
||||
// Close database connections, external API connections, etc.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) cleanupTempFiles(ctx context.Context) error {
|
||||
// Remove temporary files, logs, caches, etc.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) notifyExternalSystems(ctx context.Context) error {
|
||||
// Notify external systems that this instance is shutting down
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultShutdownHook provides a basic implementation of ShutdownHook
|
||||
type DefaultShutdownHook struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func NewDefaultShutdownHook(name string) *DefaultShutdownHook {
|
||||
return &DefaultShutdownHook{name: name}
|
||||
}
|
||||
|
||||
func (dsh *DefaultShutdownHook) OnShutdownStarted(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dsh *DefaultShutdownHook) OnModulesStopped(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dsh *DefaultShutdownHook) OnCleanupStarted(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dsh *DefaultShutdownHook) OnShutdownCompleted(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dsh *DefaultShutdownHook) OnShutdownFailed(ctx context.Context, err error) error {
|
||||
return nil
|
||||
}
|
||||
657
pkg/lifecycle/state_machine.go
Normal file
657
pkg/lifecycle/state_machine.go
Normal file
@@ -0,0 +1,657 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StateMachine manages module state transitions and enforces valid state changes
|
||||
type StateMachine struct {
|
||||
currentState ModuleState
|
||||
transitions map[ModuleState][]ModuleState
|
||||
stateHandlers map[ModuleState]StateHandler
|
||||
transitionHooks map[string]TransitionHook
|
||||
history []StateTransition
|
||||
module Module
|
||||
config StateMachineConfig
|
||||
mu sync.RWMutex
|
||||
metrics StateMachineMetrics
|
||||
}
|
||||
|
||||
// StateHandler handles operations when entering a specific state
|
||||
type StateHandler func(ctx context.Context, machine *StateMachine) error
|
||||
|
||||
// TransitionHook is called before or after state transitions
|
||||
type TransitionHook func(ctx context.Context, from, to ModuleState, machine *StateMachine) error
|
||||
|
||||
// StateTransition represents a state change event
|
||||
type StateTransition struct {
|
||||
From ModuleState `json:"from"`
|
||||
To ModuleState `json:"to"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Success bool `json:"success"`
|
||||
Error error `json:"error,omitempty"`
|
||||
Trigger string `json:"trigger"`
|
||||
Context map[string]interface{} `json:"context"`
|
||||
}
|
||||
|
||||
// StateMachineConfig configures state machine behavior
|
||||
type StateMachineConfig struct {
|
||||
InitialState ModuleState `json:"initial_state"`
|
||||
TransitionTimeout time.Duration `json:"transition_timeout"`
|
||||
MaxHistorySize int `json:"max_history_size"`
|
||||
EnableMetrics bool `json:"enable_metrics"`
|
||||
EnableValidation bool `json:"enable_validation"`
|
||||
AllowConcurrent bool `json:"allow_concurrent"`
|
||||
RetryFailedTransitions bool `json:"retry_failed_transitions"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
RetryDelay time.Duration `json:"retry_delay"`
|
||||
}
|
||||
|
||||
// StateMachineMetrics tracks state machine performance
|
||||
type StateMachineMetrics struct {
|
||||
TotalTransitions int64 `json:"total_transitions"`
|
||||
SuccessfulTransitions int64 `json:"successful_transitions"`
|
||||
FailedTransitions int64 `json:"failed_transitions"`
|
||||
StateDistribution map[ModuleState]int64 `json:"state_distribution"`
|
||||
TransitionTimes map[string]time.Duration `json:"transition_times"`
|
||||
AverageTransitionTime time.Duration `json:"average_transition_time"`
|
||||
LongestTransition time.Duration `json:"longest_transition"`
|
||||
LastTransition time.Time `json:"last_transition"`
|
||||
CurrentStateDuration time.Duration `json:"current_state_duration"`
|
||||
stateEnterTime time.Time
|
||||
}
|
||||
|
||||
// NewStateMachine creates a new state machine for a module
|
||||
func NewStateMachine(module Module, config StateMachineConfig) *StateMachine {
|
||||
sm := &StateMachine{
|
||||
currentState: config.InitialState,
|
||||
transitions: createDefaultTransitions(),
|
||||
stateHandlers: make(map[ModuleState]StateHandler),
|
||||
transitionHooks: make(map[string]TransitionHook),
|
||||
history: make([]StateTransition, 0),
|
||||
module: module,
|
||||
config: config,
|
||||
metrics: StateMachineMetrics{
|
||||
StateDistribution: make(map[ModuleState]int64),
|
||||
TransitionTimes: make(map[string]time.Duration),
|
||||
stateEnterTime: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
// Set default config values
|
||||
if sm.config.TransitionTimeout == 0 {
|
||||
sm.config.TransitionTimeout = 30 * time.Second
|
||||
}
|
||||
if sm.config.MaxHistorySize == 0 {
|
||||
sm.config.MaxHistorySize = 100
|
||||
}
|
||||
if sm.config.MaxRetries == 0 {
|
||||
sm.config.MaxRetries = 3
|
||||
}
|
||||
if sm.config.RetryDelay == 0 {
|
||||
sm.config.RetryDelay = time.Second
|
||||
}
|
||||
|
||||
// Setup default state handlers
|
||||
sm.setupDefaultHandlers()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// GetCurrentState returns the current state
|
||||
func (sm *StateMachine) GetCurrentState() ModuleState {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.currentState
|
||||
}
|
||||
|
||||
// CanTransition checks if a transition from current state to target state is valid
|
||||
func (sm *StateMachine) CanTransition(to ModuleState) bool {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
validTransitions, exists := sm.transitions[sm.currentState]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, validState := range validTransitions {
|
||||
if validState == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Transition performs a state transition
|
||||
func (sm *StateMachine) Transition(ctx context.Context, to ModuleState, trigger string) error {
|
||||
if !sm.config.AllowConcurrent {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
} else {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
}
|
||||
|
||||
return sm.performTransition(ctx, to, trigger)
|
||||
}
|
||||
|
||||
// TransitionWithRetry performs a state transition with retry logic
|
||||
func (sm *StateMachine) TransitionWithRetry(ctx context.Context, to ModuleState, trigger string) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= sm.config.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Wait before retrying
|
||||
select {
|
||||
case <-time.After(sm.config.RetryDelay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
err := sm.Transition(ctx, to, trigger)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// Don't retry if it's a validation error
|
||||
if !sm.config.RetryFailedTransitions {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("transition failed after %d attempts: %w", sm.config.MaxRetries, lastErr)
|
||||
}
|
||||
|
||||
// Initialize transitions to initialized state
|
||||
func (sm *StateMachine) Initialize(ctx context.Context) error {
|
||||
return sm.Transition(ctx, StateInitialized, "initialize")
|
||||
}
|
||||
|
||||
// Start transitions to running state
|
||||
func (sm *StateMachine) Start(ctx context.Context) error {
|
||||
return sm.Transition(ctx, StateRunning, "start")
|
||||
}
|
||||
|
||||
// Stop transitions to stopped state
|
||||
func (sm *StateMachine) Stop(ctx context.Context) error {
|
||||
return sm.Transition(ctx, StateStopped, "stop")
|
||||
}
|
||||
|
||||
// Pause transitions to paused state
|
||||
func (sm *StateMachine) Pause(ctx context.Context) error {
|
||||
return sm.Transition(ctx, StatePaused, "pause")
|
||||
}
|
||||
|
||||
// Resume transitions to running state from paused
|
||||
func (sm *StateMachine) Resume(ctx context.Context) error {
|
||||
return sm.Transition(ctx, StateRunning, "resume")
|
||||
}
|
||||
|
||||
// Fail transitions to failed state
|
||||
func (sm *StateMachine) Fail(ctx context.Context, reason string) error {
|
||||
return sm.Transition(ctx, StateFailed, fmt.Sprintf("fail: %s", reason))
|
||||
}
|
||||
|
||||
// SetStateHandler sets a custom handler for a specific state
|
||||
func (sm *StateMachine) SetStateHandler(state ModuleState, handler StateHandler) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.stateHandlers[state] = handler
|
||||
}
|
||||
|
||||
// SetTransitionHook sets a hook for state transitions
|
||||
func (sm *StateMachine) SetTransitionHook(name string, hook TransitionHook) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.transitionHooks[name] = hook
|
||||
}
|
||||
|
||||
// GetHistory returns the state transition history
|
||||
func (sm *StateMachine) GetHistory() []StateTransition {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
history := make([]StateTransition, len(sm.history))
|
||||
copy(history, sm.history)
|
||||
return history
|
||||
}
|
||||
|
||||
// GetMetrics returns state machine metrics
|
||||
func (sm *StateMachine) GetMetrics() StateMachineMetrics {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
// Update current state duration
|
||||
metrics := sm.metrics
|
||||
metrics.CurrentStateDuration = time.Since(sm.metrics.stateEnterTime)
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// AddCustomTransition adds a custom state transition rule
|
||||
func (sm *StateMachine) AddCustomTransition(from, to ModuleState) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if _, exists := sm.transitions[from]; !exists {
|
||||
sm.transitions[from] = make([]ModuleState, 0)
|
||||
}
|
||||
|
||||
// Check if transition already exists
|
||||
for _, existing := range sm.transitions[from] {
|
||||
if existing == to {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sm.transitions[from] = append(sm.transitions[from], to)
|
||||
}
|
||||
|
||||
// RemoveTransition removes a state transition rule
|
||||
func (sm *StateMachine) RemoveTransition(from, to ModuleState) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
transitions, exists := sm.transitions[from]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for i, transition := range transitions {
|
||||
if transition == to {
|
||||
sm.transitions[from] = append(transitions[:i], transitions[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetValidTransitions returns all valid transitions from current state
|
||||
func (sm *StateMachine) GetValidTransitions() []ModuleState {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
validTransitions, exists := sm.transitions[sm.currentState]
|
||||
if !exists {
|
||||
return []ModuleState{}
|
||||
}
|
||||
|
||||
result := make([]ModuleState, len(validTransitions))
|
||||
copy(result, validTransitions)
|
||||
return result
|
||||
}
|
||||
|
||||
// IsInState checks if the state machine is in a specific state
|
||||
func (sm *StateMachine) IsInState(state ModuleState) bool {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.currentState == state
|
||||
}
|
||||
|
||||
// IsInAnyState checks if the state machine is in any of the provided states
|
||||
func (sm *StateMachine) IsInAnyState(states ...ModuleState) bool {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
for _, state := range states {
|
||||
if sm.currentState == state {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WaitForState waits until the state machine reaches a specific state or times out
|
||||
func (sm *StateMachine) WaitForState(ctx context.Context, state ModuleState, timeout time.Duration) error {
|
||||
if sm.IsInState(state) {
|
||||
return nil
|
||||
}
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
return fmt.Errorf("timeout waiting for state %s", state)
|
||||
case <-ticker.C:
|
||||
if sm.IsInState(state) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets the state machine to its initial state
|
||||
func (sm *StateMachine) Reset(ctx context.Context) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
// Clear history
|
||||
sm.history = make([]StateTransition, 0)
|
||||
|
||||
// Reset metrics
|
||||
sm.metrics = StateMachineMetrics{
|
||||
StateDistribution: make(map[ModuleState]int64),
|
||||
TransitionTimes: make(map[string]time.Duration),
|
||||
stateEnterTime: time.Now(),
|
||||
}
|
||||
|
||||
// Transition to initial state
|
||||
return sm.performTransition(ctx, sm.config.InitialState, "reset")
|
||||
}
|
||||
|
||||
// Private methods
|
||||
|
||||
func (sm *StateMachine) performTransition(ctx context.Context, to ModuleState, trigger string) error {
|
||||
startTime := time.Now()
|
||||
from := sm.currentState
|
||||
|
||||
// Validate transition
|
||||
if sm.config.EnableValidation && !sm.canTransitionUnsafe(to) {
|
||||
return fmt.Errorf("invalid transition from %s to %s", from, to)
|
||||
}
|
||||
|
||||
// Create transition context
|
||||
transitionCtx := map[string]interface{}{
|
||||
"trigger": trigger,
|
||||
"start_time": startTime,
|
||||
"module_id": sm.module.GetID(),
|
||||
}
|
||||
|
||||
// Execute pre-transition hooks
|
||||
for name, hook := range sm.transitionHooks {
|
||||
if hookCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout); hookCtx != nil {
|
||||
if err := hook(hookCtx, from, to, sm); err != nil {
|
||||
cancel()
|
||||
sm.recordFailedTransition(from, to, startTime, trigger, err, transitionCtx)
|
||||
return fmt.Errorf("pre-transition hook %s failed: %w", name, err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Execute state-specific logic
|
||||
if err := sm.executeStateTransition(ctx, from, to); err != nil {
|
||||
sm.recordFailedTransition(from, to, startTime, trigger, err, transitionCtx)
|
||||
return fmt.Errorf("state transition failed: %w", err)
|
||||
}
|
||||
|
||||
// Update current state
|
||||
sm.currentState = to
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// Update metrics
|
||||
if sm.config.EnableMetrics {
|
||||
sm.updateMetrics(from, to, duration)
|
||||
}
|
||||
|
||||
// Record successful transition
|
||||
sm.recordSuccessfulTransition(from, to, startTime, duration, trigger, transitionCtx)
|
||||
|
||||
// Execute post-transition hooks
|
||||
for _, hook := range sm.transitionHooks {
|
||||
if hookCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout); hookCtx != nil {
|
||||
if err := hook(hookCtx, from, to, sm); err != nil {
|
||||
// Log error but don't fail the transition
|
||||
cancel()
|
||||
continue
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Execute state handler for new state
|
||||
if handler, exists := sm.stateHandlers[to]; exists {
|
||||
if handlerCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout); handlerCtx != nil {
|
||||
if err := handler(handlerCtx, sm); err != nil {
|
||||
cancel()
|
||||
// Log error but don't fail the transition
|
||||
} else {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *StateMachine) executeStateTransition(ctx context.Context, from, to ModuleState) error {
|
||||
// Create timeout context for the operation
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
|
||||
defer cancel()
|
||||
|
||||
switch to {
|
||||
case StateInitialized:
|
||||
return sm.module.Initialize(timeoutCtx, ModuleConfig{})
|
||||
case StateRunning:
|
||||
if from == StatePaused {
|
||||
return sm.module.Resume(timeoutCtx)
|
||||
}
|
||||
return sm.module.Start(timeoutCtx)
|
||||
case StateStopped:
|
||||
return sm.module.Stop(timeoutCtx)
|
||||
case StatePaused:
|
||||
return sm.module.Pause(timeoutCtx)
|
||||
case StateFailed:
|
||||
// Failed state doesn't require module action
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unknown target state: %s", to)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StateMachine) canTransitionUnsafe(to ModuleState) bool {
|
||||
validTransitions, exists := sm.transitions[sm.currentState]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, validState := range validTransitions {
|
||||
if validState == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (sm *StateMachine) recordSuccessfulTransition(from, to ModuleState, startTime time.Time, duration time.Duration, trigger string, context map[string]interface{}) {
|
||||
transition := StateTransition{
|
||||
From: from,
|
||||
To: to,
|
||||
Timestamp: startTime,
|
||||
Duration: duration,
|
||||
Success: true,
|
||||
Trigger: trigger,
|
||||
Context: context,
|
||||
}
|
||||
|
||||
sm.addToHistory(transition)
|
||||
}
|
||||
|
||||
func (sm *StateMachine) recordFailedTransition(from, to ModuleState, startTime time.Time, trigger string, err error, context map[string]interface{}) {
|
||||
transition := StateTransition{
|
||||
From: from,
|
||||
To: to,
|
||||
Timestamp: startTime,
|
||||
Duration: time.Since(startTime),
|
||||
Success: false,
|
||||
Error: err,
|
||||
Trigger: trigger,
|
||||
Context: context,
|
||||
}
|
||||
|
||||
sm.addToHistory(transition)
|
||||
|
||||
if sm.config.EnableMetrics {
|
||||
sm.metrics.FailedTransitions++
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StateMachine) addToHistory(transition StateTransition) {
|
||||
sm.history = append(sm.history, transition)
|
||||
|
||||
// Trim history if it exceeds max size
|
||||
if len(sm.history) > sm.config.MaxHistorySize {
|
||||
sm.history = sm.history[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StateMachine) updateMetrics(from, to ModuleState, duration time.Duration) {
|
||||
sm.metrics.TotalTransitions++
|
||||
sm.metrics.SuccessfulTransitions++
|
||||
sm.metrics.StateDistribution[to]++
|
||||
sm.metrics.LastTransition = time.Now()
|
||||
|
||||
// Update transition times
|
||||
transitionKey := fmt.Sprintf("%s->%s", from, to)
|
||||
sm.metrics.TransitionTimes[transitionKey] = duration
|
||||
|
||||
// Update average transition time
|
||||
if sm.metrics.TotalTransitions > 0 {
|
||||
total := time.Duration(0)
|
||||
for _, d := range sm.metrics.TransitionTimes {
|
||||
total += d
|
||||
}
|
||||
sm.metrics.AverageTransitionTime = total / time.Duration(len(sm.metrics.TransitionTimes))
|
||||
}
|
||||
|
||||
// Update longest transition
|
||||
if duration > sm.metrics.LongestTransition {
|
||||
sm.metrics.LongestTransition = duration
|
||||
}
|
||||
|
||||
// Update state enter time for duration tracking
|
||||
sm.metrics.stateEnterTime = time.Now()
|
||||
}
|
||||
|
||||
func (sm *StateMachine) setupDefaultHandlers() {
|
||||
// Default handlers for common states
|
||||
sm.stateHandlers[StateInitialized] = func(ctx context.Context, machine *StateMachine) error {
|
||||
// State entered successfully
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.stateHandlers[StateRunning] = func(ctx context.Context, machine *StateMachine) error {
|
||||
// Module is now running
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.stateHandlers[StateStopped] = func(ctx context.Context, machine *StateMachine) error {
|
||||
// Module has stopped
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.stateHandlers[StatePaused] = func(ctx context.Context, machine *StateMachine) error {
|
||||
// Module is paused
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.stateHandlers[StateFailed] = func(ctx context.Context, machine *StateMachine) error {
|
||||
// Handle failure state - could trigger recovery logic
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// createDefaultTransitions creates the standard state transition rules
|
||||
func createDefaultTransitions() map[ModuleState][]ModuleState {
|
||||
return map[ModuleState][]ModuleState{
|
||||
StateUninitialized: {StateInitialized, StateFailed},
|
||||
StateInitialized: {StateStarting, StateStopped, StateFailed},
|
||||
StateStarting: {StateRunning, StateFailed},
|
||||
StateRunning: {StatePausing, StateStopping, StateFailed},
|
||||
StatePausing: {StatePaused, StateFailed},
|
||||
StatePaused: {StateResuming, StateStopping, StateFailed},
|
||||
StateResuming: {StateRunning, StateFailed},
|
||||
StateStopping: {StateStopped, StateFailed},
|
||||
StateStopped: {StateInitialized, StateStarting, StateFailed},
|
||||
StateFailed: {StateInitialized, StateStopped}, // Recovery paths
|
||||
}
|
||||
}
|
||||
|
||||
// StateMachineBuilder provides a fluent interface for building state machines
|
||||
type StateMachineBuilder struct {
|
||||
config StateMachineConfig
|
||||
stateHandlers map[ModuleState]StateHandler
|
||||
transitionHooks map[string]TransitionHook
|
||||
customTransitions map[ModuleState][]ModuleState
|
||||
}
|
||||
|
||||
// NewStateMachineBuilder creates a new state machine builder
|
||||
func NewStateMachineBuilder() *StateMachineBuilder {
|
||||
return &StateMachineBuilder{
|
||||
config: StateMachineConfig{
|
||||
InitialState: StateUninitialized,
|
||||
TransitionTimeout: 30 * time.Second,
|
||||
MaxHistorySize: 100,
|
||||
EnableMetrics: true,
|
||||
EnableValidation: true,
|
||||
},
|
||||
stateHandlers: make(map[ModuleState]StateHandler),
|
||||
transitionHooks: make(map[string]TransitionHook),
|
||||
customTransitions: make(map[ModuleState][]ModuleState),
|
||||
}
|
||||
}
|
||||
|
||||
// WithConfig sets the state machine configuration
|
||||
func (smb *StateMachineBuilder) WithConfig(config StateMachineConfig) *StateMachineBuilder {
|
||||
smb.config = config
|
||||
return smb
|
||||
}
|
||||
|
||||
// WithStateHandler adds a state handler
|
||||
func (smb *StateMachineBuilder) WithStateHandler(state ModuleState, handler StateHandler) *StateMachineBuilder {
|
||||
smb.stateHandlers[state] = handler
|
||||
return smb
|
||||
}
|
||||
|
||||
// WithTransitionHook adds a transition hook
|
||||
func (smb *StateMachineBuilder) WithTransitionHook(name string, hook TransitionHook) *StateMachineBuilder {
|
||||
smb.transitionHooks[name] = hook
|
||||
return smb
|
||||
}
|
||||
|
||||
// WithCustomTransition adds a custom transition rule
|
||||
func (smb *StateMachineBuilder) WithCustomTransition(from, to ModuleState) *StateMachineBuilder {
|
||||
if _, exists := smb.customTransitions[from]; !exists {
|
||||
smb.customTransitions[from] = make([]ModuleState, 0)
|
||||
}
|
||||
smb.customTransitions[from] = append(smb.customTransitions[from], to)
|
||||
return smb
|
||||
}
|
||||
|
||||
// Build creates the state machine
|
||||
func (smb *StateMachineBuilder) Build(module Module) *StateMachine {
|
||||
sm := NewStateMachine(module, smb.config)
|
||||
|
||||
// Add state handlers
|
||||
for state, handler := range smb.stateHandlers {
|
||||
sm.SetStateHandler(state, handler)
|
||||
}
|
||||
|
||||
// Add transition hooks
|
||||
for name, hook := range smb.transitionHooks {
|
||||
sm.SetTransitionHook(name, hook)
|
||||
}
|
||||
|
||||
// Add custom transitions
|
||||
for from, toStates := range smb.customTransitions {
|
||||
for _, to := range toStates {
|
||||
sm.AddCustomTransition(from, to)
|
||||
}
|
||||
}
|
||||
|
||||
return sm
|
||||
}
|
||||
Reference in New Issue
Block a user