293 lines
7.6 KiB
Go
293 lines
7.6 KiB
Go
package secure
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/fraktal/mev-beta/internal/logger"
|
|
)
|
|
|
|
// ConfigManager handles secure configuration management
|
|
type ConfigManager struct {
|
|
logger *logger.Logger
|
|
aesGCM cipher.AEAD
|
|
key []byte
|
|
}
|
|
|
|
// NewConfigManager creates a new secure configuration manager
|
|
func NewConfigManager(logger *logger.Logger) (*ConfigManager, error) {
|
|
// Get encryption key from environment or generate one
|
|
keyStr := os.Getenv("MEV_BOT_CONFIG_KEY")
|
|
if keyStr == "" {
|
|
return nil, errors.New("MEV_BOT_CONFIG_KEY environment variable not set")
|
|
}
|
|
|
|
// Create SHA-256 hash of the key for AES-256
|
|
key := sha256.Sum256([]byte(keyStr))
|
|
|
|
// Create AES cipher
|
|
block, err := aes.NewCipher(key[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
|
|
}
|
|
|
|
// Create GCM mode
|
|
aesGCM, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create GCM mode: %w", err)
|
|
}
|
|
|
|
return &ConfigManager{
|
|
logger: logger,
|
|
aesGCM: aesGCM,
|
|
key: key[:],
|
|
}, nil
|
|
}
|
|
|
|
// EncryptValue encrypts a configuration value
|
|
func (cm *ConfigManager) EncryptValue(plaintext string) (string, error) {
|
|
// Create a random nonce
|
|
nonce := make([]byte, cm.aesGCM.NonceSize())
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
|
}
|
|
|
|
// Encrypt the plaintext
|
|
ciphertext := cm.aesGCM.Seal(nonce, nonce, []byte(plaintext), nil)
|
|
|
|
// Encode to base64 for storage
|
|
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
|
}
|
|
|
|
// DecryptValue decrypts a configuration value
|
|
func (cm *ConfigManager) DecryptValue(ciphertext string) (string, error) {
|
|
// Decode from base64
|
|
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to decode base64: %w", err)
|
|
}
|
|
|
|
// Check minimum length (nonce size)
|
|
nonceSize := cm.aesGCM.NonceSize()
|
|
if len(data) < nonceSize {
|
|
return "", errors.New("ciphertext too short")
|
|
}
|
|
|
|
// Extract nonce and ciphertext
|
|
nonce, ciphertext_bytes := data[:nonceSize], data[nonceSize:]
|
|
|
|
// Decrypt
|
|
plaintext, err := cm.aesGCM.Open(nil, nonce, ciphertext_bytes, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to decrypt: %w", err)
|
|
}
|
|
|
|
return string(plaintext), nil
|
|
}
|
|
|
|
// GetSecureValue gets a secure value from environment with fallback to encrypted storage
|
|
func (cm *ConfigManager) GetSecureValue(key string) (string, error) {
|
|
// First try environment variable
|
|
if value := os.Getenv(key); value != "" {
|
|
return value, nil
|
|
}
|
|
|
|
// Try encrypted environment variable
|
|
encryptedKey := key + "_ENCRYPTED"
|
|
if encryptedValue := os.Getenv(encryptedKey); encryptedValue != "" {
|
|
return cm.DecryptValue(encryptedValue)
|
|
}
|
|
|
|
return "", fmt.Errorf("secure value not found for key: %s", key)
|
|
}
|
|
|
|
// SecureConfig holds encrypted configuration values
|
|
type SecureConfig struct {
|
|
manager *ConfigManager
|
|
values map[string]string
|
|
}
|
|
|
|
// NewSecureConfig creates a new secure configuration
|
|
func NewSecureConfig(manager *ConfigManager) *SecureConfig {
|
|
return &SecureConfig{
|
|
manager: manager,
|
|
values: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
// Set stores a value securely
|
|
func (sc *SecureConfig) Set(key, value string) error {
|
|
encrypted, err := sc.manager.EncryptValue(value)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt value for key %s: %w", key, err)
|
|
}
|
|
|
|
sc.values[key] = encrypted
|
|
return nil
|
|
}
|
|
|
|
// Get retrieves a value securely
|
|
func (sc *SecureConfig) Get(key string) (string, error) {
|
|
// Check local encrypted storage first
|
|
if encrypted, exists := sc.values[key]; exists {
|
|
return sc.manager.DecryptValue(encrypted)
|
|
}
|
|
|
|
// Fallback to secure environment lookup
|
|
return sc.manager.GetSecureValue(key)
|
|
}
|
|
|
|
// GetRequired retrieves a required value, returning error if not found
|
|
func (sc *SecureConfig) GetRequired(key string) (string, error) {
|
|
value, err := sc.Get(key)
|
|
if err != nil {
|
|
return "", fmt.Errorf("required configuration value missing: %s", key)
|
|
}
|
|
|
|
if strings.TrimSpace(value) == "" {
|
|
return "", fmt.Errorf("required configuration value empty: %s", key)
|
|
}
|
|
|
|
return value, nil
|
|
}
|
|
|
|
// GetWithDefault retrieves a value with a default fallback
|
|
func (sc *SecureConfig) GetWithDefault(key, defaultValue string) string {
|
|
value, err := sc.Get(key)
|
|
if err != nil {
|
|
return defaultValue
|
|
}
|
|
return value
|
|
}
|
|
|
|
// LoadFromEnvironment loads configuration from environment variables
|
|
func (sc *SecureConfig) LoadFromEnvironment(keys []string) error {
|
|
for _, key := range keys {
|
|
value, err := sc.manager.GetSecureValue(key)
|
|
if err != nil {
|
|
sc.manager.logger.Warn(fmt.Sprintf("Could not load secure config for %s: %v", key, err))
|
|
continue
|
|
}
|
|
|
|
// Store encrypted in memory
|
|
if err := sc.Set(key, value); err != nil {
|
|
return fmt.Errorf("failed to store secure config for %s: %w", key, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Clear removes all stored values from memory
|
|
func (sc *SecureConfig) Clear() {
|
|
// Zero out the map entries before clearing
|
|
for key := range sc.values {
|
|
// Overwrite with zeros
|
|
sc.values[key] = strings.Repeat("0", len(sc.values[key]))
|
|
delete(sc.values, key)
|
|
}
|
|
}
|
|
|
|
// Validate checks that all required configuration is present
|
|
func (sc *SecureConfig) Validate(requiredKeys []string) error {
|
|
var missingKeys []string
|
|
|
|
for _, key := range requiredKeys {
|
|
if _, err := sc.GetRequired(key); err != nil {
|
|
missingKeys = append(missingKeys, key)
|
|
}
|
|
}
|
|
|
|
if len(missingKeys) > 0 {
|
|
return fmt.Errorf("missing required configuration keys: %s", strings.Join(missingKeys, ", "))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GenerateConfigKey generates a new encryption key for configuration
|
|
func GenerateConfigKey() (string, error) {
|
|
key := make([]byte, 32) // 256-bit key
|
|
if _, err := rand.Read(key); err != nil {
|
|
return "", fmt.Errorf("failed to generate random key: %w", err)
|
|
}
|
|
|
|
return base64.StdEncoding.EncodeToString(key), nil
|
|
}
|
|
|
|
// ConfigValidator provides validation utilities
|
|
type ConfigValidator struct {
|
|
logger *logger.Logger
|
|
}
|
|
|
|
// NewConfigValidator creates a new configuration validator
|
|
func NewConfigValidator(logger *logger.Logger) *ConfigValidator {
|
|
return &ConfigValidator{
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// ValidateURL validates that a URL is properly formatted and uses HTTPS
|
|
func (cv *ConfigValidator) ValidateURL(url string) error {
|
|
if url == "" {
|
|
return errors.New("URL cannot be empty")
|
|
}
|
|
|
|
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "wss://") {
|
|
return errors.New("URL must use HTTPS or WSS protocol")
|
|
}
|
|
|
|
// Additional validation could go here (DNS lookup, connection test, etc.)
|
|
return nil
|
|
}
|
|
|
|
// ValidateAPIKey validates that an API key meets minimum security requirements
|
|
func (cv *ConfigValidator) ValidateAPIKey(key string) error {
|
|
if key == "" {
|
|
return errors.New("API key cannot be empty")
|
|
}
|
|
|
|
if len(key) < 32 {
|
|
return errors.New("API key must be at least 32 characters")
|
|
}
|
|
|
|
// Check for basic entropy (not all same character, contains mixed case, etc.)
|
|
if strings.Count(key, string(key[0])) == len(key) {
|
|
return errors.New("API key lacks sufficient entropy")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidateAddress validates an Ethereum address
|
|
func (cv *ConfigValidator) ValidateAddress(address string) error {
|
|
if address == "" {
|
|
return errors.New("address cannot be empty")
|
|
}
|
|
|
|
if !strings.HasPrefix(address, "0x") {
|
|
return errors.New("address must start with 0x")
|
|
}
|
|
|
|
if len(address) != 42 { // 0x + 40 hex chars
|
|
return errors.New("address must be 42 characters long")
|
|
}
|
|
|
|
// Validate hex format
|
|
for i, char := range address[2:] {
|
|
if !((char >= '0' && char <= '9') || (char >= 'a' && char <= 'f') || (char >= 'A' && char <= 'F')) {
|
|
return fmt.Errorf("invalid hex character at position %d: %c", i+2, char)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|