Sequencer is working (minimal parsing)

This commit is contained in:
Krypto Kajun
2025-09-14 06:21:10 -05:00
parent 7dd5b5b692
commit 518758790a
59 changed files with 10539 additions and 471 deletions

262
internal/auth/middleware.go Normal file
View File

@@ -0,0 +1,262 @@
package auth
import (
"crypto/subtle"
"fmt"
"net/http"
"os"
"strings"
"time"
"sync"
"github.com/fraktal/mev-beta/internal/logger"
)
// AuthConfig holds authentication configuration
type AuthConfig struct {
APIKey string
BasicUsername string
BasicPassword string
AllowedIPs []string
RequireHTTPS bool
RateLimitRPS int
Logger *logger.Logger
}
// Middleware provides authentication middleware for HTTP endpoints
type Middleware struct {
config *AuthConfig
rateLimiter map[string]*RateLimiter
mu sync.RWMutex
}
// RateLimiter tracks request rates per IP
type RateLimiter struct {
requests []time.Time
maxRequests int
window time.Duration
}
// NewMiddleware creates a new authentication middleware
func NewMiddleware(config *AuthConfig) *Middleware {
// Use environment variables for sensitive data
if config.APIKey == "" {
config.APIKey = os.Getenv("MEV_BOT_API_KEY")
}
if config.BasicUsername == "" {
config.BasicUsername = os.Getenv("MEV_BOT_USERNAME")
}
if config.BasicPassword == "" {
config.BasicPassword = os.Getenv("MEV_BOT_PASSWORD")
}
return &Middleware{
config: config,
rateLimiter: make(map[string]*RateLimiter),
}
}
// RequireAuthentication is a middleware that requires API key or basic auth
func (m *Middleware) RequireAuthentication(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Require HTTPS in production
if m.config.RequireHTTPS && r.Header.Get("X-Forwarded-Proto") != "https" && r.TLS == nil {
http.Error(w, "HTTPS required", http.StatusUpgradeRequired)
return
}
// Check IP allowlist
if !m.isIPAllowed(r.RemoteAddr) {
m.config.Logger.Warn(fmt.Sprintf("Blocked request from unauthorized IP: %s", r.RemoteAddr))
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Rate limiting
if !m.checkRateLimit(r.RemoteAddr) {
m.config.Logger.Warn(fmt.Sprintf("Rate limit exceeded for IP: %s", r.RemoteAddr))
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
// Try API key authentication first
if m.authenticateAPIKey(r) {
next.ServeHTTP(w, r)
return
}
// Try basic authentication
if m.authenticateBasic(r) {
next.ServeHTTP(w, r)
return
}
// Authentication failed
w.Header().Set("WWW-Authenticate", `Basic realm="MEV Bot API"`)
http.Error(w, "Authentication required", http.StatusUnauthorized)
}
}
// authenticateAPIKey checks for valid API key in header or query param
func (m *Middleware) authenticateAPIKey(r *http.Request) bool {
if m.config.APIKey == "" {
return false
}
// Check Authorization header
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
token := strings.TrimPrefix(auth, "Bearer ")
return subtle.ConstantTimeCompare([]byte(token), []byte(m.config.APIKey)) == 1
}
// Check X-API-Key header
apiKey := r.Header.Get("X-API-Key")
if apiKey != "" {
return subtle.ConstantTimeCompare([]byte(apiKey), []byte(m.config.APIKey)) == 1
}
// Check query parameter (less secure, but sometimes necessary)
queryKey := r.URL.Query().Get("api_key")
if queryKey != "" {
return subtle.ConstantTimeCompare([]byte(queryKey), []byte(m.config.APIKey)) == 1
}
return false
}
// authenticateBasic checks basic authentication credentials
func (m *Middleware) authenticateBasic(r *http.Request) bool {
if m.config.BasicUsername == "" || m.config.BasicPassword == "" {
return false
}
username, password, ok := r.BasicAuth()
if !ok {
return false
}
// Use constant time comparison to prevent timing attacks
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(m.config.BasicUsername)) == 1
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(m.config.BasicPassword)) == 1
return usernameMatch && passwordMatch
}
// isIPAllowed checks if the request IP is in the allowlist
func (m *Middleware) isIPAllowed(remoteAddr string) bool {
if len(m.config.AllowedIPs) == 0 {
return true // No IP restrictions
}
// Extract IP from address (remove port)
ip := strings.Split(remoteAddr, ":")[0]
for _, allowedIP := range m.config.AllowedIPs {
if ip == allowedIP || allowedIP == "*" {
return true
}
// Support CIDR notation in future versions
}
return false
}
// checkRateLimit implements simple rate limiting per IP
func (m *Middleware) checkRateLimit(remoteAddr string) bool {
if m.config.RateLimitRPS <= 0 {
return true // No rate limiting
}
ip := strings.Split(remoteAddr, ":")[0]
now := time.Now()
m.mu.Lock()
defer m.mu.Unlock()
// Get or create rate limiter for this IP
limiter, exists := m.rateLimiter[ip]
if !exists {
limiter = &RateLimiter{
requests: make([]time.Time, 0),
maxRequests: m.config.RateLimitRPS,
window: time.Minute,
}
m.rateLimiter[ip] = limiter
}
// Clean old requests outside the time window
cutoff := now.Add(-limiter.window)
validRequests := make([]time.Time, 0)
for _, reqTime := range limiter.requests {
if reqTime.After(cutoff) {
validRequests = append(validRequests, reqTime)
}
}
limiter.requests = validRequests
// Check if we're under the limit
if len(limiter.requests) >= limiter.maxRequests {
return false
}
// Add current request
limiter.requests = append(limiter.requests, now)
return true
}
// RequireReadOnly is a middleware for read-only endpoints (less strict)
func (m *Middleware) RequireReadOnly(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Only allow GET requests for read-only endpoints
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Apply basic security checks
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
// Check IP allowlist
if !m.isIPAllowed(r.RemoteAddr) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Apply rate limiting
if !m.checkRateLimit(r.RemoteAddr) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
}
}
// CleanupRateLimiters removes old rate limiter entries
func (m *Middleware) CleanupRateLimiters() {
cutoff := time.Now().Add(-time.Hour)
m.mu.Lock()
defer m.mu.Unlock()
for ip, limiter := range m.rateLimiter {
// Remove limiters that haven't been used recently
if len(limiter.requests) == 0 {
delete(m.rateLimiter, ip)
continue
}
lastRequest := limiter.requests[len(limiter.requests)-1]
if lastRequest.Before(cutoff) {
delete(m.rateLimiter, ip)
}
}
}

