package security import ( "fmt" "math/big" "sync" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/fraktal/mev-beta/internal/logger" ) // ChainIDValidator provides comprehensive chain ID validation and EIP-155 replay protection type ChainIDValidator struct { logger *logger.Logger expectedChainID *big.Int allowedChainIDs map[uint64]bool replayAttackDetector *ReplayAttackDetector mu sync.RWMutex // Chain ID validation statistics validationCount uint64 mismatchCount uint64 replayAttemptCount uint64 lastMismatchTime time.Time } func (cv *ChainIDValidator) normalizeChainID(txChainID *big.Int, override *big.Int) *big.Int { if override != nil { // Use override when transaction chain ID is missing or placeholder if isPlaceholderChainID(txChainID) { return new(big.Int).Set(override) } } if isPlaceholderChainID(txChainID) { return new(big.Int).Set(cv.expectedChainID) } return new(big.Int).Set(txChainID) } func isPlaceholderChainID(id *big.Int) bool { if id == nil || id.Sign() == 0 { return true } // Treat extremely large values (legacy placeholder) as missing if id.BitLen() >= 62 { return true } return false } // ReplayAttackDetector tracks potential replay attacks type ReplayAttackDetector struct { // Track transaction hashes across different chain IDs to detect replay attempts seenTransactions map[string]ChainIDRecord maxTrackingTime time.Duration mu sync.Mutex } // ChainIDRecord stores information about a transaction's chain ID usage type ChainIDRecord struct { ChainID uint64 FirstSeen time.Time Count int From common.Address AlertTriggered bool } // ChainValidationResult contains comprehensive chain ID validation results type ChainValidationResult struct { Valid bool `json:"valid"` ExpectedChainID uint64 `json:"expected_chain_id"` ActualChainID uint64 `json:"actual_chain_id"` IsEIP155Protected bool `json:"is_eip155_protected"` ReplayRisk string `json:"replay_risk"` // NONE, LOW, MEDIUM, HIGH, CRITICAL Warnings []string `json:"warnings"` Errors []string `json:"errors"` SecurityMetadata map[string]interface{} `json:"security_metadata"` } // NewChainIDValidator creates a new chain ID validator func NewChainIDValidator(logger *logger.Logger, expectedChainID *big.Int) *ChainIDValidator { return &ChainIDValidator{ logger: logger, expectedChainID: expectedChainID, allowedChainIDs: map[uint64]bool{ 1: true, // Ethereum mainnet (for testing) 42161: true, // Arbitrum One mainnet 421614: true, // Arbitrum Sepolia testnet (for testing) }, replayAttackDetector: &ReplayAttackDetector{ seenTransactions: make(map[string]ChainIDRecord), maxTrackingTime: 24 * time.Hour, // Track for 24 hours }, } } // ValidateChainID performs comprehensive chain ID validation func (cv *ChainIDValidator) ValidateChainID(tx *types.Transaction, signerAddr common.Address, overrideChainID *big.Int) *ChainValidationResult { actualChainID := cv.normalizeChainID(tx.ChainId(), overrideChainID) result := &ChainValidationResult{ Valid: true, ExpectedChainID: cv.expectedChainID.Uint64(), ActualChainID: actualChainID.Uint64(), SecurityMetadata: make(map[string]interface{}), } cv.mu.Lock() defer cv.mu.Unlock() cv.validationCount++ // 1. Basic Chain ID Validation if actualChainID.Uint64() != cv.expectedChainID.Uint64() { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("Chain ID mismatch: expected %d, got %d", cv.expectedChainID.Uint64(), actualChainID.Uint64())) cv.mismatchCount++ cv.lastMismatchTime = time.Now() // Log security alert cv.logger.Warn(fmt.Sprintf("SECURITY ALERT: Chain ID mismatch detected from %s - Expected: %d, Got: %d", signerAddr.Hex(), cv.expectedChainID.Uint64(), actualChainID.Uint64())) } // 2. EIP-155 Replay Protection Verification eip155Result := cv.validateEIP155Protection(tx, actualChainID) result.IsEIP155Protected = eip155Result.protected if !eip155Result.protected { result.Warnings = append(result.Warnings, "Transaction lacks EIP-155 replay protection") result.ReplayRisk = "HIGH" } else { result.ReplayRisk = "NONE" } // 3. Chain ID Allowlist Validation if !cv.allowedChainIDs[actualChainID.Uint64()] { result.Valid = false result.Errors = append(result.Errors, fmt.Sprintf("Chain ID %d is not in the allowed list", actualChainID.Uint64())) cv.logger.Error(fmt.Sprintf("SECURITY ALERT: Attempted transaction on unauthorized chain %d from %s", actualChainID.Uint64(), signerAddr.Hex())) } // 4. Replay Attack Detection replayResult := cv.detectReplayAttack(tx, signerAddr, actualChainID.Uint64()) if replayResult.riskLevel != "NONE" { result.ReplayRisk = replayResult.riskLevel result.Warnings = append(result.Warnings, replayResult.warnings...) if replayResult.riskLevel == "CRITICAL" { result.Valid = false result.Errors = append(result.Errors, "Potential replay attack detected") } } // 5. Chain-specific Validation chainSpecificResult := cv.validateChainSpecificRules(tx, actualChainID.Uint64()) if !chainSpecificResult.valid { result.Errors = append(result.Errors, chainSpecificResult.errors...) result.Valid = false } result.Warnings = append(result.Warnings, chainSpecificResult.warnings...) // 6. Add security metadata result.SecurityMetadata["validation_timestamp"] = time.Now().Unix() result.SecurityMetadata["total_validations"] = cv.validationCount result.SecurityMetadata["total_mismatches"] = cv.mismatchCount result.SecurityMetadata["signer_address"] = signerAddr.Hex() result.SecurityMetadata["transaction_hash"] = tx.Hash().Hex() // Log validation result for audit if !result.Valid { cv.logger.Error(fmt.Sprintf("Chain validation FAILED for tx %s from %s: %v", tx.Hash().Hex(), signerAddr.Hex(), result.Errors)) } return result } // EIP155Result contains EIP-155 validation results type EIP155Result struct { protected bool chainID uint64 warnings []string } // validateEIP155Protection verifies EIP-155 replay protection is properly implemented func (cv *ChainIDValidator) validateEIP155Protection(tx *types.Transaction, normalizedChainID *big.Int) EIP155Result { result := EIP155Result{ protected: false, warnings: make([]string, 0), } // Check if transaction has a valid chain ID (EIP-155 requirement) if isPlaceholderChainID(tx.ChainId()) { result.warnings = append(result.warnings, "Transaction missing chain ID (pre-EIP155)") return result } chainID := normalizedChainID.Uint64() result.chainID = chainID // Verify the transaction signature includes chain ID protection // EIP-155 requires v = CHAIN_ID * 2 + 35 or v = CHAIN_ID * 2 + 36 v, _, _ := tx.RawSignatureValues() // Calculate expected v values for EIP-155 expectedV1 := chainID*2 + 35 expectedV2 := chainID*2 + 36 actualV := v.Uint64() // Check if v value follows EIP-155 format if actualV == expectedV1 || actualV == expectedV2 { result.protected = true } else { // Check if it's a legacy transaction (v = 27 or 28) if actualV == 27 || actualV == 28 { result.warnings = append(result.warnings, "Legacy transaction format detected (not EIP-155 protected)") } else { result.warnings = append(result.warnings, fmt.Sprintf("Invalid v value for EIP-155: got %d, expected %d or %d", actualV, expectedV1, expectedV2)) } } return result } // ReplayResult contains replay attack detection results type ReplayResult struct { riskLevel string warnings []string } // detectReplayAttack detects potential cross-chain replay attacks func (cv *ChainIDValidator) detectReplayAttack(tx *types.Transaction, signerAddr common.Address, normalizedChainID uint64) ReplayResult { result := ReplayResult{ riskLevel: "NONE", warnings: make([]string, 0), } // Clean old tracking data cv.cleanOldTrackingData() // Create a canonical transaction representation for tracking // Use a combination of nonce, to, value, and data to identify potential replays txIdentifier := cv.createTransactionIdentifier(tx, signerAddr) detector := cv.replayAttackDetector detector.mu.Lock() defer detector.mu.Unlock() if record, exists := detector.seenTransactions[txIdentifier]; exists { // This transaction pattern has been seen before currentChainID := normalizedChainID if record.ChainID != currentChainID { // Same transaction on different chain - CRITICAL replay risk result.riskLevel = "CRITICAL" result.warnings = append(result.warnings, fmt.Sprintf("Identical transaction detected on chain %d and %d - possible replay attack", record.ChainID, currentChainID)) cv.replayAttackDetector.seenTransactions[txIdentifier] = ChainIDRecord{ ChainID: currentChainID, FirstSeen: record.FirstSeen, Count: record.Count + 1, From: signerAddr, AlertTriggered: true, } cv.replayAttemptCount++ cv.logger.Error(fmt.Sprintf("CRITICAL SECURITY ALERT: Potential replay attack detected! "+ "Transaction %s from %s seen on chains %d and %d", txIdentifier, signerAddr.Hex(), record.ChainID, currentChainID)) } else { // Same transaction on same chain - possible retry or duplicate record.Count++ if record.Count > 3 { result.riskLevel = "MEDIUM" result.warnings = append(result.warnings, "Multiple identical transactions detected") } detector.seenTransactions[txIdentifier] = record } } else { // First time seeing this transaction detector.seenTransactions[txIdentifier] = ChainIDRecord{ ChainID: normalizedChainID, FirstSeen: time.Now(), Count: 1, From: signerAddr, AlertTriggered: false, } } return result } // ChainSpecificResult contains chain-specific validation results type ChainSpecificResult struct { valid bool warnings []string errors []string } // validateChainSpecificRules applies chain-specific validation rules func (cv *ChainIDValidator) validateChainSpecificRules(tx *types.Transaction, chainID uint64) ChainSpecificResult { result := ChainSpecificResult{ valid: true, warnings: make([]string, 0), errors: make([]string, 0), } switch chainID { case 42161: // Arbitrum One // Arbitrum-specific validations if tx.GasPrice() != nil && tx.GasPrice().Cmp(big.NewInt(1000000000000)) > 0 { // 1000 Gwei result.warnings = append(result.warnings, "Unusually high gas price for Arbitrum") } // Check for Arbitrum-specific gas limits if tx.Gas() > 32000000 { // Arbitrum block gas limit result.valid = false result.errors = append(result.errors, "Gas limit exceeds Arbitrum maximum") } case 421614: // Arbitrum Sepolia testnet // Testnet-specific validations if tx.Value() != nil && tx.Value().Cmp(new(big.Int).Mul(big.NewInt(100), big.NewInt(1e18))) > 0 { // 100 ETH result.warnings = append(result.warnings, "Large value transfer on testnet") } default: // Unknown or unsupported chain result.valid = false result.errors = append(result.errors, fmt.Sprintf("Unsupported chain ID: %d", chainID)) } return result } // createTransactionIdentifier creates a canonical identifier for transaction tracking func (cv *ChainIDValidator) createTransactionIdentifier(tx *types.Transaction, signerAddr common.Address) string { // Create identifier from key transaction fields that would be identical in a replay var toAddr string if tx.To() != nil { toAddr = tx.To().Hex() } else { toAddr = "0x0" // Contract creation } // Combine nonce, to, value, and first 32 bytes of data dataPrefix := "" if len(tx.Data()) > 0 { end := 32 if len(tx.Data()) < 32 { end = len(tx.Data()) } dataPrefix = common.Bytes2Hex(tx.Data()[:end]) } return fmt.Sprintf("%s:%d:%s:%s:%s", signerAddr.Hex(), tx.Nonce(), toAddr, tx.Value().String(), dataPrefix) } // cleanOldTrackingData removes old transaction tracking data func (cv *ChainIDValidator) cleanOldTrackingData() { detector := cv.replayAttackDetector detector.mu.Lock() defer detector.mu.Unlock() cutoff := time.Now().Add(-detector.maxTrackingTime) for identifier, record := range detector.seenTransactions { if record.FirstSeen.Before(cutoff) { delete(detector.seenTransactions, identifier) } } } // GetValidationStats returns validation statistics func (cv *ChainIDValidator) GetValidationStats() map[string]interface{} { cv.mu.RLock() defer cv.mu.RUnlock() detector := cv.replayAttackDetector detector.mu.Lock() trackingEntries := len(detector.seenTransactions) detector.mu.Unlock() return map[string]interface{}{ "total_validations": cv.validationCount, "chain_id_mismatches": cv.mismatchCount, "replay_attempts": cv.replayAttemptCount, "last_mismatch_time": cv.lastMismatchTime.Unix(), "expected_chain_id": cv.expectedChainID.Uint64(), "allowed_chain_ids": cv.getAllowedChainIDs(), "tracking_entries": trackingEntries, } } // getAllowedChainIDs returns a slice of allowed chain IDs func (cv *ChainIDValidator) getAllowedChainIDs() []uint64 { cv.mu.RLock() defer cv.mu.RUnlock() chainIDs := make([]uint64, 0, len(cv.allowedChainIDs)) for chainID := range cv.allowedChainIDs { chainIDs = append(chainIDs, chainID) } return chainIDs } // AddAllowedChainID adds a chain ID to the allowed list func (cv *ChainIDValidator) AddAllowedChainID(chainID uint64) { cv.mu.Lock() defer cv.mu.Unlock() cv.allowedChainIDs[chainID] = true cv.logger.Info(fmt.Sprintf("Added chain ID %d to allowed list", chainID)) } // RemoveAllowedChainID removes a chain ID from the allowed list func (cv *ChainIDValidator) RemoveAllowedChainID(chainID uint64) { cv.mu.Lock() defer cv.mu.Unlock() delete(cv.allowedChainIDs, chainID) cv.logger.Info(fmt.Sprintf("Removed chain ID %d from allowed list", chainID)) } // ValidateSignerMatchesChain verifies that the signer's address matches the expected chain func (cv *ChainIDValidator) ValidateSignerMatchesChain(tx *types.Transaction, expectedSigner common.Address) error { // Create appropriate signer based on transaction type var signer types.Signer switch tx.Type() { case types.LegacyTxType: signer = types.NewEIP155Signer(tx.ChainId()) case types.DynamicFeeTxType: signer = types.NewLondonSigner(tx.ChainId()) default: return fmt.Errorf("unsupported transaction type: %d", tx.Type()) } // Recover the signer from the transaction recoveredSigner, err := types.Sender(signer, tx) if err != nil { return fmt.Errorf("failed to recover signer: %w", err) } // Verify the signer matches expected if recoveredSigner != expectedSigner { return fmt.Errorf("signer mismatch: expected %s, got %s", expectedSigner.Hex(), recoveredSigner.Hex()) } // Additional validation: ensure the signature is valid for this chain if !cv.verifySignatureForChain(tx, recoveredSigner) { return fmt.Errorf("signature invalid for chain ID %d", tx.ChainId().Uint64()) } return nil } // verifySignatureForChain verifies the signature is valid for the specific chain func (cv *ChainIDValidator) verifySignatureForChain(tx *types.Transaction, signer common.Address) bool { // Create appropriate signer based on transaction type var chainSigner types.Signer switch tx.Type() { case types.LegacyTxType: chainSigner = types.NewEIP155Signer(tx.ChainId()) case types.DynamicFeeTxType: chainSigner = types.NewLondonSigner(tx.ChainId()) default: return false // Unsupported transaction type } // Try to recover the signer - if it matches and doesn't error, signature is valid recoveredSigner, err := types.Sender(chainSigner, tx) if err != nil { return false } return recoveredSigner == signer }