package main import ( "bytes" "encoding/json" "net/http" "net/http/httptest" "os" "strings" "testing" "time" ) // Test setup helper to initialize JWT secret for tests func setupTestEnv() { os.Setenv("JWT_SECRET", "test-jwt-secret-minimum-32-characters-long-for-security") } // TestRateLimiter tests the rate limiting functionality func TestRateLimiter(t *testing.T) { rl := &rateLimiter{attempts: make(map[string]*attemptInfo)} // Test that first request is not rate limited if rl.checkRateLimit("test-ip", 5, time.Minute) { t.Error("First request should not be rate limited") } // Fill up the limit for i := 0; i < 4; i++ { rl.checkRateLimit("test-ip", 5, time.Minute) } // 6th request should be rate limited if !rl.checkRateLimit("test-ip", 5, time.Minute) { t.Error("6th request should be rate limited") } // Different IP should not be rate limited if rl.checkRateLimit("other-ip", 5, time.Minute) { t.Error("Different IP should not be rate limited") } } func TestRateLimiterClearAttempts(t *testing.T) { rl := &rateLimiter{attempts: make(map[string]*attemptInfo)} // Add some attempts for i := 0; i < 3; i++ { rl.checkRateLimit("test-ip", 5, time.Minute) } // Clear attempts rl.clearAttempts("test-ip") // Should not be rate limited after clearing if rl.checkRateLimit("test-ip", 5, time.Minute) { t.Error("Should not be rate limited after clearing attempts") } } func TestGetClientIP(t *testing.T) { tests := []struct { name string headers map[string]string remoteAddr string want string }{ { name: "X-Forwarded-For single", headers: map[string]string{"X-Forwarded-For": "192.168.1.1"}, remoteAddr: "10.0.0.1:12345", want: "192.168.1.1", }, { name: "X-Forwarded-For multiple", headers: map[string]string{"X-Forwarded-For": "192.168.1.1, 10.0.0.2"}, remoteAddr: "10.0.0.1:12345", want: "192.168.1.1", }, { name: "X-Real-IP", headers: map[string]string{"X-Real-IP": "192.168.1.2"}, remoteAddr: "10.0.0.1:12345", want: "192.168.1.2", }, { name: "Remote address only", headers: map[string]string{}, remoteAddr: "10.0.0.1:12345", want: "10.0.0.1", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = tt.remoteAddr for k, v := range tt.headers { req.Header.Set(k, v) } got := getClientIP(req) if got != tt.want { t.Errorf("getClientIP() = %v, want %v", got, tt.want) } }) } } // TestRegisterEmailPasswordValidation tests input validation for registration func TestRegisterEmailPasswordValidation(t *testing.T) { setupTestEnv() tests := []struct { name string payload map[string]string expectedStatus int expectedError string }{ { name: "missing email", payload: map[string]string{"password": "Password123", "name": "Test"}, expectedStatus: http.StatusBadRequest, expectedError: "Email is required", }, { name: "invalid email format", payload: map[string]string{"email": "notanemail", "password": "Password123", "name": "Test"}, expectedStatus: http.StatusBadRequest, expectedError: "Invalid email format", }, { name: "password too short", payload: map[string]string{"email": "test@example.com", "password": "Pass1", "name": "Test"}, expectedStatus: http.StatusBadRequest, expectedError: "Password must be at least 8 characters", }, { name: "password no uppercase", payload: map[string]string{"email": "test@example.com", "password": "password123", "name": "Test"}, expectedStatus: http.StatusBadRequest, expectedError: "uppercase", }, { name: "missing name", payload: map[string]string{"email": "test@example.com", "password": "Password123"}, expectedStatus: http.StatusBadRequest, expectedError: "Name is required", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { body, _ := json.Marshal(tt.payload) req := httptest.NewRequest("POST", "/register-email-password", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() handleRegisterEmailPassword(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus) } if tt.expectedError != "" && !strings.Contains(rr.Body.String(), tt.expectedError) { t.Errorf("handler returned unexpected body: got %v, want to contain %v", rr.Body.String(), tt.expectedError) } }) } } // TestLoginEmailPasswordValidation tests input validation for login // Note: Login handler requires DB for user lookup, so we skip these tests without DB func TestLoginEmailPasswordValidation(t *testing.T) { if db == nil { t.Skip("Skipping login validation tests - requires database connection") } setupTestEnv() tests := []struct { name string payload map[string]string expectedStatus int }{ { name: "missing email", payload: map[string]string{"password": "Password123"}, expectedStatus: http.StatusBadRequest, }, { name: "missing password", payload: map[string]string{"email": "test@example.com"}, expectedStatus: http.StatusBadRequest, }, { name: "empty body", payload: map[string]string{}, expectedStatus: http.StatusBadRequest, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { body, _ := json.Marshal(tt.payload) req := httptest.NewRequest("POST", "/login-email-password", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() handleLoginEmailPassword(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus) } }) } } // TestRefreshTokenValidation tests refresh token endpoint validation (non-DB tests only) func TestRefreshTokenValidation(t *testing.T) { tests := []struct { name string method string payload map[string]string expectedStatus int requiresDB bool }{ { name: "wrong method", method: "GET", payload: map[string]string{}, expectedStatus: http.StatusMethodNotAllowed, requiresDB: false, }, { name: "missing refresh token", method: "POST", payload: map[string]string{}, expectedStatus: http.StatusBadRequest, requiresDB: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.requiresDB && db == nil { t.Skip("Skipping test that requires database connection") } body, _ := json.Marshal(tt.payload) req := httptest.NewRequest(tt.method, "/auth/refresh", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() handleRefreshToken(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus) } }) } } // TestLogoutValidation tests logout endpoint validation (non-DB tests only) func TestLogoutValidation(t *testing.T) { tests := []struct { name string method string payload map[string]string expectedStatus int requiresDB bool }{ { name: "wrong method", method: "GET", payload: map[string]string{}, expectedStatus: http.StatusMethodNotAllowed, requiresDB: false, }, { name: "missing refresh token", method: "POST", payload: map[string]string{}, expectedStatus: http.StatusBadRequest, requiresDB: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.requiresDB && db == nil { t.Skip("Skipping test that requires database connection") } body, _ := json.Marshal(tt.payload) req := httptest.NewRequest(tt.method, "/auth/logout", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() handleLogout(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus) } }) } } // TestExtractRoles tests role extraction from JWT claims func TestExtractRoles(t *testing.T) { tests := []struct { name string claims map[string]interface{} wantRoles []string wantErr bool }{ { name: "string array roles", claims: map[string]interface{}{"roles": []interface{}{"CLIENT", "STAFF"}}, wantRoles: []string{"CLIENT", "STAFF"}, wantErr: false, }, { name: "empty roles", claims: map[string]interface{}{"roles": []interface{}{}}, wantRoles: []string{}, wantErr: false, }, { name: "missing roles", claims: map[string]interface{}{}, wantRoles: nil, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { roles, err := extractRoles(tt.claims) if (err != nil) != tt.wantErr { t.Errorf("extractRoles() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr && len(roles) != len(tt.wantRoles) { t.Errorf("extractRoles() = %v, want %v", roles, tt.wantRoles) } }) } } func TestVerifyEthereumSignature(t *testing.T) { // Test with invalid inputs to ensure error handling tests := []struct { name string address string message string signature string want bool }{ { name: "empty inputs", address: "", message: "", signature: "", want: false, }, { name: "invalid address format", address: "invalid", message: "test message", signature: "0x0000", want: false, }, { name: "invalid signature format", address: "0x1234567890123456789012345678901234567890", message: "test message", signature: "invalid", want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := verifyEthereumSignature(tt.address, tt.message, tt.signature) if got != tt.want { t.Errorf("verifyEthereumSignature() = %v, want %v", got, tt.want) } }) } } func TestGenerateJWT(t *testing.T) { user := User{ ID: 1, Name: "Test User", Email: ptr("test@example.com"), } roles := []string{"CLIENT"} // This will fail if JWT_SECRET is not set, which is expected in test environment // In production CI, JWT_SECRET should be set token, err := generateJWT(user, roles) // If JWT_SECRET is not set, we expect an error if token == "" && err == nil { t.Error("Expected either a token or an error") } // If we got a token, it should be a non-empty string if token != "" && len(token) < 10 { t.Error("Token is too short to be valid") } } func ptr[T any](v T) *T { return &v } func TestPasswordHashing(t *testing.T) { password := "testPassword123" // Test hashing hashed, err := hashPassword(password) if err != nil { t.Fatalf("hashPassword() error = %v", err) } if hashed == "" { t.Error("hashPassword() returned empty string") } if hashed == password { t.Error("hashPassword() did not hash the password") } // Test verification with correct password if !checkPasswordHash(password, hashed) { t.Error("checkPasswordHash() failed to verify correct password") } // Test verification with incorrect password if checkPasswordHash("wrongPassword", hashed) { t.Error("checkPasswordHash() verified incorrect password") } } func TestValidateEmail(t *testing.T) { tests := []struct { name string email string wantErr bool errMsg string }{ { name: "valid email", email: "test@example.com", wantErr: false, }, { name: "valid email with subdomain", email: "test@mail.example.com", wantErr: false, }, { name: "valid email with plus", email: "test+tag@example.com", wantErr: false, }, { name: "empty email", email: "", wantErr: true, errMsg: "Email is required", }, { name: "email without @", email: "testexample.com", wantErr: true, errMsg: "Invalid email format", }, { name: "email without domain", email: "test@", wantErr: true, errMsg: "Invalid email format", }, { name: "email with invalid TLD", email: "test@example.c", wantErr: true, errMsg: "Invalid email format", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateEmail(tt.email) if (err != nil) != tt.wantErr { t.Errorf("validateEmail(%s) error = %v, wantErr %v", tt.email, err, tt.wantErr) } if err != nil && tt.errMsg != "" && err.Message != tt.errMsg { t.Errorf("validateEmail(%s) error message = %s, want %s", tt.email, err.Message, tt.errMsg) } }) } } func TestValidatePassword(t *testing.T) { tests := []struct { name string password string wantErr bool errMsg string }{ { name: "valid password", password: "Password1", wantErr: false, }, { name: "valid complex password", password: "MySecure123Password!", wantErr: false, }, { name: "too short", password: "Pass1", wantErr: true, errMsg: "Password must be at least 8 characters", }, { name: "no uppercase", password: "password123", wantErr: true, errMsg: "Password must contain at least one uppercase letter", }, { name: "no lowercase", password: "PASSWORD123", wantErr: true, errMsg: "Password must contain at least one lowercase letter", }, { name: "no number", password: "PasswordABC", wantErr: true, errMsg: "Password must contain at least one number", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validatePassword(tt.password) if (err != nil) != tt.wantErr { t.Errorf("validatePassword() error = %v, wantErr %v", err, tt.wantErr) } if err != nil && tt.errMsg != "" && err.Message != tt.errMsg { t.Errorf("validatePassword() error message = %s, want %s", err.Message, tt.errMsg) } }) } } func TestValidateName(t *testing.T) { tests := []struct { name string input string wantErr bool errMsg string }{ { name: "valid name", input: "John Doe", wantErr: false, }, { name: "empty name", input: "", wantErr: true, errMsg: "Name is required", }, { name: "whitespace only", input: " ", wantErr: true, errMsg: "Name is required", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateName(tt.input) if (err != nil) != tt.wantErr { t.Errorf("validateName() error = %v, wantErr %v", err, tt.wantErr) } if err != nil && tt.errMsg != "" && err.Message != tt.errMsg { t.Errorf("validateName() error message = %s, want %s", err.Message, tt.errMsg) } }) } } func TestNormalizeEthereumAddress(t *testing.T) { tests := []struct { name string address string want string }{ { name: "lowercase address", address: "0xabcdef1234567890abcdef1234567890abcdef12", want: "0xabcdef1234567890abcdef1234567890abcdef12", }, { name: "uppercase address", address: "0xABCDEF1234567890ABCDEF1234567890ABCDEF12", want: "0xabcdef1234567890abcdef1234567890abcdef12", }, { name: "mixed case address", address: "0xAbCdEf1234567890AbCdEf1234567890AbCdEf12", want: "0xabcdef1234567890abcdef1234567890abcdef12", }, { name: "address without 0x prefix", address: "abcdef1234567890abcdef1234567890abcdef12", want: "0xabcdef1234567890abcdef1234567890abcdef12", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := normalizeEthereumAddress(tt.address) if got != tt.want { t.Errorf("normalizeEthereumAddress() = %v, want %v", got, tt.want) } }) } }