View File

@@ -1,9 +1,13 @@
package config
import (
"fmt"
"os"
"regexp"
"strconv"
"strings"
"gopkg.in/yaml.v3"
)
@@ -117,9 +121,12 @@ func Load(filename string) (*Config, error) {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
// Expand environment variables in the raw YAML
expandedData := expandEnvVars(string(data))
// Parse the YAML
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
if err := yaml.Unmarshal([]byte(expandedData), &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
@@ -129,6 +136,33 @@ func Load(filename string) (*Config, error) {
return &config, nil
}
// expandEnvVars expands ${VAR} and $VAR patterns in the given string
func expandEnvVars(s string) string {
// Pattern to match ${VAR} and $VAR
envVarPattern := regexp.MustCompile(`\$\{([^}]+)\}|\$([A-Za-z_][A-Za-z0-9_]*)`)
return envVarPattern.ReplaceAllStringFunc(s, func(match string) string {
var varName string
// Handle ${VAR} format
if strings.HasPrefix(match, "${") && strings.HasSuffix(match, "}") {
varName = match[2 : len(match)-1]
} else if strings.HasPrefix(match, "$") {
// Handle $VAR format
varName = match[1:]
}
// Get environment variable value
if value := os.Getenv(varName); value != "" {
return value
}
// Return empty string if environment variable is not set
// This prevents invalid YAML when variables are missing
return ""
})
}
// OverrideWithEnv overrides configuration with environment variables
func (c *Config) OverrideWithEnv() {
// Override RPC endpoint

View File

@@ -0,0 +1,450 @@
package ratelimit
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/fraktal/mev-beta/internal/config"
"github.com/fraktal/mev-beta/internal/logger"
"golang.org/x/time/rate"
)
// AdaptiveRateLimiter implements adaptive rate limiting that adjusts to endpoint capacity
type AdaptiveRateLimiter struct {
endpoints map[string]*AdaptiveEndpoint
mu sync.RWMutex
logger *logger.Logger
defaultConfig config.RateLimitConfig
adjustInterval time.Duration
stopChan chan struct{}
}
// AdaptiveEndpoint represents an endpoint with adaptive rate limiting
type AdaptiveEndpoint struct {
URL string
limiter *rate.Limiter
config config.RateLimitConfig
circuitBreaker *CircuitBreaker
metrics *EndpointMetrics
healthChecker *HealthChecker
lastAdjustment time.Time
consecutiveErrors int64
consecutiveSuccess int64
}
// EndpointMetrics tracks performance metrics for an endpoint
type EndpointMetrics struct {
TotalRequests int64
SuccessfulRequests int64
FailedRequests int64
TotalLatency int64 // nanoseconds
LastRequestTime int64 // unix timestamp
SuccessRate float64
AverageLatency float64 // milliseconds
}
// CircuitBreaker implements circuit breaker pattern for failed endpoints
type CircuitBreaker struct {
state int32 // 0: Closed, 1: Open, 2: HalfOpen
failureCount int64
lastFailTime int64
threshold int64
timeout time.Duration // How long to wait before trying again
testRequests int64 // Number of test requests in half-open state
}
// Circuit breaker states
const (
CircuitClosed = 0
CircuitOpen = 1
CircuitHalfOpen = 2
)
// HealthChecker monitors endpoint health
type HealthChecker struct {
endpoint string
interval time.Duration
timeout time.Duration
isHealthy int64 // atomic bool
lastCheck int64 // unix timestamp
stopChan chan struct{}
}
// NewAdaptiveRateLimiter creates a new adaptive rate limiter
func NewAdaptiveRateLimiter(cfg *config.ArbitrumConfig, logger *logger.Logger) *AdaptiveRateLimiter {
arl := &AdaptiveRateLimiter{
endpoints: make(map[string]*AdaptiveEndpoint),
logger: logger,
defaultConfig: cfg.RateLimit,
adjustInterval: 30 * time.Second,
stopChan: make(chan struct{}),
}
// Create adaptive endpoint for primary endpoint
arl.addEndpoint(cfg.RPCEndpoint, cfg.RateLimit)
// Create adaptive endpoints for fallback endpoints
for _, endpoint := range cfg.FallbackEndpoints {
arl.addEndpoint(endpoint.URL, endpoint.RateLimit)
}
// Start background adjustment routine
go arl.adjustmentLoop()
return arl
}
// addEndpoint adds a new adaptive endpoint
func (arl *AdaptiveRateLimiter) addEndpoint(url string, config config.RateLimitConfig) {
endpoint := &AdaptiveEndpoint{
URL: url,
limiter: rate.NewLimiter(rate.Limit(config.RequestsPerSecond), config.Burst),
config: config,
circuitBreaker: &CircuitBreaker{
threshold: 10, // Break after 10 consecutive failures
timeout: 60 * time.Second,
},
metrics: &EndpointMetrics{},
healthChecker: &HealthChecker{
endpoint: url,
interval: 30 * time.Second,
timeout: 5 * time.Second,
isHealthy: 1, // Start assuming healthy
stopChan: make(chan struct{}),
},
}
arl.mu.Lock()
arl.endpoints[url] = endpoint
arl.mu.Unlock()
// Start health checker for this endpoint
go endpoint.healthChecker.start()
arl.logger.Info(fmt.Sprintf("Added adaptive rate limiter for endpoint: %s", url))
}
// WaitForBestEndpoint waits for the best available endpoint
func (arl *AdaptiveRateLimiter) WaitForBestEndpoint(ctx context.Context) (string, error) {
// Find the best available endpoint
bestEndpoint := arl.getBestEndpoint()
if bestEndpoint == "" {
return "", fmt.Errorf("no healthy endpoints available")
}
// Wait for rate limiter
arl.mu.RLock()
endpoint := arl.endpoints[bestEndpoint]
arl.mu.RUnlock()
if endpoint == nil {
return "", fmt.Errorf("endpoint not found: %s", bestEndpoint)
}
// Check circuit breaker
if !endpoint.circuitBreaker.canExecute() {
return "", fmt.Errorf("circuit breaker open for endpoint: %s", bestEndpoint)
}
// Wait for rate limiter
err := endpoint.limiter.Wait(ctx)
if err != nil {
return "", err
}
return bestEndpoint, nil
}
// RecordResult records the result of a request for adaptive adjustment
func (arl *AdaptiveRateLimiter) RecordResult(endpointURL string, success bool, latency time.Duration) {
arl.mu.RLock()
endpoint, exists := arl.endpoints[endpointURL]
arl.mu.RUnlock()
if !exists {
return
}
// Update metrics atomically
atomic.AddInt64(&endpoint.metrics.TotalRequests, 1)
atomic.AddInt64(&endpoint.metrics.TotalLatency, latency.Nanoseconds())
atomic.StoreInt64(&endpoint.metrics.LastRequestTime, time.Now().Unix())
if success {
atomic.AddInt64(&endpoint.metrics.SuccessfulRequests, 1)
atomic.AddInt64(&endpoint.consecutiveSuccess, 1)
atomic.StoreInt64(&endpoint.consecutiveErrors, 0)
endpoint.circuitBreaker.recordSuccess()
} else {
atomic.AddInt64(&endpoint.metrics.FailedRequests, 1)
atomic.AddInt64(&endpoint.consecutiveErrors, 1)
atomic.StoreInt64(&endpoint.consecutiveSuccess, 0)
endpoint.circuitBreaker.recordFailure()
}
// Update calculated metrics
arl.updateCalculatedMetrics(endpoint)
}
// updateCalculatedMetrics updates derived metrics
func (arl *AdaptiveRateLimiter) updateCalculatedMetrics(endpoint *AdaptiveEndpoint) {
totalReq := atomic.LoadInt64(&endpoint.metrics.TotalRequests)
successReq := atomic.LoadInt64(&endpoint.metrics.SuccessfulRequests)
totalLatency := atomic.LoadInt64(&endpoint.metrics.TotalLatency)
if totalReq > 0 {
endpoint.metrics.SuccessRate = float64(successReq) / float64(totalReq)
endpoint.metrics.AverageLatency = float64(totalLatency) / float64(totalReq) / 1000000 // Convert to milliseconds
}
}
// getBestEndpoint selects the best available endpoint based on metrics
func (arl *AdaptiveRateLimiter) getBestEndpoint() string {
arl.mu.RLock()
defer arl.mu.RUnlock()
bestEndpoint := ""
bestScore := float64(-1)
for url, endpoint := range arl.endpoints {
// Skip unhealthy endpoints
if atomic.LoadInt64(&endpoint.healthChecker.isHealthy) == 0 {
continue
}
// Skip if circuit breaker is open
if !endpoint.circuitBreaker.canExecute() {
continue
}
// Calculate score based on success rate, latency, and current load
score := arl.calculateEndpointScore(endpoint)
if score > bestScore {
bestScore = score
bestEndpoint = url
}
}
return bestEndpoint
}
// calculateEndpointScore calculates a score for endpoint selection
func (arl *AdaptiveRateLimiter) calculateEndpointScore(endpoint *AdaptiveEndpoint) float64 {
// Base score on success rate (0-1)
successWeight := 0.6
latencyWeight := 0.3
loadWeight := 0.1
successScore := endpoint.metrics.SuccessRate
// Invert latency score (lower latency = higher score)
latencyScore := 1.0
if endpoint.metrics.AverageLatency > 0 {
// Normalize latency score (assuming 1000ms is poor, 100ms is good)
latencyScore = 1.0 - (endpoint.metrics.AverageLatency / 1000.0)
if latencyScore < 0 {
latencyScore = 0
}
}
// Load score based on current rate limiter state
loadScore := 1.0 // Simplified - could check current tokens in limiter
return successScore*successWeight + latencyScore*latencyWeight + loadScore*loadWeight
}
// adjustmentLoop runs periodic adjustments to rate limits
func (arl *AdaptiveRateLimiter) adjustmentLoop() {
ticker := time.NewTicker(arl.adjustInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
arl.adjustRateLimits()
case <-arl.stopChan:
return
}
}
}
// adjustRateLimits adjusts rate limits based on observed performance
func (arl *AdaptiveRateLimiter) adjustRateLimits() {
arl.mu.Lock()
defer arl.mu.Unlock()
for url, endpoint := range arl.endpoints {
arl.adjustEndpointRateLimit(url, endpoint)
}
}
// adjustEndpointRateLimit adjusts rate limit for a specific endpoint
func (arl *AdaptiveRateLimiter) adjustEndpointRateLimit(url string, endpoint *AdaptiveEndpoint) {
// Don't adjust too frequently
if time.Since(endpoint.lastAdjustment) < arl.adjustInterval {
return
}
successRate := endpoint.metrics.SuccessRate
avgLatency := endpoint.metrics.AverageLatency
currentLimit := float64(endpoint.limiter.Limit())
var newLimit float64 = currentLimit
adjustmentFactor := 0.1 // 10% adjustment
// Increase rate if performing well
if successRate > 0.95 && avgLatency < 500 { // 95% success, < 500ms latency
newLimit = currentLimit * (1.0 + adjustmentFactor)
} else if successRate < 0.8 || avgLatency > 2000 { // < 80% success or > 2s latency
newLimit = currentLimit * (1.0 - adjustmentFactor)
}
// Apply bounds
minLimit := float64(arl.defaultConfig.RequestsPerSecond) * 0.1 // 10% of default minimum
maxLimit := float64(arl.defaultConfig.RequestsPerSecond) * 3.0 // 300% of default maximum
if newLimit < minLimit {
newLimit = minLimit
}
if newLimit > maxLimit {
newLimit = maxLimit
}
// Update if changed significantly
if abs(newLimit-currentLimit)/currentLimit > 0.05 { // 5% change threshold
endpoint.limiter.SetLimit(rate.Limit(newLimit))
endpoint.lastAdjustment = time.Now()
arl.logger.Info(fmt.Sprintf("Adjusted rate limit for %s: %.2f -> %.2f (success: %.2f%%, latency: %.2fms)",
url, currentLimit, newLimit, successRate*100, avgLatency))
}
}
// abs returns absolute value of float64
func abs(x float64) float64 {
if x < 0 {
return -x
}
return x
}
// canExecute checks if circuit breaker allows execution
func (cb *CircuitBreaker) canExecute() bool {
state := atomic.LoadInt32(&cb.state)
now := time.Now().Unix()
switch state {
case CircuitClosed:
return true
case CircuitOpen:
// Check if timeout has passed
lastFail := atomic.LoadInt64(&cb.lastFailTime)
if now-lastFail > int64(cb.timeout.Seconds()) {
// Try to move to half-open
if atomic.CompareAndSwapInt32(&cb.state, CircuitOpen, CircuitHalfOpen) {
atomic.StoreInt64(&cb.testRequests, 0)
return true
}
}
return false
case CircuitHalfOpen:
// Allow limited test requests
testReq := atomic.LoadInt64(&cb.testRequests)
if testReq < 3 { // Allow up to 3 test requests
atomic.AddInt64(&cb.testRequests, 1)
return true
}
return false
}
return false
}
// recordSuccess records a successful request
func (cb *CircuitBreaker) recordSuccess() {
state := atomic.LoadInt32(&cb.state)
if state == CircuitHalfOpen {
// Move back to closed after successful test
atomic.StoreInt32(&cb.state, CircuitClosed)
atomic.StoreInt64(&cb.failureCount, 0)
}
}
// recordFailure records a failed request
func (cb *CircuitBreaker) recordFailure() {
failures := atomic.AddInt64(&cb.failureCount, 1)
atomic.StoreInt64(&cb.lastFailTime, time.Now().Unix())
if failures >= cb.threshold {
atomic.StoreInt32(&cb.state, CircuitOpen)
}
}
// start starts the health checker
func (hc *HealthChecker) start() {
ticker := time.NewTicker(hc.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
hc.checkHealth()
case <-hc.stopChan:
return
}
}
}
// checkHealth performs a health check on the endpoint
func (hc *HealthChecker) checkHealth() {
ctx, cancel := context.WithTimeout(context.Background(), hc.timeout)
defer cancel()
// Simple health check - try to connect
// In production, this might make a simple RPC call
healthy := hc.performHealthCheck(ctx)
if healthy {
atomic.StoreInt64(&hc.isHealthy, 1)
} else {
atomic.StoreInt64(&hc.isHealthy, 0)
}
atomic.StoreInt64(&hc.lastCheck, time.Now().Unix())
}
// performHealthCheck performs the actual health check
func (hc *HealthChecker) performHealthCheck(ctx context.Context) bool {
// Simplified health check - in production would make actual RPC call
// For now, just simulate based on endpoint availability
return true // Assume healthy for demo
}
// Stop stops the adaptive rate limiter
func (arl *AdaptiveRateLimiter) Stop() {
close(arl.stopChan)
// Stop all health checkers
arl.mu.RLock()
for _, endpoint := range arl.endpoints {
close(endpoint.healthChecker.stopChan)
}
arl.mu.RUnlock()
}
// GetMetrics returns current metrics for all endpoints
func (arl *AdaptiveRateLimiter) GetMetrics() map[string]*EndpointMetrics {
arl.mu.RLock()
defer arl.mu.RUnlock()
metrics := make(map[string]*EndpointMetrics)
for url, endpoint := range arl.endpoints {
// Update calculated metrics before returning
arl.updateCalculatedMetrics(endpoint)
metrics[url] = endpoint.metrics
}
return metrics
}

