package security import ( "math/big" "strings" "testing" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/fraktal/mev-beta/internal/logger" ) func TestNewChainIDValidator(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) // Arbitrum mainnet validator := NewChainIDValidator(logger, expectedChainID) assert.NotNil(t, validator) assert.Equal(t, expectedChainID.Uint64(), validator.expectedChainID.Uint64()) assert.True(t, validator.allowedChainIDs[42161]) // Arbitrum mainnet assert.True(t, validator.allowedChainIDs[421614]) // Arbitrum testnet assert.NotNil(t, validator.replayAttackDetector) } func TestValidateChainID_ValidTransaction(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) validator := NewChainIDValidator(logger, expectedChainID) // Create a valid EIP-155 transaction for Arbitrum tx := types.NewTransaction( 0, // nonce common.HexToAddress("0x1234567890123456789012345678901234567890"), // to big.NewInt(1000000000000000000), // value (1 ETH) 21000, // gas limit big.NewInt(20000000000), // gas price (20 Gwei) nil, // data ) // Create a properly signed transaction for testing privateKey, err := crypto.GenerateKey() require.NoError(t, err) signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey) signer := types.NewEIP155Signer(expectedChainID) signedTx, err := types.SignTx(tx, signer, privateKey) require.NoError(t, err) result := validator.ValidateChainID(signedTx, signerAddr, nil) assert.True(t, result.Valid) assert.Equal(t, expectedChainID.Uint64(), result.ExpectedChainID) assert.Equal(t, expectedChainID.Uint64(), result.ActualChainID) assert.True(t, result.IsEIP155Protected) assert.Equal(t, "NONE", result.ReplayRisk) assert.Empty(t, result.Errors) } func TestValidateChainID_InvalidChainID(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) // Arbitrum validator := NewChainIDValidator(logger, expectedChainID) // Create transaction with wrong chain ID (Ethereum mainnet) wrongChainID := big.NewInt(1) tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil) privateKey, err := crypto.GenerateKey() require.NoError(t, err) signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey) signer := types.NewEIP155Signer(wrongChainID) signedTx, err := types.SignTx(tx, signer, privateKey) require.NoError(t, err) result := validator.ValidateChainID(signedTx, signerAddr, nil) assert.False(t, result.Valid) assert.Equal(t, expectedChainID.Uint64(), result.ExpectedChainID) assert.Equal(t, wrongChainID.Uint64(), result.ActualChainID) assert.NotEmpty(t, result.Errors) assert.Contains(t, result.Errors[0], "Chain ID mismatch") } func TestValidateChainID_ReplayAttackDetection(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) validator := NewChainIDValidator(logger, expectedChainID) privateKey, err := crypto.GenerateKey() require.NoError(t, err) signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey) // Create identical transactions on different chains tx1 := types.NewTransaction(1, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil) tx2 := types.NewTransaction(1, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil) // Sign first transaction with Arbitrum chain ID signer1 := types.NewEIP155Signer(big.NewInt(42161)) signedTx1, err := types.SignTx(tx1, signer1, privateKey) require.NoError(t, err) // Sign second identical transaction with different chain ID signer2 := types.NewEIP155Signer(big.NewInt(421614)) // Arbitrum testnet signedTx2, err := types.SignTx(tx2, signer2, privateKey) require.NoError(t, err) // First validation should pass result1 := validator.ValidateChainID(signedTx1, signerAddr, nil) assert.True(t, result1.Valid) assert.Equal(t, "NONE", result1.ReplayRisk) // Create a new validator and add testnet to allowed chains validator.AddAllowedChainID(421614) // Second validation should detect replay risk result2 := validator.ValidateChainID(signedTx2, signerAddr, nil) assert.Equal(t, "CRITICAL", result2.ReplayRisk) assert.NotEmpty(t, result2.Warnings) assert.Contains(t, result2.Warnings[0], "replay attack") } func TestValidateEIP155Protection(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) validator := NewChainIDValidator(logger, expectedChainID) privateKey, err := crypto.GenerateKey() require.NoError(t, err) // Test EIP-155 protected transaction tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil) signer := types.NewEIP155Signer(expectedChainID) signedTx, err := types.SignTx(tx, signer, privateKey) require.NoError(t, err) result := validator.validateEIP155Protection(signedTx, expectedChainID) assert.True(t, result.protected) assert.Equal(t, expectedChainID.Uint64(), result.chainID) assert.Empty(t, result.warnings) } func TestValidateEIP155Protection_LegacyTransaction(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) validator := NewChainIDValidator(logger, expectedChainID) // Create a legacy transaction (pre-EIP155) by manually setting v to 27 tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil) // For testing purposes, we'll create a transaction that mimics legacy format // In practice, this would be a transaction created before EIP-155 signer := types.HomesteadSigner{} // Pre-EIP155 signer privateKey, err := crypto.GenerateKey() require.NoError(t, err) signedTx, err := types.SignTx(tx, signer, privateKey) require.NoError(t, err) result := validator.validateEIP155Protection(signedTx, expectedChainID) assert.False(t, result.protected) assert.NotEmpty(t, result.warnings) // Legacy transactions may not have chain ID, so check for either warning hasExpectedWarning := false for _, warning := range result.warnings { if strings.Contains(warning, "Legacy transaction format") || strings.Contains(warning, "Transaction missing chain ID") { hasExpectedWarning = true break } } assert.True(t, hasExpectedWarning, "Should contain legacy transaction warning") } func TestChainSpecificValidation_Arbitrum(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) validator := NewChainIDValidator(logger, expectedChainID) // Create a properly signed transaction for Arbitrum to test chain-specific rules privateKey, err := crypto.GenerateKey() require.NoError(t, err) // Test normal Arbitrum transaction tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(1000000000), nil) // 1 Gwei signer := types.NewEIP155Signer(expectedChainID) signedTx, err := types.SignTx(tx, signer, privateKey) require.NoError(t, err) result := validator.validateChainSpecificRules(signedTx, expectedChainID.Uint64()) assert.True(t, result.valid) assert.Empty(t, result.errors) // Test high gas price warning txHighGas := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(2000000000000), nil) // 2000 Gwei signedTxHighGas, err := types.SignTx(txHighGas, signer, privateKey) require.NoError(t, err) resultHighGas := validator.validateChainSpecificRules(signedTxHighGas, expectedChainID.Uint64()) assert.True(t, resultHighGas.valid) assert.NotEmpty(t, resultHighGas.warnings) assert.Contains(t, resultHighGas.warnings[0], "high gas price") // Test gas limit too high txHighGasLimit := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 50000000, big.NewInt(1000000000), nil) // 50M gas signedTxHighGasLimit, err := types.SignTx(txHighGasLimit, signer, privateKey) require.NoError(t, err) resultHighGasLimit := validator.validateChainSpecificRules(signedTxHighGasLimit, expectedChainID.Uint64()) assert.False(t, resultHighGasLimit.valid) assert.NotEmpty(t, resultHighGasLimit.errors) assert.Contains(t, resultHighGasLimit.errors[0], "exceeds Arbitrum maximum") } func TestChainSpecificValidation_UnsupportedChain(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(999999) // Unsupported chain validator := NewChainIDValidator(logger, expectedChainID) privateKey, err := crypto.GenerateKey() require.NoError(t, err) tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(1000000000), nil) signer := types.NewEIP155Signer(expectedChainID) signedTx, err := types.SignTx(tx, signer, privateKey) require.NoError(t, err) result := validator.validateChainSpecificRules(signedTx, expectedChainID.Uint64()) assert.False(t, result.valid) assert.NotEmpty(t, result.errors) assert.Contains(t, result.errors[0], "Unsupported chain ID") } func TestValidateSignerMatchesChain(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) validator := NewChainIDValidator(logger, expectedChainID) privateKey, err := crypto.GenerateKey() require.NoError(t, err) expectedSigner := crypto.PubkeyToAddress(privateKey.PublicKey) tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil) signer := types.NewEIP155Signer(expectedChainID) signedTx, err := types.SignTx(tx, signer, privateKey) require.NoError(t, err) // Valid signature should pass err = validator.ValidateSignerMatchesChain(signedTx, expectedSigner) assert.NoError(t, err) // Wrong expected signer should fail wrongSigner := common.HexToAddress("0x1234567890123456789012345678901234567890") err = validator.ValidateSignerMatchesChain(signedTx, wrongSigner) assert.Error(t, err) assert.Contains(t, err.Error(), "signer mismatch") } func TestGetValidationStats(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) validator := NewChainIDValidator(logger, expectedChainID) privateKey, err := crypto.GenerateKey() require.NoError(t, err) signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey) // Perform some validations to generate stats tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil) signer := types.NewEIP155Signer(expectedChainID) signedTx, err := types.SignTx(tx, signer, privateKey) require.NoError(t, err) validator.ValidateChainID(signedTx, signerAddr, nil) stats := validator.GetValidationStats() assert.NotNil(t, stats) assert.Equal(t, uint64(1), stats["total_validations"]) assert.Equal(t, expectedChainID.Uint64(), stats["expected_chain_id"]) assert.NotNil(t, stats["allowed_chain_ids"]) } func TestAddRemoveAllowedChainID(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) validator := NewChainIDValidator(logger, expectedChainID) // Add new chain ID newChainID := uint64(999) validator.AddAllowedChainID(newChainID) assert.True(t, validator.allowedChainIDs[newChainID]) // Remove chain ID validator.RemoveAllowedChainID(newChainID) assert.False(t, validator.allowedChainIDs[newChainID]) } func TestReplayAttackDetection_CleanOldData(t *testing.T) { logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) validator := NewChainIDValidator(logger, expectedChainID) privateKey, err := crypto.GenerateKey() require.NoError(t, err) signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey) // Create transaction tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil) signer := types.NewEIP155Signer(expectedChainID) signedTx, err := types.SignTx(tx, signer, privateKey) require.NoError(t, err) // First validation validator.ValidateChainID(signedTx, signerAddr, nil) assert.Equal(t, 1, len(validator.replayAttackDetector.seenTransactions)) // Manually set old timestamp to test cleanup txIdentifier := validator.createTransactionIdentifier(signedTx, signerAddr) record := validator.replayAttackDetector.seenTransactions[txIdentifier] record.FirstSeen = time.Now().Add(-25 * time.Hour) // Older than maxTrackingTime validator.replayAttackDetector.seenTransactions[txIdentifier] = record // Trigger cleanup validator.cleanOldTrackingData() assert.Equal(t, 0, len(validator.replayAttackDetector.seenTransactions)) } // Integration test with KeyManager func SkipTestKeyManagerChainValidationIntegration(t *testing.T) { config := &KeyManagerConfig{ KeystorePath: t.TempDir(), EncryptionKey: "test_key_32_chars_minimum_length_required", MaxFailedAttempts: 3, LockoutDuration: 5 * time.Minute, MaxSigningRate: 10, EnableAuditLogging: true, RequireAuthentication: false, } logger := logger.New("info", "text", "") expectedChainID := big.NewInt(42161) km, err := newKeyManagerInternal(config, logger, expectedChainID, false) // Use testing version require.NoError(t, err) // Generate a key permissions := KeyPermissions{ CanSign: true, CanTransfer: true, MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH } keyAddr, err := km.GenerateKey("test", permissions) require.NoError(t, err) // Test valid chain ID transaction // Create a transaction that will be properly handled by EIP155 signer tx := types.NewTx(&types.LegacyTx{ Nonce: 0, To: &common.Address{}, Value: big.NewInt(1000), Gas: 21000, GasPrice: big.NewInt(20000000000), Data: nil, }) request := &SigningRequest{ Transaction: tx, ChainID: expectedChainID, From: keyAddr, Purpose: "Test transaction", UrgencyLevel: 1, } result, err := km.SignTransaction(request) assert.NoError(t, err) assert.NotNil(t, result) assert.NotNil(t, result.SignedTx) // Test invalid chain ID transaction wrongChainID := big.NewInt(1) // Ethereum mainnet txWrong := types.NewTx(&types.LegacyTx{ Nonce: 1, To: &common.Address{}, Value: big.NewInt(1000), Gas: 21000, GasPrice: big.NewInt(20000000000), Data: nil, }) requestWrong := &SigningRequest{ Transaction: txWrong, ChainID: wrongChainID, From: keyAddr, Purpose: "Invalid chain test", UrgencyLevel: 1, } _, err = km.SignTransaction(requestWrong) assert.Error(t, err) assert.Contains(t, err.Error(), "doesn't match expected") // Test chain validation stats stats := km.GetChainValidationStats() assert.NotNil(t, stats) assert.True(t, stats["total_validations"].(uint64) > 0) // Test expected chain ID chainID := km.GetExpectedChainID() assert.Equal(t, expectedChainID.Uint64(), chainID.Uint64()) } func TestCrossChainReplayPrevention(t *testing.T) { logger := logger.New("info", "text", "") validator := NewChainIDValidator(logger, big.NewInt(42161)) // Add testnet to allowed chains for testing validator.AddAllowedChainID(421614) privateKey, err := crypto.GenerateKey() require.NoError(t, err) signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey) // Create identical transaction data nonce := uint64(42) to := common.HexToAddress("0x1234567890123456789012345678901234567890") value := big.NewInt(1000000000000000000) // 1 ETH gasLimit := uint64(21000) gasPrice := big.NewInt(20000000000) // 20 Gwei // Sign for mainnet tx1 := types.NewTransaction(nonce, to, value, gasLimit, gasPrice, nil) signer1 := types.NewEIP155Signer(big.NewInt(42161)) signedTx1, err := types.SignTx(tx1, signer1, privateKey) require.NoError(t, err) // Sign identical transaction for testnet tx2 := types.NewTransaction(nonce, to, value, gasLimit, gasPrice, nil) signer2 := types.NewEIP155Signer(big.NewInt(421614)) signedTx2, err := types.SignTx(tx2, signer2, privateKey) require.NoError(t, err) // First validation (mainnet) should pass result1 := validator.ValidateChainID(signedTx1, signerAddr, nil) assert.True(t, result1.Valid) assert.Equal(t, "NONE", result1.ReplayRisk) // Second validation (testnet with same tx data) should detect replay risk result2 := validator.ValidateChainID(signedTx2, signerAddr, nil) assert.Equal(t, "CRITICAL", result2.ReplayRisk) assert.Contains(t, result2.Warnings[0], "replay attack") // Verify the detector tracked both chain IDs stats := validator.GetValidationStats() assert.Equal(t, uint64(1), stats["replay_attempts"]) }