fix(multicall): resolve critical multicall parsing corruption issues

- Added comprehensive bounds checking to prevent buffer overruns in multicall parsing
- Implemented graduated validation system (Strict/Moderate/Permissive) to reduce false positives
- Added LRU caching system for address validation with 10-minute TTL
- Enhanced ABI decoder with missing Universal Router and Arbitrum-specific DEX signatures
- Fixed duplicate function declarations and import conflicts across multiple files
- Added error recovery mechanisms with multiple fallback strategies
- Updated tests to handle new validation behavior for suspicious addresses
- Fixed parser test expectations for improved validation system
- Applied gofmt formatting fixes to ensure code style compliance
- Fixed mutex copying issues in monitoring package by introducing MetricsSnapshot
- Resolved critical security vulnerabilities in heuristic address extraction
- Progress: Updated TODO audit from 10% to 35% complete

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Krypto Kajun
2025-10-17 00:12:55 -05:00
parent f358f49aa9
commit 850223a953
8621 changed files with 79808 additions and 7340 deletions

View File

@@ -0,0 +1,109 @@
package lifecycle
import (
"errors"
"fmt"
"regexp"
"strings"
)
var txHashPattern = regexp.MustCompile(`0x[a-fA-F0-9]{64}`)
type RecordedError struct {
Err error
TxHash string
}
func (re RecordedError) Error() string {
if re.Err == nil {
return ""
}
return re.Err.Error()
}
func enrichErrorWithTxHash(message string, err error, attrs []interface{}) (error, string, []interface{}) {
txHash, attrsWithTx := ensureTxHash(attrs, err)
wrapped := fmt.Errorf("%s: %w", message, err)
if txHash != "" {
wrapped = fmt.Errorf("%s [tx_hash=%s]: %w", message, txHash, err)
}
return wrapped, txHash, attrsWithTx
}
func ensureTxHash(attrs []interface{}, err error) (string, []interface{}) {
txHash := extractTxHashFromAttrs(attrs)
if txHash == "" {
txHash = extractTxHashFromError(err)
}
if txHash == "" {
return "", attrs
}
hasTxAttr := false
for i := 0; i+1 < len(attrs); i += 2 {
key, ok := attrs[i].(string)
if !ok {
continue
}
if key == "tx_hash" || key == "transaction_hash" || key == "tx" {
hasTxAttr = true
break
}
}
if !hasTxAttr {
attrs = append(attrs, "tx_hash", txHash)
}
return txHash, attrs
}
func extractTxHashFromAttrs(attrs []interface{}) string {
for i := 0; i+1 < len(attrs); i += 2 {
key, ok := attrs[i].(string)
if !ok {
continue
}
if key == "tx_hash" || key == "transaction_hash" || key == "tx" {
if value, ok := attrs[i+1].(string); ok && isValidTxHash(value) {
return strings.ToLower(value)
}
}
}
return ""
}
func extractTxHashFromError(err error) string {
for err != nil {
if match := txHashPattern.FindString(err.Error()); match != "" {
return strings.ToLower(match)
}
err = errors.Unwrap(err)
}
return ""
}
func isValidTxHash(value string) bool {
if value == "" {
return false
}
if len(value) != 66 {
return false
}
if !strings.HasPrefix(value, "0x") {
return false
}
for _, r := range value[2:] {
if !isHexChar(r) {
return false
}
}
return true
}
func isHexChar(r rune) bool {
return (r >= '0' && r <= '9') ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}

View File