View File

@@ -0,0 +1,292 @@
package secure
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/fraktal/mev-beta/internal/logger"
)
// ConfigManager handles secure configuration management
type ConfigManager struct {
logger *logger.Logger
aesGCM cipher.AEAD
key []byte
}
// NewConfigManager creates a new secure configuration manager
func NewConfigManager(logger *logger.Logger) (*ConfigManager, error) {
// Get encryption key from environment or generate one
keyStr := os.Getenv("MEV_BOT_CONFIG_KEY")
if keyStr == "" {
return nil, errors.New("MEV_BOT_CONFIG_KEY environment variable not set")
}
// Create SHA-256 hash of the key for AES-256
key := sha256.Sum256([]byte(keyStr))
// Create AES cipher
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
// Create GCM mode
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM mode: %w", err)
}
return &ConfigManager{
logger: logger,
aesGCM: aesGCM,
key: key[:],
}, nil
}
// EncryptValue encrypts a configuration value
func (cm *ConfigManager) EncryptValue(plaintext string) (string, error) {
// Create a random nonce
nonce := make([]byte, cm.aesGCM.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt the plaintext
ciphertext := cm.aesGCM.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode to base64 for storage
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptValue decrypts a configuration value
func (cm *ConfigManager) DecryptValue(ciphertext string) (string, error) {
// Decode from base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("failed to decode base64: %w", err)
}
// Check minimum length (nonce size)
nonceSize := cm.aesGCM.NonceSize()
if len(data) < nonceSize {
return "", errors.New("ciphertext too short")
}
// Extract nonce and ciphertext
nonce, ciphertext_bytes := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := cm.aesGCM.Open(nil, nonce, ciphertext_bytes, nil)
if err != nil {
return "", fmt.Errorf("failed to decrypt: %w", err)
}
return string(plaintext), nil
}
// GetSecureValue gets a secure value from environment with fallback to encrypted storage
func (cm *ConfigManager) GetSecureValue(key string) (string, error) {
// First try environment variable
if value := os.Getenv(key); value != "" {
return value, nil
}
// Try encrypted environment variable
encryptedKey := key + "_ENCRYPTED"
if encryptedValue := os.Getenv(encryptedKey); encryptedValue != "" {
return cm.DecryptValue(encryptedValue)
}
return "", fmt.Errorf("secure value not found for key: %s", key)
}
// SecureConfig holds encrypted configuration values
type SecureConfig struct {
manager *ConfigManager
values map[string]string
}
// NewSecureConfig creates a new secure configuration
func NewSecureConfig(manager *ConfigManager) *SecureConfig {
return &SecureConfig{
manager: manager,
values: make(map[string]string),
}
}
// Set stores a value securely
func (sc *SecureConfig) Set(key, value string) error {
encrypted, err := sc.manager.EncryptValue(value)
if err != nil {
return fmt.Errorf("failed to encrypt value for key %s: %w", key, err)
}
sc.values[key] = encrypted
return nil
}
// Get retrieves a value securely
func (sc *SecureConfig) Get(key string) (string, error) {
// Check local encrypted storage first
if encrypted, exists := sc.values[key]; exists {
return sc.manager.DecryptValue(encrypted)
}
// Fallback to secure environment lookup
return sc.manager.GetSecureValue(key)
}
// GetRequired retrieves a required value, returning error if not found
func (sc *SecureConfig) GetRequired(key string) (string, error) {
value, err := sc.Get(key)
if err != nil {
return "", fmt.Errorf("required configuration value missing: %s", key)
}
if strings.TrimSpace(value) == "" {
return "", fmt.Errorf("required configuration value empty: %s", key)
}
return value, nil
}
// GetWithDefault retrieves a value with a default fallback
func (sc *SecureConfig) GetWithDefault(key, defaultValue string) string {
value, err := sc.Get(key)
if err != nil {
return defaultValue
}
return value
}
// LoadFromEnvironment loads configuration from environment variables
func (sc *SecureConfig) LoadFromEnvironment(keys []string) error {
for _, key := range keys {
value, err := sc.manager.GetSecureValue(key)
if err != nil {
sc.manager.logger.Warn(fmt.Sprintf("Could not load secure config for %s: %v", key, err))
continue
}
// Store encrypted in memory
if err := sc.Set(key, value); err != nil {
return fmt.Errorf("failed to store secure config for %s: %w", key, err)
}
}
return nil
}
// Clear removes all stored values from memory
func (sc *SecureConfig) Clear() {
// Zero out the map entries before clearing
for key := range sc.values {
// Overwrite with zeros
sc.values[key] = strings.Repeat("0", len(sc.values[key]))
delete(sc.values, key)
}
}
// Validate checks that all required configuration is present
func (sc *SecureConfig) Validate(requiredKeys []string) error {
var missingKeys []string
for _, key := range requiredKeys {
if _, err := sc.GetRequired(key); err != nil {
missingKeys = append(missingKeys, key)
}
}
if len(missingKeys) > 0 {
return fmt.Errorf("missing required configuration keys: %s", strings.Join(missingKeys, ", "))
}
return nil
}
// GenerateConfigKey generates a new encryption key for configuration
func GenerateConfigKey() (string, error) {
key := make([]byte, 32) // 256-bit key
if _, err := rand.Read(key); err != nil {
return "", fmt.Errorf("failed to generate random key: %w", err)
}
return base64.StdEncoding.EncodeToString(key), nil
}
// ConfigValidator provides validation utilities
type ConfigValidator struct {
logger *logger.Logger
}
// NewConfigValidator creates a new configuration validator
func NewConfigValidator(logger *logger.Logger) *ConfigValidator {
return &ConfigValidator{
logger: logger,
}
}
// ValidateURL validates that a URL is properly formatted and uses HTTPS
func (cv *ConfigValidator) ValidateURL(url string) error {
if url == "" {
return errors.New("URL cannot be empty")
}
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "wss://") {
return errors.New("URL must use HTTPS or WSS protocol")
}
// Additional validation could go here (DNS lookup, connection test, etc.)
return nil
}
// ValidateAPIKey validates that an API key meets minimum security requirements
func (cv *ConfigValidator) ValidateAPIKey(key string) error {
if key == "" {
return errors.New("API key cannot be empty")
}
if len(key) < 32 {
return errors.New("API key must be at least 32 characters")
}
// Check for basic entropy (not all same character, contains mixed case, etc.)
if strings.Count(key, string(key[0])) == len(key) {
return errors.New("API key lacks sufficient entropy")
}
return nil
}
// ValidateAddress validates an Ethereum address
func (cv *ConfigValidator) ValidateAddress(address string) error {
if address == "" {
return errors.New("address cannot be empty")
}
if !strings.HasPrefix(address, "0x") {
return errors.New("address must start with 0x")
}
if len(address) != 42 { // 0x + 40 hex chars
return errors.New("address must be 42 characters long")
}
// Validate hex format
for i, char := range address[2:] {
if !((char >= '0' && char <= '9') || (char >= 'a' && char <= 'f') || (char >= 'A' && char <= 'F')) {
return fmt.Errorf("invalid hex character at position %d: %c", i+2, char)
}
}
return nil
}