//go:build integration && legacy && forked // +build integration,legacy,forked package test_main import ( "fmt" "math/big" "os" "strings" "testing" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/fraktal/mev-beta/internal/logger" "github.com/fraktal/mev-beta/pkg/security" ) // TestSecurityVulnerabilityFixes validates that critical security issues have been addressed func TestSecurityVulnerabilityFixes(t *testing.T) { // Set secure encryption key for testing os.Setenv("MEV_BOT_ENCRYPTION_KEY", "test-secure-encryption-key-32-chars") defer os.Unsetenv("MEV_BOT_ENCRYPTION_KEY") t.Run("NoHardcodedEncryptionKeys", func(t *testing.T) { // Verify that empty encryption key fails os.Unsetenv("MEV_BOT_ENCRYPTION_KEY") keyManagerConfig := &security.KeyManagerConfig{ KeystorePath: "test_keystore", EncryptionKey: "", // Should fail with empty key KeyRotationDays: 30, MaxSigningRate: 100, SessionTimeout: time.Hour, AuditLogPath: "test_audit.log", BackupPath: "test_backups", } log := logger.New("debug", "text", "") _, err := security.NewKeyManager(keyManagerConfig, log) // Should fail without encryption key assert.Error(t, err) assert.Contains(t, err.Error(), "encryption key") // Restore for other tests os.Setenv("MEV_BOT_ENCRYPTION_KEY", "test-secure-encryption-key-32-chars") }) t.Run("SecureKeyGeneration", func(t *testing.T) { keyManagerConfig := &security.KeyManagerConfig{ KeystorePath: "test_keystore", EncryptionKey: os.Getenv("MEV_BOT_ENCRYPTION_KEY"), KeyRotationDays: 30, MaxSigningRate: 100, SessionTimeout: time.Hour, AuditLogPath: "test_audit.log", BackupPath: "test_backups", } log := logger.New("debug", "text", "") keyManager, err := security.NewKeyManager(keyManagerConfig, log) require.NoError(t, err) // Test key generation with proper permissions permissions := security.KeyPermissions{ CanSign: true, CanTransfer: true, MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH AllowedContracts: []string{}, RequireConfirm: false, } address, err := keyManager.GenerateKey("trading", permissions) require.NoError(t, err) assert.NotEqual(t, common.Address{}, address) // Verify we can retrieve the active private key privateKey, err := keyManager.GetActivePrivateKey() require.NoError(t, err) assert.NotNil(t, privateKey) // Verify it's a valid ECDSA private key (privateKey is already *ecdsa.PrivateKey) assert.NotNil(t, privateKey.D, "Private key should have a valid D component") // Verify the key is not a hardcoded test key hardcodedKey, _ := crypto.HexToECDSA("0000000000000000000000000000000000000000000000000000000000000001") assert.NotEqual(t, hardcodedKey.D.String(), privateKey.D.String(), "Private key should not be hardcoded") // Clean up test files os.RemoveAll("test_keystore") os.Remove("test_audit.log") os.RemoveAll("test_backups") }) t.Run("RandomSaltGeneration", func(t *testing.T) { // This test would require exposing the salt generation function // or checking that keys generated with the same master key are different // due to different salts keyManagerConfig := &security.KeyManagerConfig{ KeystorePath: "test_keystore_1", EncryptionKey: "test-encryption-key-32-characters", KeyRotationDays: 30, MaxSigningRate: 100, SessionTimeout: time.Hour, AuditLogPath: "test_audit_1.log", BackupPath: "test_backups_1", } log := logger.New("debug", "text", "") keyManager1, err := security.NewKeyManager(keyManagerConfig, log) require.NoError(t, err) // Generate a key permissions := security.KeyPermissions{ CanSign: true, CanTransfer: true, MaxTransferWei: big.NewInt(1000000000000000000), AllowedContracts: []string{}, RequireConfirm: false, } address1, err := keyManager1.GenerateKey("trading", permissions) require.NoError(t, err) // Create second key manager with same master key keyManagerConfig.KeystorePath = "test_keystore_2" keyManagerConfig.AuditLogPath = "test_audit_2.log" keyManagerConfig.BackupPath = "test_backups_2" keyManager2, err := security.NewKeyManager(keyManagerConfig, log) require.NoError(t, err) address2, err := keyManager2.GenerateKey("trading", permissions) require.NoError(t, err) // Different key managers with same master key should generate different addresses // due to random salt usage assert.NotEqual(t, address1.Hex(), address2.Hex(), "Keys should be different due to random salt") // Clean up test files os.RemoveAll("test_keystore_1") os.RemoveAll("test_keystore_2") os.Remove("test_audit_1.log") os.Remove("test_audit_2.log") os.RemoveAll("test_backups_1") os.RemoveAll("test_backups_2") }) t.Run("KeyPermissionsEnforcement", func(t *testing.T) { keyManagerConfig := &security.KeyManagerConfig{ KeystorePath: "test_keystore", EncryptionKey: os.Getenv("MEV_BOT_ENCRYPTION_KEY"), KeyRotationDays: 30, MaxSigningRate: 100, SessionTimeout: time.Hour, AuditLogPath: "test_audit.log", BackupPath: "test_backups", } log := logger.New("debug", "text", "") keyManager, err := security.NewKeyManager(keyManagerConfig, log) require.NoError(t, err) // Create restrictive permissions restrictivePermissions := security.KeyPermissions{ CanSign: false, // Cannot sign CanTransfer: false, // Cannot transfer MaxTransferWei: big.NewInt(100000), // Very low limit AllowedContracts: []string{}, // No contracts allowed RequireConfirm: true, // Requires confirmation } address, err := keyManager.GenerateKey("restricted", restrictivePermissions) require.NoError(t, err) assert.NotEqual(t, common.Address{}, address) // Verify permissions are stored correctly // Note: This would require exposing permission checking methods // For now, we just verify the key was created successfully // Clean up test files os.RemoveAll("test_keystore") os.Remove("test_audit.log") os.RemoveAll("test_backups") }) } // TestInputValidationSecurity validates input validation fixes func TestInputValidationSecurity(t *testing.T) { t.Run("AmountValidation", func(t *testing.T) { // Test zero amount err := validateAmount(big.NewInt(0)) assert.Error(t, err) assert.Contains(t, err.Error(), "must be greater than zero") // Test negative amount err = validateAmount(big.NewInt(-1)) assert.Error(t, err) assert.Contains(t, err.Error(), "must be greater than zero") // Test excessive amount (potential overflow) excessiveAmount := new(big.Int).Exp(big.NewInt(10), big.NewInt(30), nil) err = validateAmount(excessiveAmount) assert.Error(t, err) assert.Contains(t, err.Error(), "exceeds maximum allowed value") // Test valid amount validAmount := big.NewInt(1000000000000000000) // 1 ETH err = validateAmount(validAmount) assert.NoError(t, err) }) t.Run("AddressValidation", func(t *testing.T) { // Test zero address zeroAddr := common.Address{} assert.False(t, isValidAddress(zeroAddr), "Zero address should be invalid") // Test valid address validAddr := common.HexToAddress("0x1234567890123456789012345678901234567890") assert.True(t, isValidAddress(validAddr), "Valid address should pass validation") }) } // Helper functions for validation (these should be implemented in the actual codebase) func validateAmount(amount *big.Int) error { if amount == nil || amount.Sign() <= 0 { return fmt.Errorf("amount must be greater than zero") } // Check for maximum amount to prevent overflow (more conservative limit) maxAmount := new(big.Int).Exp(big.NewInt(10), big.NewInt(28), nil) // 10^28 wei if amount.Cmp(maxAmount) > 0 { return fmt.Errorf("amount exceeds maximum allowed value") } return nil } func isValidAddress(addr common.Address) bool { return addr != (common.Address{}) } // TestRPCEndpointValidation validates RPC security fixes func TestRPCEndpointValidation(t *testing.T) { testCases := []struct { name string endpoint string shouldError bool errorMsg string }{ { name: "Valid HTTPS endpoint", endpoint: "https://arbitrum-mainnet.core.chainstack.com/test", shouldError: false, }, { name: "Valid WSS endpoint", endpoint: "wss://arbitrum-mainnet.core.chainstack.com/test", shouldError: false, }, { name: "Empty endpoint", endpoint: "", shouldError: true, errorMsg: "cannot be empty", }, { name: "Invalid scheme", endpoint: "ftp://invalid.com", shouldError: true, errorMsg: "invalid RPC scheme", }, { name: "Localhost without override", endpoint: "http://localhost:8545", shouldError: true, errorMsg: "localhost RPC endpoints not allowed", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Clear localhost override os.Unsetenv("MEV_BOT_ALLOW_LOCALHOST") err := validateRPCEndpoint(tc.endpoint) if tc.shouldError { assert.Error(t, err) if tc.errorMsg != "" && err != nil { assert.Contains(t, err.Error(), tc.errorMsg) } } else { assert.NoError(t, err) } }) } // Test localhost with override t.Run("Localhost with override", func(t *testing.T) { os.Setenv("MEV_BOT_ALLOW_LOCALHOST", "true") defer os.Unsetenv("MEV_BOT_ALLOW_LOCALHOST") err := validateRPCEndpoint("http://localhost:8545") assert.NoError(t, err) }) } // validateRPCEndpoint - simplified version for testing func validateRPCEndpoint(endpoint string) error { if endpoint == "" { return fmt.Errorf("RPC endpoint cannot be empty") } // Parse the URL to validate the scheme if endpoint[0] == ':' || endpoint[0] == '/' { return fmt.Errorf("invalid URL scheme") } // Check for valid schemes if !(strings.HasPrefix(endpoint, "https://") || strings.HasPrefix(endpoint, "wss://") || strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "ws://")) { return fmt.Errorf("invalid RPC scheme") } // Check for localhost restrictions if strings.Contains(endpoint, "localhost") || strings.Contains(endpoint, "127.0.0.1") { allowLocalhost := os.Getenv("MEV_BOT_ALLOW_LOCALHOST") if allowLocalhost != "true" { return fmt.Errorf("localhost RPC endpoints not allowed") } } return nil }