@@ -2,23 +2,31 @@ package lifecycle
import (
"context"
"errors"
"fmt"
"os"
"sync"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// HealthMonitorImpl implements comprehensive health monitoring for modules
type HealthMonitorImpl struct {
monitors map[string]*ModuleMonitor
config HealthMonitorConfig
aggregator HealthAggregator
notifier HealthNotifier
metrics HealthMetrics
rules []HealthRule
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
running bool
monitors map[string]*ModuleMonitor
config HealthMonitorConfig
aggregator HealthAggregator
notifier HealthNotifier
metrics HealthMetrics
rules []HealthRule
notificationErrors []error
notificationErrorDetails []RecordedError
notifyMu sync.Mutex
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
running bool
logger *logger.Logger
}
// ModuleMonitor monitors a specific module's health
@@ -38,16 +46,18 @@ type ModuleMonitor struct {
// HealthMonitorConfig configures the health monitoring system
type HealthMonitorConfig struct {
CheckInterval time.Duration `json:"check_interval"`
CheckTimeout time.Duration `json:"check_timeout"`
HistorySize int `json:"history_size"`
FailureThreshold int `json:"failure_threshold"`
RecoveryThreshold int `json:"recovery_threshold"`
EnableNotifications bool `json:"enable_notifications"`
EnableMetrics bool `json:"enable_metrics"`
EnableTrends bool `json:"enable_trends"`
ParallelChecks bool `json:"parallel_checks"`
MaxConcurrentChecks int `json:"max_concurrent_checks"`
CheckInterval time.Duration `json:"check_interval"`
CheckTimeout time.Duration `json:"check_timeout"`
HistorySize int `json:"history_size"`
FailureThreshold int `json:"failure_threshold"`
RecoveryThreshold int `json:"recovery_threshold"`
EnableNotifications bool `json:"enable_notifications"`
EnableMetrics bool `json:"enable_metrics"`
EnableTrends bool `json:"enable_trends"`
ParallelChecks bool `json:"parallel_checks"`
MaxConcurrentChecks int `json:"max_concurrent_checks"`
NotificationRetries int `json:"notification_retries"`
NotificationRetryDelay time.Duration `json:"notification_retry_delay"`
}
// ModuleHealthConfig configures health checking for a specific module
@@ -217,13 +227,15 @@ func NewHealthMonitor(config HealthMonitorConfig) *HealthMonitorImpl {
ctx, cancel := context.WithCancel(context.Background())
hm := &HealthMonitorImpl{
monitors: make(map[string]*ModuleMonitor),
config: config,
aggregator: NewDefaultHealthAggregator(),
notifier: NewDefaultHealthNotifier(),
rules: make([]HealthRule, 0),
ctx: ctx,
cancel: cancel,
monitors: make(map[string]*ModuleMonitor),
config: config,
aggregator: NewDefaultHealthAggregator(),
notifier: NewDefaultHealthNotifier(),
rules: make([]HealthRule, 0),
notificationErrors: make([]error, 0),
notificationErrorDetails: make([]RecordedError, 0),
ctx: ctx,
cancel: cancel,
metrics: HealthMetrics{
ModuleHealthScores: make(map[string]float64),
},
@@ -248,10 +260,21 @@ func NewHealthMonitor(config HealthMonitorConfig) *HealthMonitorImpl {
if hm.config.MaxConcurrentChecks == 0 {
hm.config.MaxConcurrentChecks = 10
}
if hm.config.NotificationRetries == 0 {
hm.config.NotificationRetries = 3
}
if hm.config.NotificationRetryDelay == 0 {
hm.config.NotificationRetryDelay = 500 * time.Millisecond
}
// Setup default health rules
hm.setupDefaultRules()
if err := os.MkdirAll("logs", 0o755); err != nil {
fmt.Printf("failed to ensure logs directory: %v\n", err)
}
hm.logger = logger.New("info", "", "logs/lifecycle_health.log")
return hm
}
@@ -441,7 +464,11 @@ func (hm *HealthMonitorImpl) performAllHealthChecks() {
// Update overall health and send notifications
overallHealth := hm.GetOverallHealth()
if hm.config.EnableNotifications {
hm.notifier.NotifySystemHealth(overallHealth)
_ = hm.notifyWithRetry(
func() error { return hm.notifier.NotifySystemHealth(overallHealth) },
"Failed to notify system health",
"overall_health_status", overallHealth.Status,
)
}
}
@@ -545,9 +572,22 @@ func (hm *HealthMonitorImpl) performHealthCheck(monitor *ModuleMonitor) ModuleHe
// Apply health rules
hm.applyHealthRules(monitor.moduleID, monitor.currentHealth)
// Send notifications if health changed
_ = hm.notifyWithRetry(
func() error {
return hm.notifier.NotifyHealthChange(monitor.moduleID, oldHealth, monitor.currentHealth)
},
"Failed to notify health change",
"module_id", monitor.moduleID,
)
if hm.config.EnableNotifications && oldHealth.Status != monitor.currentHealth.Status {
hm.notifier.NotifyHealthChange(monitor.moduleID, oldHealth, monitor.currentHealth)
_ = hm.notifyWithRetry(
func() error {
return hm.notifier.NotifyHealthChange(monitor.moduleID, oldHealth, monitor.currentHealth)
},
"Failed to notify health change (status transition)",
"module_id", monitor.moduleID,
"reason", "status_change",
)
}
// Update metrics
@@ -597,6 +637,114 @@ func (hm *HealthMonitorImpl) performCustomCheck(ctx context.Context, check Healt
return result
}
func (hm *HealthMonitorImpl) notifyWithRetry(send func() error, failureMessage string, attrs ...interface{}) error {
if hm.notifier == nil {
return nil
}
retries := hm.config.NotificationRetries
if retries <= 0 {
retries = 1
}
delay := hm.config.NotificationRetryDelay
if delay <= 0 {
delay = 500 * time.Millisecond
}
var errs []error
for attempt := 1; attempt <= retries; attempt++ {
if err := send(); err != nil {
errs = append(errs, err)
if attempt < retries {
time.Sleep(delay)
continue
}
joined := errors.Join(errs...)
attemptAttrs := append([]interface{}{}, attrs...)
attemptAttrs = append(attemptAttrs, "attempts", attempt)
hm.recordNotificationError(failureMessage, joined, attemptAttrs...)
return fmt.Errorf("%s after %d attempts: %w", failureMessage, attempt, joined)
}
if attempt > 1 && hm.logger != nil {
attemptAttrs := append([]interface{}{}, attrs...)
attemptAttrs = append(attemptAttrs, "attempts", attempt)
hm.logger.Warn(append([]interface{}{"Health notification succeeded after retry"}, attemptAttrs...)...)
}
return nil
}
return nil
}
func (hm *HealthMonitorImpl) recordNotificationError(message string, err error, attrs ...interface{}) {
if err == nil {
return
}
attrCopy := append([]interface{}{}, attrs...)
wrapped, txHash, attrsWithTx := enrichErrorWithTxHash(message, err, attrCopy)
hm.notifyMu.Lock()
hm.notificationErrors = append(hm.notificationErrors, wrapped)
hm.notificationErrorDetails = append(hm.notificationErrorDetails, RecordedError{
Err: wrapped,
TxHash: txHash,
})
hm.notifyMu.Unlock()
if hm.logger != nil {
kv := append([]interface{}{}, attrsWithTx...)
kv = append(kv, "error", err)
args := append([]interface{}{message}, kv...)
hm.logger.Error(args...)
}
}
func (hm *HealthMonitorImpl) aggregatedNotificationError() error {
hm.notifyMu.Lock()
defer hm.notifyMu.Unlock()
if len(hm.notificationErrors) == 0 {
return nil
}
errs := make([]error, len(hm.notificationErrors))
copy(errs, hm.notificationErrors)
return errors.Join(errs...)
}
// NotificationErrors returns a copy of recorded notification errors for diagnostics.
func (hm *HealthMonitorImpl) NotificationErrors() []error {
hm.notifyMu.Lock()
defer hm.notifyMu.Unlock()
if len(hm.notificationErrors) == 0 {
return nil
}
errs := make([]error, len(hm.notificationErrors))
copy(errs, hm.notificationErrors)
return errs
}
// NotificationErrorDetails returns recorded notification errors with tx hash metadata.
func (hm *HealthMonitorImpl) NotificationErrorDetails() []RecordedError {
hm.notifyMu.Lock()
defer hm.notifyMu.Unlock()
if len(hm.notificationErrorDetails) == 0 {
return nil
}
details := make([]RecordedError, len(hm.notificationErrorDetails))
copy(details, hm.notificationErrorDetails)
return details
}
func (hm *HealthMonitorImpl) calculateHealthTrend(monitor *ModuleMonitor) HealthTrend {
if len(monitor.history) < 5 {
return HealthTrend{
@@ -681,50 +829,45 @@ func (hm *HealthMonitorImpl) updateMetrics(monitor *ModuleMonitor, result Health
hm.metrics.ModuleHealthScores[monitor.moduleID] = score
}
func (hm *HealthMonitorImpl) setupDefaultRules() {
// Rule: Alert on unhealthy critical modules
hm.rules = append(hm.rules, HealthRule{
Name: "critical_module_unhealthy",
Description: "Alert when a critical module becomes unhealthy",
func (hm *HealthMonitorImpl) createHealthRule(name, description, messageFormat string, status HealthStatus, severity AlertSeverity) HealthRule {
return HealthRule{
Name: name,
Description: description,
Condition: func(health ModuleHealth) bool {
return health.Status == HealthUnhealthy
return health.Status == status
},
Action: func(moduleID string, health ModuleHealth) error {
alert := HealthAlert{
ID: fmt.Sprintf("critical_%s_%d", moduleID, time.Now().Unix()),
ID: fmt.Sprintf("%s_%s_%d", name, moduleID, time.Now().Unix()),
ModuleID: moduleID,
Severity: SeverityCritical,
Severity: severity,
Type: AlertHealthChange,
Message: fmt.Sprintf("Critical module %s is unhealthy: %s", moduleID, health.Message),
Message: fmt.Sprintf(messageFormat, moduleID, health.Message),
Timestamp: time.Now(),
}
return hm.notifier.NotifyAlert(alert)
},
Severity: SeverityCritical,
Severity: severity,
Enabled: true,
})
}
}
// Rule: Alert on degraded performance
hm.rules = append(hm.rules, HealthRule{
Name: "degraded_performance",
Description: "Alert when module performance is degraded",
Condition: func(health ModuleHealth) bool {
return health.Status == HealthDegraded
},
Action: func(moduleID string, health ModuleHealth) error {
alert := HealthAlert{
ID: fmt.Sprintf("degraded_%s_%d", moduleID, time.Now().Unix()),
ModuleID: moduleID,
Severity: SeverityWarning,
Type: AlertHealthChange,
Message: fmt.Sprintf("Module %s performance is degraded: %s", moduleID, health.Message),
Timestamp: time.Now(),
}
return hm.notifier.NotifyAlert(alert)
},
Severity: SeverityWarning,
Enabled: true,
})
func (hm *HealthMonitorImpl) setupDefaultRules() {
hm.rules = append(hm.rules, hm.createHealthRule(
"critical_module_unhealthy",
"Alert when a critical module becomes unhealthy",
"Critical module %s is unhealthy: %s",
HealthUnhealthy,
SeverityCritical,
))
hm.rules = append(hm.rules, hm.createHealthRule(
"degraded_performance",
"Alert when module performance is degraded",
"Module %s performance is degraded: %s",
HealthDegraded,
SeverityWarning,
))
}
// DefaultHealthAggregator implements basic health aggregation

View File

@@ -0,0 +1,106 @@
package lifecycle
import (
"fmt"
"strings"
"testing"
"time"
)
type stubHealthNotifier struct {
failUntil int
attempts int
txHash string
}
func (s *stubHealthNotifier) NotifyHealthChange(moduleID string, oldHealth, newHealth ModuleHealth) error {
s.attempts++
if s.attempts <= s.failUntil {
return fmt.Errorf("notify failure %d for tx %s", s.attempts, s.txHash)
}
return nil
}
func (s *stubHealthNotifier) NotifySystemHealth(health OverallHealth) error {
return nil
}
func (s *stubHealthNotifier) NotifyAlert(alert HealthAlert) error {
return nil
}
func TestHealthMonitorNotifyWithRetrySuccess(t *testing.T) {
config := HealthMonitorConfig{
NotificationRetries: 3,
NotificationRetryDelay: time.Nanosecond,
}
hm := NewHealthMonitor(config)
hm.logger = nil
notifier := &stubHealthNotifier{failUntil: 2, txHash: "0xdddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd"}
hm.notifier = notifier
err := hm.notifyWithRetry(func() error {
return notifier.NotifyHealthChange("module", ModuleHealth{}, ModuleHealth{})
}, "notify failure", "module_id", "module")
if err != nil {
t.Fatalf("expected notification to eventually succeed, got %v", err)
}
if notifier.attempts != 3 {
t.Fatalf("expected 3 attempts, got %d", notifier.attempts)
}
if errs := hm.NotificationErrors(); len(errs) != 0 {
t.Fatalf("expected no recorded notification errors, got %d", len(errs))
}
if hm.aggregatedNotificationError() != nil {
t.Fatal("expected aggregated notification error to be nil")
}
}
func TestHealthMonitorNotifyWithRetryFailure(t *testing.T) {
config := HealthMonitorConfig{
NotificationRetries: 2,
NotificationRetryDelay: time.Nanosecond,
}
hm := NewHealthMonitor(config)
hm.logger = nil
txHash := "0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
notifier := &stubHealthNotifier{failUntil: 3, txHash: txHash}
hm.notifier = notifier
err := hm.notifyWithRetry(func() error {
return notifier.NotifyHealthChange("module", ModuleHealth{}, ModuleHealth{})
}, "notify failure", "module_id", "module")
if err == nil {
t.Fatal("expected notification to fail after retries")
}
if notifier.attempts != 2 {
t.Fatalf("expected 2 attempts, got %d", notifier.attempts)
}
exported := hm.NotificationErrors()
if len(exported) != 1 {
t.Fatalf("expected 1 recorded notification error, got %d", len(exported))
}
copyErrs := hm.NotificationErrors()
if copyErrs[0] == nil {
t.Fatal("expected copy of notification errors to retain value")
}
if got := exported[0].Error(); !strings.Contains(got, txHash) {
t.Fatalf("recorded notification error should include tx hash, got %q", got)
}
details := hm.NotificationErrorDetails()
if len(details) != 1 {
t.Fatalf("expected notification error details entry, got %d", len(details))
}
if details[0].TxHash != txHash {
t.Fatalf("expected notification error detail to track tx hash %s, got %s", txHash, details[0].TxHash)
}
agg := hm.aggregatedNotificationError()
if agg == nil {
t.Fatal("expected aggregated notification error to be returned")
}
if got := agg.Error(); !strings.Contains(got, "notify failure") || !strings.Contains(got, "notify failure 1") || !strings.Contains(got, txHash) {
t.Fatalf("aggregated notification error should include failure details and tx hash, got %q", got)
}
}

View File

@@ -2,6 +2,7 @@ package lifecycle
import (
"context"
"errors"
"fmt"
"log/slog"
"reflect"
@@ -11,19 +12,22 @@ import (
// ModuleRegistry manages the registration, discovery, and lifecycle of system modules
type ModuleRegistry struct {
modules map[string]*RegisteredModule
modulesByType map[reflect.Type][]*RegisteredModule
dependencies map[string][]string
startOrder []string
stopOrder []string
state RegistryState
eventBus EventBus
healthMonitor HealthMonitor
config RegistryConfig
logger *slog.Logger
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
modules map[string]*RegisteredModule
modulesByType map[reflect.Type][]*RegisteredModule
dependencies map[string][]string
startOrder []string
stopOrder []string
state RegistryState
eventBus EventBus
healthMonitor HealthMonitor
config RegistryConfig
logger *slog.Logger
registryErrors []error
registryErrorDetails []RecordedError
errMu sync.Mutex
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
// RegisteredModule represents a module in the registry
@@ -151,6 +155,8 @@ type RegistryConfig struct {
FailureRecovery bool `json:"failure_recovery"`
AutoRestart bool `json:"auto_restart"`
MaxRestartAttempts int `json:"max_restart_attempts"`
EventPublishRetries int `json:"event_publish_retries"`
EventPublishDelay time.Duration `json:"event_publish_delay"`
}
// EventBus interface for module events
@@ -200,13 +206,15 @@ func NewModuleRegistry(config RegistryConfig) *ModuleRegistry {
ctx, cancel := context.WithCancel(context.Background())
registry := &ModuleRegistry{
modules: make(map[string]*RegisteredModule),
modulesByType: make(map[reflect.Type][]*RegisteredModule),
dependencies: make(map[string][]string),
config: config,
state: RegistryUninitialized,
ctx: ctx,
cancel: cancel,
modules: make(map[string]*RegisteredModule),
modulesByType: make(map[reflect.Type][]*RegisteredModule),
dependencies: make(map[string][]string),
config: config,
state: RegistryUninitialized,
ctx: ctx,
cancel: cancel,
registryErrors: make([]error, 0),
registryErrorDetails: make([]RecordedError, 0),
}
// Set default configuration
@@ -219,6 +227,16 @@ func NewModuleRegistry(config RegistryConfig) *ModuleRegistry {
if registry.config.HealthCheckInterval == 0 {
registry.config.HealthCheckInterval = 30 * time.Second
}
if registry.config.EventPublishRetries == 0 {
registry.config.EventPublishRetries = 3
}
if registry.config.EventPublishDelay == 0 {
registry.config.EventPublishDelay = 200 * time.Millisecond
}
if registry.logger == nil {
registry.logger = slog.Default()
}
return registry
}
@@ -253,19 +271,15 @@ func (mr *ModuleRegistry) Register(module Module, config ModuleConfig) error {
mr.dependencies[id] = module.GetDependencies()
// Publish event
if mr.eventBus != nil {
if err := mr.eventBus.Publish(ModuleEvent{
Type: EventModuleRegistered,
ModuleID: id,
Timestamp: time.Now(),
Data: map[string]interface{}{
"name": module.GetName(),
"version": module.GetVersion(),
},
}); err != nil {
mr.logger.Error("Failed to publish module registration event", "module_id", id, "error", err)
}
}
_ = mr.publishEventWithRetry(ModuleEvent{
Type: EventModuleRegistered,
ModuleID: id,
Timestamp: time.Now(),
Data: map[string]interface{}{
"name": module.GetName(),
"version": module.GetVersion(),
},
}, "Module registration event publish failed")
return nil
}
@@ -302,13 +316,11 @@ func (mr *ModuleRegistry) Unregister(moduleID string) error {
delete(mr.dependencies, moduleID)
// Publish event
if mr.eventBus != nil {
mr.eventBus.Publish(ModuleEvent{
Type: EventModuleUnregistered,
ModuleID: moduleID,
Timestamp: time.Now(),
})
}
_ = mr.publishEventWithRetry(ModuleEvent{
Type: EventModuleUnregistered,
ModuleID: moduleID,
Timestamp: time.Now(),
}, "Module unregistration event publish failed")
return nil
}
@@ -512,6 +524,15 @@ func (mr *ModuleRegistry) SetHealthMonitor(healthMonitor HealthMonitor) {
mr.healthMonitor = healthMonitor
}
// SetLogger overrides the default logger for the registry.
func (mr *ModuleRegistry) SetLogger(logger *slog.Logger) {
mr.mu.Lock()
defer mr.mu.Unlock()
if logger != nil {
mr.logger = logger
}
}
// GetHealth returns the health status of all modules
func (mr *ModuleRegistry) GetHealth() map[string]ModuleHealth {
mr.mu.RLock()
@@ -538,6 +559,115 @@ func (mr *ModuleRegistry) GetMetrics() map[string]ModuleMetrics {
return metrics
}
func (mr *ModuleRegistry) recordRegistryError(message string, err error, attrs ...interface{}) {
if err == nil {
return
}
attrCopy := append([]interface{}{}, attrs...)
wrapped, txHash, attrsWithTx := enrichErrorWithTxHash(message, err, attrCopy)
mr.errMu.Lock()
mr.registryErrors = append(mr.registryErrors, wrapped)
mr.registryErrorDetails = append(mr.registryErrorDetails, RecordedError{
Err: wrapped,
TxHash: txHash,
})
mr.errMu.Unlock()
if mr.logger != nil {
kv := append([]interface{}{}, attrsWithTx...)
kv = append(kv, "error", err)
mr.logger.Error(message, kv...)
}
}
func (mr *ModuleRegistry) publishEventWithRetry(event ModuleEvent, failureMessage string) error {
if mr.eventBus == nil {
return nil
}
retries := mr.config.EventPublishRetries
if retries <= 0 {
retries = 1
}
delay := mr.config.EventPublishDelay
if delay <= 0 {
delay = 200 * time.Millisecond
}
var errs []error
for attempt := 1; attempt <= retries; attempt++ {
if err := mr.eventBus.Publish(event); err != nil {
errs = append(errs, err)
if attempt < retries {
time.Sleep(delay)
continue
}
joined := errors.Join(errs...)
mr.recordRegistryError(
failureMessage,
joined,
"module_id", event.ModuleID,
"event_type", event.Type,
"attempts", attempt,
)
return fmt.Errorf("%s after %d attempts: %w", failureMessage, attempt, joined)
}
if attempt > 1 && mr.logger != nil {
mr.logger.Warn("Module event publish succeeded after retry", "module_id", event.ModuleID, "event_type", event.Type, "attempts", attempt)
}
return nil
}
return nil
}
func (mr *ModuleRegistry) aggregatedErrors() error {
mr.errMu.Lock()
defer mr.errMu.Unlock()
if len(mr.registryErrors) == 0 {
return nil
}
errs := make([]error, len(mr.registryErrors))
copy(errs, mr.registryErrors)
return errors.Join(errs...)
}
// RegistryErrors exposes a copy of aggregated registry errors for diagnostics.
func (mr *ModuleRegistry) RegistryErrors() []error {
mr.errMu.Lock()
defer mr.errMu.Unlock()
if len(mr.registryErrors) == 0 {
return nil
}
errs := make([]error, len(mr.registryErrors))
copy(errs, mr.registryErrors)
return errs
}
// RegistryErrorDetails returns recorded registry errors with tx hash metadata.
func (mr *ModuleRegistry) RegistryErrorDetails() []RecordedError {
mr.errMu.Lock()
defer mr.errMu.Unlock()
if len(mr.registryErrorDetails) == 0 {
return nil
}
details := make([]RecordedError, len(mr.registryErrorDetails))
copy(details, mr.registryErrorDetails)
return details
}
// Shutdown gracefully shuts down the registry
func (mr *ModuleRegistry) Shutdown(ctx context.Context) error {
if mr.state == RegistryRunning {
@@ -599,19 +729,20 @@ func (mr *ModuleRegistry) initializeModule(ctx context.Context, registered *Regi
registered.State = StateInitialized
if err := registered.Instance.Initialize(ctx, registered.Config); err != nil {
registered.State = StateFailed
return err
}
// Publish event
if mr.eventBus != nil {
mr.eventBus.Publish(ModuleEvent{
_ = mr.publishEventWithRetry(ModuleEvent{
Type: EventModuleInitialized,
ModuleID: registered.ID,
Timestamp: time.Now(),
})
}, "Module initialization event publish failed after error")
return err
}
_ = mr.publishEventWithRetry(ModuleEvent{
Type: EventModuleInitialized,
ModuleID: registered.ID,
Timestamp: time.Now(),
}, "Module initialization event publish failed")
return nil
}
@@ -638,20 +769,19 @@ func (mr *ModuleRegistry) startModule(ctx context.Context, registered *Registere
// Start health monitoring
if mr.healthMonitor != nil {
mr.healthMonitor.StartMonitoring(registered)
if err := mr.healthMonitor.StartMonitoring(registered); err != nil {
mr.recordRegistryError("Failed to start health monitoring", err, "module_id", registered.ID)
}
}
// Publish event
if mr.eventBus != nil {
mr.eventBus.Publish(ModuleEvent{
Type: EventModuleStarted,
ModuleID: registered.ID,
Timestamp: time.Now(),
Data: map[string]interface{}{
"startup_time": registered.Metrics.StartupTime,
},
})
}
_ = mr.publishEventWithRetry(ModuleEvent{
Type: EventModuleStarted,
ModuleID: registered.ID,
Timestamp: time.Now(),
Data: map[string]interface{}{
"startup_time": registered.Metrics.StartupTime,
},
}, "Module started event publish failed")
return nil
}
@@ -680,77 +810,79 @@ func (mr *ModuleRegistry) stopModule(registered *RegisteredModule) error {
// Stop health monitoring
if mr.healthMonitor != nil {
if err := mr.healthMonitor.StopMonitoring(registered.ID); err != nil {
mr.logger.Error("Failed to stop health monitoring", "module_id", registered.ID, "error", err)
mr.recordRegistryError("Failed to stop health monitoring", err, "module_id", registered.ID)
}
}
// Publish event
if mr.eventBus != nil {
if err := mr.eventBus.Publish(ModuleEvent{
Type: EventModuleStopped,
ModuleID: registered.ID,
Timestamp: time.Now(),
Data: map[string]interface{}{
"shutdown_time": registered.Metrics.ShutdownTime,
},
}); err != nil {
mr.logger.Error("Failed to publish module stopped event", "module_id", registered.ID, "error", err)
}
_ = mr.publishEventWithRetry(ModuleEvent{
Type: EventModuleStopped,
ModuleID: registered.ID,
Timestamp: time.Now(),
Data: map[string]interface{}{
"shutdown_time": registered.Metrics.ShutdownTime,
},
}, "Module stopped event publish failed")
return nil
}
func (mr *ModuleRegistry) transitionModuleState(
ctx context.Context,
registered *RegisteredModule,
actionName string,
expectedInitialState ModuleState,
intermediateState ModuleState,
finalState ModuleState,
eventType EventType,
actionFunc func(context.Context) error,
) error {
if registered.State != expectedInitialState {
return fmt.Errorf("invalid state for %s: %s", actionName, registered.State)
}
registered.State = intermediateState
if err := actionFunc(ctx); err != nil {
registered.State = StateFailed
return err
}
registered.State = finalState
// Publish event
_ = mr.publishEventWithRetry(ModuleEvent{
Type: eventType,
ModuleID: registered.ID,
Timestamp: time.Now(),
}, "Module state transition event publish failed")
return nil
}
func (mr *ModuleRegistry) pauseModule(ctx context.Context, registered *RegisteredModule) error {
if registered.State != StateRunning {
return fmt.Errorf("invalid state for pause: %s", registered.State)
}
registered.State = StatePausing
if err := registered.Instance.Pause(ctx); err != nil {
registered.State = StateFailed
return err
}
registered.State = StatePaused
// Publish event
if mr.eventBus != nil {
mr.eventBus.Publish(ModuleEvent{
Type: EventModulePaused,
ModuleID: registered.ID,
Timestamp: time.Now(),
})
}
return nil
return mr.transitionModuleState(
ctx,
registered,
"pause",
StateRunning,
StatePausing,
StatePaused,
EventModulePaused,
registered.Instance.Pause,
)
}
func (mr *ModuleRegistry) resumeModule(ctx context.Context, registered *RegisteredModule) error {
if registered.State != StatePaused {
return fmt.Errorf("invalid state for resume: %s", registered.State)
}
registered.State = StateResuming
if err := registered.Instance.Resume(ctx); err != nil {
registered.State = StateFailed
return err
}
registered.State = StateRunning
// Publish event
if mr.eventBus != nil {
mr.eventBus.Publish(ModuleEvent{
Type: EventModuleResumed,
ModuleID: registered.ID,
Timestamp: time.Now(),
})
}
return nil
return mr.transitionModuleState(
ctx,
registered,
"resume",
StatePaused,
StateResuming,
StateRunning,
EventModuleResumed,
registered.Instance.Resume,
)
}
func (mr *ModuleRegistry) startAllSequential(ctx context.Context) error {

View File

@@ -0,0 +1,96 @@
package lifecycle
import (
"fmt"
"strings"
"testing"
"time"
)
type stubEventBus struct {
failUntil int
attempts int
txHash string
}
func (s *stubEventBus) Publish(event ModuleEvent) error {
s.attempts++
if s.attempts <= s.failUntil {
return fmt.Errorf("publish failure %d for tx %s", s.attempts, s.txHash)
}
return nil
}
func (s *stubEventBus) Subscribe(eventType EventType, handler EventHandler) error {
return nil
}
func TestModuleRegistryPublishEventWithRetrySuccess(t *testing.T) {
registry := NewModuleRegistry(RegistryConfig{
EventPublishRetries: 3,
EventPublishDelay: time.Nanosecond,
})
registry.logger = nil
bus := &stubEventBus{failUntil: 2, txHash: "0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}
registry.eventBus = bus
err := registry.publishEventWithRetry(ModuleEvent{ModuleID: "module-A", Type: EventModuleStarted}, "publish failed")
if err != nil {
t.Fatalf("expected publish to eventually succeed, got %v", err)
}
if bus.attempts != 3 {
t.Fatalf("expected 3 attempts, got %d", bus.attempts)
}
if err := registry.aggregatedErrors(); err != nil {
t.Fatalf("expected no aggregated errors, got %v", err)
}
}
func TestModuleRegistryPublishEventWithRetryFailure(t *testing.T) {
registry := NewModuleRegistry(RegistryConfig{
EventPublishRetries: 2,
EventPublishDelay: time.Nanosecond,
})
registry.logger = nil
txHash := "0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc"
bus := &stubEventBus{failUntil: 3, txHash: txHash}
registry.eventBus = bus
err := registry.publishEventWithRetry(ModuleEvent{ModuleID: "module-B", Type: EventModuleStopped}, "publish failed")
if err == nil {
t.Fatal("expected publish to fail after retries")
}
if bus.attempts != 2 {
t.Fatalf("expected 2 attempts, got %d", bus.attempts)
}
exported := registry.RegistryErrors()
if len(exported) != 1 {
t.Fatalf("expected 1 recorded registry error, got %d", len(exported))
}
if exported[0] == nil {
t.Fatal("expected recorded error to be non-nil")
}
if copyErrs := registry.RegistryErrors(); copyErrs[0] == nil {
t.Fatal("copy of registry errors should preserve values")
}
if got := exported[0].Error(); !strings.Contains(got, txHash) {
t.Fatalf("recorded registry error should include tx hash, got %q", got)
}
details := registry.RegistryErrorDetails()
if len(details) != 1 {
t.Fatalf("expected registry error details to include entry, got %d", len(details))
}
if details[0].TxHash != txHash {
t.Fatalf("expected registry error detail to track tx hash %s, got %s", txHash, details[0].TxHash)
}
agg := registry.aggregatedErrors()
if agg == nil {
t.Fatal("expected aggregated error to be returned")
}
if got := agg.Error(); !strings.Contains(got, "publish failed") || !strings.Contains(got, "publish failure 1") || !strings.Contains(got, txHash) {
t.Fatalf("aggregated error should include failure details and tx hash, got %q", got)
}
}

View File

@@ -2,29 +2,37 @@ package lifecycle
import (
"context"
"errors"
"fmt"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// ShutdownManager handles graceful shutdown of the application
type ShutdownManager struct {
registry *ModuleRegistry
shutdownTasks []ShutdownTask
shutdownHooks []ShutdownHook
config ShutdownConfig
signalChannel chan os.Signal
shutdownChannel chan struct{}
state ShutdownState
startTime time.Time
shutdownStarted time.Time
mu sync.RWMutex
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
registry *ModuleRegistry
shutdownTasks []ShutdownTask
shutdownHooks []ShutdownHook
config ShutdownConfig
signalChannel chan os.Signal
shutdownChannel chan struct{}
state ShutdownState
startTime time.Time
shutdownStarted time.Time
mu sync.RWMutex
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
logger *logger.Logger
shutdownErrors []error
shutdownErrorDetails []RecordedError
errMu sync.Mutex
exitFunc func(code int)
}
// ShutdownTask represents a task to be executed during shutdown
@@ -106,16 +114,19 @@ func NewShutdownManager(registry *ModuleRegistry, config ShutdownConfig) *Shutdo
ctx, cancel := context.WithCancel(context.Background())
sm := &ShutdownManager{
registry: registry,
shutdownTasks: make([]ShutdownTask, 0),
shutdownHooks: make([]ShutdownHook, 0),
config: config,
signalChannel: make(chan os.Signal, config.SignalBufferSize),
shutdownChannel: make(chan struct{}),
state: ShutdownStateRunning,
startTime: time.Now(),
ctx: ctx,
cancel: cancel,
registry: registry,
shutdownTasks: make([]ShutdownTask, 0),
shutdownHooks: make([]ShutdownHook, 0),
config: config,
signalChannel: make(chan os.Signal, config.SignalBufferSize),
shutdownChannel: make(chan struct{}),
state: ShutdownStateRunning,
startTime: time.Now(),
ctx: ctx,
cancel: cancel,
shutdownErrors: make([]error, 0),
shutdownErrorDetails: make([]RecordedError, 0),
exitFunc: os.Exit,
}
// Set default configuration
@@ -135,6 +146,11 @@ func NewShutdownManager(registry *ModuleRegistry, config ShutdownConfig) *Shutdo
sm.config.RetryDelay = time.Second
}
if err := os.MkdirAll("logs", 0o755); err != nil {
fmt.Printf("failed to ensure logs directory: %v\n", err)
}
sm.logger = logger.New("info", "", "logs/lifecycle_shutdown.log")
// Setup default shutdown tasks
sm.setupDefaultTasks()
@@ -174,7 +190,14 @@ func (sm *ShutdownManager) Shutdown(ctx context.Context) error {
// Close shutdown channel to signal shutdown
close(sm.shutdownChannel)
return sm.performShutdown(ctx)
err := sm.performShutdown(ctx)
combined := sm.combinedShutdownError()
if combined != nil {
return combined
}
return err
}
// ForceShutdown forces immediate shutdown
@@ -186,14 +209,23 @@ func (sm *ShutdownManager) ForceShutdown(ctx context.Context) error {
sm.cancel() // Cancel all operations
// Force stop all modules immediately
var forceErr error
if sm.registry != nil {
forceCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sm.registry.StopAll(forceCtx)
if err := sm.registry.StopAll(forceCtx); err != nil {
wrapped := fmt.Errorf("StopAll failed during force shutdown: %w", err)
sm.recordShutdownError("StopAll failed in force shutdown", wrapped)
forceErr = errors.Join(forceErr, wrapped)
}
}
os.Exit(1)
return nil
if forceErr != nil {
sm.recordShutdownError("Force shutdown encountered errors", forceErr)
}
sm.exitFunc(1)
return forceErr
}
// AddShutdownTask adds a task to be executed during shutdown
@@ -382,10 +414,13 @@ func (sm *ShutdownManager) signalHandler() {
// Graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), sm.config.GracefulTimeout)
if err := sm.Shutdown(ctx); err != nil {
sm.recordShutdownError("Graceful shutdown failed from signal", err)
cancel()
// Force shutdown if graceful fails
forceCtx, forceCancel := context.WithTimeout(context.Background(), sm.config.ForceTimeout)
sm.ForceShutdown(forceCtx)
if err := sm.ForceShutdown(forceCtx); err != nil {
sm.recordShutdownError("Force shutdown error in timeout scenario", err)
}
forceCancel()
}
cancel()
@@ -393,7 +428,9 @@ func (sm *ShutdownManager) signalHandler() {
case syscall.SIGQUIT:
// Force shutdown
ctx, cancel := context.WithTimeout(context.Background(), sm.config.ForceTimeout)
sm.ForceShutdown(ctx)
if err := sm.ForceShutdown(ctx); err != nil {
sm.recordShutdownError("Force shutdown error in SIGQUIT handler", err)
}
cancel()
return
case syscall.SIGHUP:
@@ -415,52 +452,66 @@ func (sm *ShutdownManager) performShutdown(ctx context.Context) error {
shutdownCtx, cancel := context.WithTimeout(ctx, sm.config.GracefulTimeout)
defer cancel()
var shutdownErr error
var phaseErrors []error
// Phase 1: Call shutdown started hooks
if err := sm.callHooks(shutdownCtx, "OnShutdownStarted"); err != nil {
shutdownErr = fmt.Errorf("shutdown hooks failed: %w", err)
if err := sm.callHooks(shutdownCtx, "OnShutdownStarted", nil); err != nil {
wrapped := fmt.Errorf("shutdown start hooks failed: %w", err)
sm.recordShutdownError("Shutdown started hook failure", wrapped)
phaseErrors = append(phaseErrors, wrapped)
}
// Phase 2: Stop modules
sm.state = ShutdownStateModuleStop
if sm.registry != nil {
if err := sm.registry.StopAll(shutdownCtx); err != nil {
shutdownErr = fmt.Errorf("failed to stop modules: %w", err)
wrapped := fmt.Errorf("failed to stop modules: %w", err)
sm.recordShutdownError("Module stop failure", wrapped)
phaseErrors = append(phaseErrors, wrapped)
}
}
// Call modules stopped hooks
if err := sm.callHooks(shutdownCtx, "OnModulesStopped"); err != nil {
if shutdownErr == nil {
shutdownErr = fmt.Errorf("modules stopped hooks failed: %w", err)
}
if err := sm.callHooks(shutdownCtx, "OnModulesStopped", nil); err != nil {
wrapped := fmt.Errorf("modules stopped hooks failed: %w", err)
sm.recordShutdownError("Modules stopped hook failure", wrapped)
phaseErrors = append(phaseErrors, wrapped)
}
// Phase 3: Execute shutdown tasks
sm.state = ShutdownStateCleanup
if err := sm.callHooks(shutdownCtx, "OnCleanupStarted"); err != nil {
if shutdownErr == nil {
shutdownErr = fmt.Errorf("cleanup hooks failed: %w", err)
}
if err := sm.callHooks(shutdownCtx, "OnCleanupStarted", nil); err != nil {
wrapped := fmt.Errorf("cleanup hooks failed: %w", err)
sm.recordShutdownError("Cleanup hook failure", wrapped)
phaseErrors = append(phaseErrors, wrapped)
}
if err := sm.executeShutdownTasks(shutdownCtx); err != nil {
if shutdownErr == nil {
shutdownErr = fmt.Errorf("shutdown tasks failed: %w", err)
}
wrapped := fmt.Errorf("shutdown tasks failed: %w", err)
sm.recordShutdownError("Shutdown task execution failure", wrapped)
phaseErrors = append(phaseErrors, wrapped)
}
// Phase 4: Final cleanup
if shutdownErr != nil {
if len(phaseErrors) > 0 {
finalErr := errors.Join(phaseErrors...)
sm.state = ShutdownStateFailed
sm.callHooks(shutdownCtx, "OnShutdownFailed")
} else {
sm.state = ShutdownStateCompleted
sm.callHooks(shutdownCtx, "OnShutdownCompleted")
if err := sm.callHooks(shutdownCtx, "OnShutdownFailed", finalErr); err != nil {
wrapped := fmt.Errorf("shutdown failed hook error: %w", err)
sm.recordShutdownError("Shutdown failed hook error", wrapped)
finalErr = errors.Join(finalErr, wrapped)
}
return finalErr
}
return shutdownErr
sm.state = ShutdownStateCompleted
if err := sm.callHooks(shutdownCtx, "OnShutdownCompleted", nil); err != nil {
wrapped := fmt.Errorf("shutdown completed hook error: %w", err)
sm.recordShutdownError("Shutdown completed hook error", wrapped)
return wrapped
}
return nil
}
func (sm *ShutdownManager) executeShutdownTasks(ctx context.Context) error {
@@ -559,6 +610,14 @@ func (sm *ShutdownManager) executeTask(ctx context.Context, task ShutdownTask) e
}
lastErr = err
attemptNumber := attempt + 1
sm.recordShutdownError(
fmt.Sprintf("Shutdown task %s failed", task.Name),
fmt.Errorf("attempt %d: %w", attemptNumber, err),
"task", task.Name,
"attempt", attemptNumber,
)
// Call error handler if provided
if task.OnError != nil {
@@ -569,10 +628,11 @@ func (sm *ShutdownManager) executeTask(ctx context.Context, task ShutdownTask) e
return fmt.Errorf("task failed after %d attempts: %w", sm.config.MaxRetries, lastErr)
}
func (sm *ShutdownManager) callHooks(ctx context.Context, hookMethod string) error {
var lastErr error
func (sm *ShutdownManager) callHooks(ctx context.Context, hookMethod string, cause error) error {
var hookErrors []error
for _, hook := range sm.shutdownHooks {
hookName := fmt.Sprintf("%T", hook)
var err error
switch hookMethod {
@@ -585,15 +645,21 @@ func (sm *ShutdownManager) callHooks(ctx context.Context, hookMethod string) err
case "OnShutdownCompleted":
err = hook.OnShutdownCompleted(ctx)
case "OnShutdownFailed":
err = hook.OnShutdownFailed(ctx, lastErr)
err = hook.OnShutdownFailed(ctx, cause)
}
if err != nil {
lastErr = err
recordContext := fmt.Sprintf("%s hook failure (%s)", hookMethod, hookName)
sm.recordShutdownError(recordContext, err, "hook", hookName, "phase", hookMethod)
hookErrors = append(hookErrors, fmt.Errorf("%s: %w", recordContext, err))
}
}
return lastErr
if len(hookErrors) > 0 {
return errors.Join(hookErrors...)
}
return nil
}
func (sm *ShutdownManager) sortTasksByPriority() {
@@ -660,6 +726,71 @@ func (sm *ShutdownManager) notifyExternalSystems(ctx context.Context) error {
return nil
}
func (sm *ShutdownManager) recordShutdownError(message string, err error, attrs ...interface{}) {
if err == nil {
return
}
attrCopy := append([]interface{}{}, attrs...)
wrapped, txHash, attrsWithTx := enrichErrorWithTxHash(message, err, attrCopy)
sm.errMu.Lock()
sm.shutdownErrors = append(sm.shutdownErrors, wrapped)
sm.shutdownErrorDetails = append(sm.shutdownErrorDetails, RecordedError{
Err: wrapped,
TxHash: txHash,
})
sm.errMu.Unlock()
if sm.logger != nil {
kv := append([]interface{}{}, attrsWithTx...)
kv = append(kv, "error", err)
args := append([]interface{}{message}, kv...)
sm.logger.Error(args...)
}
}
func (sm *ShutdownManager) combinedShutdownError() error {
sm.errMu.Lock()
defer sm.errMu.Unlock()
if len(sm.shutdownErrors) == 0 {
return nil
}
errs := make([]error, len(sm.shutdownErrors))
copy(errs, sm.shutdownErrors)
return errors.Join(errs...)
}
// ShutdownErrors returns a copy of recorded shutdown errors for diagnostics.
func (sm *ShutdownManager) ShutdownErrors() []error {
sm.errMu.Lock()
defer sm.errMu.Unlock()
if len(sm.shutdownErrors) == 0 {
return nil
}
errs := make([]error, len(sm.shutdownErrors))
copy(errs, sm.shutdownErrors)
return errs
}
// ShutdownErrorDetails returns recorded errors with associated metadata such as tx hash.
func (sm *ShutdownManager) ShutdownErrorDetails() []RecordedError {
sm.errMu.Lock()
defer sm.errMu.Unlock()
if len(sm.shutdownErrorDetails) == 0 {
return nil
}
details := make([]RecordedError, len(sm.shutdownErrorDetails))
copy(details, sm.shutdownErrorDetails)
return details
}
// DefaultShutdownHook provides a basic implementation of ShutdownHook
type DefaultShutdownHook struct {
name string

View File

@@ -0,0 +1,111 @@
package lifecycle
import (
"context"
"errors"
"fmt"
"strings"
"testing"
)
type testShutdownHook struct {
errs map[string]error
lastFailure error
}
func (h *testShutdownHook) OnShutdownStarted(ctx context.Context) error {
return h.errs["OnShutdownStarted"]
}
func (h *testShutdownHook) OnModulesStopped(ctx context.Context) error {
return h.errs["OnModulesStopped"]
}
func (h *testShutdownHook) OnCleanupStarted(ctx context.Context) error {
return h.errs["OnCleanupStarted"]
}
func (h *testShutdownHook) OnShutdownCompleted(ctx context.Context) error {
return h.errs["OnShutdownCompleted"]
}
func (h *testShutdownHook) OnShutdownFailed(ctx context.Context, err error) error {
h.lastFailure = err
return h.errs["OnShutdownFailed"]
}
func TestShutdownManagerErrorAggregation(t *testing.T) {
sm := NewShutdownManager(nil, ShutdownConfig{})
sm.logger = nil
txHash := "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
firstErr := fmt.Errorf("first failure on %s", txHash)
secondErr := errors.New("second")
sm.recordShutdownError("first error", firstErr)
sm.recordShutdownError("second error", secondErr)
if got := len(sm.shutdownErrors); got != 2 {
t.Fatalf("expected 2 recorded errors, got %d", got)
}
combined := sm.combinedShutdownError()
if combined == nil {
t.Fatal("expected combined shutdown error, got nil")
}
if !errors.Is(combined, firstErr) || !errors.Is(combined, secondErr) {
t.Fatalf("combined error does not contain original errors: %v", combined)
}
exportedErrors := sm.ShutdownErrors()
if len(exportedErrors) != 2 {
t.Fatalf("expected exported error slice of length 2, got %d", len(exportedErrors))
}
details := sm.ShutdownErrorDetails()
if len(details) != 2 {
t.Fatalf("expected error detail slice of length 2, got %d", len(details))
}
if details[0].TxHash != txHash {
t.Fatalf("expected recorded error to track tx hash %s, got %s", txHash, details[0].TxHash)
}
if details[0].Err == nil || !strings.Contains(details[0].Err.Error(), "tx_hash="+txHash) {
t.Fatalf("expected recorded error message to include tx hash, got %v", details[0].Err)
}
}
func TestShutdownManagerCallHooksAggregatesErrors(t *testing.T) {
sm := NewShutdownManager(nil, ShutdownConfig{})
sm.logger = nil
sm.shutdownErrors = nil
hookErrA := errors.New("hookA failure")
hookErrB := errors.New("hookB failure")
hookA := &testShutdownHook{
errs: map[string]error{
"OnShutdownFailed": hookErrA,
},
}
hookB := &testShutdownHook{
errs: map[string]error{
"OnShutdownFailed": hookErrB,
},
}
sm.shutdownHooks = []ShutdownHook{hookA, hookB}
cause := errors.New("original failure")
err := sm.callHooks(context.Background(), "OnShutdownFailed", cause)
if err == nil {
t.Fatal("expected aggregated error from hooks, got nil")
}
if !errors.Is(err, hookErrA) || !errors.Is(err, hookErrB) {
t.Fatalf("expected aggregated error to contain hook failures, got %v", err)
}
if hookA.lastFailure != cause || hookB.lastFailure != cause {
t.Fatal("expected hook to receive original failure cause")
}
if len(sm.ShutdownErrors()) != 2 {
t.Fatalf("expected shutdown errors to be recorded for each hook failure, got %d", len(sm.ShutdownErrors()))
}
}

View File

@@ -372,13 +372,17 @@ func (sm *StateMachine) performTransition(ctx context.Context, to ModuleState, t
// Execute pre-transition hooks
for name, hook := range sm.transitionHooks {
if hookCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout); hookCtx != nil {
hookCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
err := func() error {
defer cancel()
if err := hook(hookCtx, from, to, sm); err != nil {
cancel()
sm.recordFailedTransition(from, to, startTime, trigger, err, transitionCtx)
return fmt.Errorf("pre-transition hook %s failed: %w", name, err)
}
cancel()
return nil
}()
if err != nil {
sm.recordFailedTransition(from, to, startTime, trigger, err, transitionCtx)
return err
}
}
@@ -402,26 +406,25 @@ func (sm *StateMachine) performTransition(ctx context.Context, to ModuleState, t
// Execute post-transition hooks
for _, hook := range sm.transitionHooks {
if hookCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout); hookCtx != nil {
hookCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
func() {
defer cancel()
if err := hook(hookCtx, from, to, sm); err != nil {
// Log error but don't fail the transition
cancel()
continue
return
}
cancel()
}
}()
}
// Execute state handler for new state
if handler, exists := sm.stateHandlers[to]; exists {
if handlerCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout); handlerCtx != nil {
handlerCtx, cancel := context.WithTimeout(ctx, sm.config.TransitionTimeout)
func() {
defer cancel()
if err := handler(handlerCtx, sm); err != nil {
cancel()
// Log error but don't fail the transition
} else {
cancel()
}
}
}()
}
return nil