Files
2025-12-26 13:38:04 +01:00

637 lines
16 KiB
Go

package main
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
)
// Test setup helper to initialize JWT secret for tests
func setupTestEnv() {
os.Setenv("JWT_SECRET", "test-jwt-secret-minimum-32-characters-long-for-security")
}
// TestRateLimiter tests the rate limiting functionality
func TestRateLimiter(t *testing.T) {
rl := &rateLimiter{attempts: make(map[string]*attemptInfo)}
// Test that first request is not rate limited
if rl.checkRateLimit("test-ip", 5, time.Minute) {
t.Error("First request should not be rate limited")
}
// Fill up the limit
for i := 0; i < 4; i++ {
rl.checkRateLimit("test-ip", 5, time.Minute)
}
// 6th request should be rate limited
if !rl.checkRateLimit("test-ip", 5, time.Minute) {
t.Error("6th request should be rate limited")
}
// Different IP should not be rate limited
if rl.checkRateLimit("other-ip", 5, time.Minute) {
t.Error("Different IP should not be rate limited")
}
}
func TestRateLimiterClearAttempts(t *testing.T) {
rl := &rateLimiter{attempts: make(map[string]*attemptInfo)}
// Add some attempts
for i := 0; i < 3; i++ {
rl.checkRateLimit("test-ip", 5, time.Minute)
}
// Clear attempts
rl.clearAttempts("test-ip")
// Should not be rate limited after clearing
if rl.checkRateLimit("test-ip", 5, time.Minute) {
t.Error("Should not be rate limited after clearing attempts")
}
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
headers map[string]string
remoteAddr string
want string
}{
{
name: "X-Forwarded-For single",
headers: map[string]string{"X-Forwarded-For": "192.168.1.1"},
remoteAddr: "10.0.0.1:12345",
want: "192.168.1.1",
},
{
name: "X-Forwarded-For multiple",
headers: map[string]string{"X-Forwarded-For": "192.168.1.1, 10.0.0.2"},
remoteAddr: "10.0.0.1:12345",
want: "192.168.1.1",
},
{
name: "X-Real-IP",
headers: map[string]string{"X-Real-IP": "192.168.1.2"},
remoteAddr: "10.0.0.1:12345",
want: "192.168.1.2",
},
{
name: "Remote address only",
headers: map[string]string{},
remoteAddr: "10.0.0.1:12345",
want: "10.0.0.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
for k, v := range tt.headers {
req.Header.Set(k, v)
}
got := getClientIP(req)
if got != tt.want {
t.Errorf("getClientIP() = %v, want %v", got, tt.want)
}
})
}
}
// TestRegisterEmailPasswordValidation tests input validation for registration
func TestRegisterEmailPasswordValidation(t *testing.T) {
setupTestEnv()
tests := []struct {
name string
payload map[string]string
expectedStatus int
expectedError string
}{
{
name: "missing email",
payload: map[string]string{"password": "Password123", "name": "Test"},
expectedStatus: http.StatusBadRequest,
expectedError: "Email is required",
},
{
name: "invalid email format",
payload: map[string]string{"email": "notanemail", "password": "Password123", "name": "Test"},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid email format",
},
{
name: "password too short",
payload: map[string]string{"email": "test@example.com", "password": "Pass1", "name": "Test"},
expectedStatus: http.StatusBadRequest,
expectedError: "Password must be at least 8 characters",
},
{
name: "password no uppercase",
payload: map[string]string{"email": "test@example.com", "password": "password123", "name": "Test"},
expectedStatus: http.StatusBadRequest,
expectedError: "uppercase",
},
{
name: "missing name",
payload: map[string]string{"email": "test@example.com", "password": "Password123"},
expectedStatus: http.StatusBadRequest,
expectedError: "Name is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body, _ := json.Marshal(tt.payload)
req := httptest.NewRequest("POST", "/register-email-password", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
handleRegisterEmailPassword(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus)
}
if tt.expectedError != "" && !strings.Contains(rr.Body.String(), tt.expectedError) {
t.Errorf("handler returned unexpected body: got %v, want to contain %v", rr.Body.String(), tt.expectedError)
}
})
}
}
// TestLoginEmailPasswordValidation tests input validation for login
// Note: Login handler requires DB for user lookup, so we skip these tests without DB
func TestLoginEmailPasswordValidation(t *testing.T) {
if db == nil {
t.Skip("Skipping login validation tests - requires database connection")
}
setupTestEnv()
tests := []struct {
name string
payload map[string]string
expectedStatus int
}{
{
name: "missing email",
payload: map[string]string{"password": "Password123"},
expectedStatus: http.StatusBadRequest,
},
{
name: "missing password",
payload: map[string]string{"email": "test@example.com"},
expectedStatus: http.StatusBadRequest,
},
{
name: "empty body",
payload: map[string]string{},
expectedStatus: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body, _ := json.Marshal(tt.payload)
req := httptest.NewRequest("POST", "/login-email-password", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
handleLoginEmailPassword(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus)
}
})
}
}
// TestRefreshTokenValidation tests refresh token endpoint validation (non-DB tests only)
func TestRefreshTokenValidation(t *testing.T) {
tests := []struct {
name string
method string
payload map[string]string
expectedStatus int
requiresDB bool
}{
{
name: "wrong method",
method: "GET",
payload: map[string]string{},
expectedStatus: http.StatusMethodNotAllowed,
requiresDB: false,
},
{
name: "missing refresh token",
method: "POST",
payload: map[string]string{},
expectedStatus: http.StatusBadRequest,
requiresDB: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.requiresDB && db == nil {
t.Skip("Skipping test that requires database connection")
}
body, _ := json.Marshal(tt.payload)
req := httptest.NewRequest(tt.method, "/auth/refresh", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
handleRefreshToken(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus)
}
})
}
}
// TestLogoutValidation tests logout endpoint validation (non-DB tests only)
func TestLogoutValidation(t *testing.T) {
tests := []struct {
name string
method string
payload map[string]string
expectedStatus int
requiresDB bool
}{
{
name: "wrong method",
method: "GET",
payload: map[string]string{},
expectedStatus: http.StatusMethodNotAllowed,
requiresDB: false,
},
{
name: "missing refresh token",
method: "POST",
payload: map[string]string{},
expectedStatus: http.StatusBadRequest,
requiresDB: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.requiresDB && db == nil {
t.Skip("Skipping test that requires database connection")
}
body, _ := json.Marshal(tt.payload)
req := httptest.NewRequest(tt.method, "/auth/logout", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
handleLogout(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus)
}
})
}
}
// TestExtractRoles tests role extraction from JWT claims
func TestExtractRoles(t *testing.T) {
tests := []struct {
name string
claims map[string]interface{}
wantRoles []string
wantErr bool
}{
{
name: "string array roles",
claims: map[string]interface{}{"roles": []interface{}{"CLIENT", "STAFF"}},
wantRoles: []string{"CLIENT", "STAFF"},
wantErr: false,
},
{
name: "empty roles",
claims: map[string]interface{}{"roles": []interface{}{}},
wantRoles: []string{},
wantErr: false,
},
{
name: "missing roles",
claims: map[string]interface{}{},
wantRoles: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
roles, err := extractRoles(tt.claims)
if (err != nil) != tt.wantErr {
t.Errorf("extractRoles() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && len(roles) != len(tt.wantRoles) {
t.Errorf("extractRoles() = %v, want %v", roles, tt.wantRoles)
}
})
}
}
func TestVerifyEthereumSignature(t *testing.T) {
// Test with invalid inputs to ensure error handling
tests := []struct {
name string
address string
message string
signature string
want bool
}{
{
name: "empty inputs",
address: "",
message: "",
signature: "",
want: false,
},
{
name: "invalid address format",
address: "invalid",
message: "test message",
signature: "0x0000",
want: false,
},
{
name: "invalid signature format",
address: "0x1234567890123456789012345678901234567890",
message: "test message",
signature: "invalid",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := verifyEthereumSignature(tt.address, tt.message, tt.signature)
if got != tt.want {
t.Errorf("verifyEthereumSignature() = %v, want %v", got, tt.want)
}
})
}
}
func TestGenerateJWT(t *testing.T) {
user := User{
ID: 1,
Name: "Test User",
Email: ptr("test@example.com"),
}
roles := []string{"CLIENT"}
// This will fail if JWT_SECRET is not set, which is expected in test environment
// In production CI, JWT_SECRET should be set
token, err := generateJWT(user, roles)
// If JWT_SECRET is not set, we expect an error
if token == "" && err == nil {
t.Error("Expected either a token or an error")
}
// If we got a token, it should be a non-empty string
if token != "" && len(token) < 10 {
t.Error("Token is too short to be valid")
}
}
func ptr[T any](v T) *T { return &v }
func TestPasswordHashing(t *testing.T) {
password := "testPassword123"
// Test hashing
hashed, err := hashPassword(password)
if err != nil {
t.Fatalf("hashPassword() error = %v", err)
}
if hashed == "" {
t.Error("hashPassword() returned empty string")
}
if hashed == password {
t.Error("hashPassword() did not hash the password")
}
// Test verification with correct password
if !checkPasswordHash(password, hashed) {
t.Error("checkPasswordHash() failed to verify correct password")
}
// Test verification with incorrect password
if checkPasswordHash("wrongPassword", hashed) {
t.Error("checkPasswordHash() verified incorrect password")
}
}
func TestValidateEmail(t *testing.T) {
tests := []struct {
name string
email string
wantErr bool
errMsg string
}{
{
name: "valid email",
email: "test@example.com",
wantErr: false,
},
{
name: "valid email with subdomain",
email: "test@mail.example.com",
wantErr: false,
},
{
name: "valid email with plus",
email: "test+tag@example.com",
wantErr: false,
},
{
name: "empty email",
email: "",
wantErr: true,
errMsg: "Email is required",
},
{
name: "email without @",
email: "testexample.com",
wantErr: true,
errMsg: "Invalid email format",
},
{
name: "email without domain",
email: "test@",
wantErr: true,
errMsg: "Invalid email format",
},
{
name: "email with invalid TLD",
email: "test@example.c",
wantErr: true,
errMsg: "Invalid email format",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateEmail(tt.email)
if (err != nil) != tt.wantErr {
t.Errorf("validateEmail(%s) error = %v, wantErr %v", tt.email, err, tt.wantErr)
}
if err != nil && tt.errMsg != "" && err.Message != tt.errMsg {
t.Errorf("validateEmail(%s) error message = %s, want %s", tt.email, err.Message, tt.errMsg)
}
})
}
}
func TestValidatePassword(t *testing.T) {
tests := []struct {
name string
password string
wantErr bool
errMsg string
}{
{
name: "valid password",
password: "Password1",
wantErr: false,
},
{
name: "valid complex password",
password: "MySecure123Password!",
wantErr: false,
},
{
name: "too short",
password: "Pass1",
wantErr: true,
errMsg: "Password must be at least 8 characters",
},
{
name: "no uppercase",
password: "password123",
wantErr: true,
errMsg: "Password must contain at least one uppercase letter",
},
{
name: "no lowercase",
password: "PASSWORD123",
wantErr: true,
errMsg: "Password must contain at least one lowercase letter",
},
{
name: "no number",
password: "PasswordABC",
wantErr: true,
errMsg: "Password must contain at least one number",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePassword(tt.password)
if (err != nil) != tt.wantErr {
t.Errorf("validatePassword() error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil && tt.errMsg != "" && err.Message != tt.errMsg {
t.Errorf("validatePassword() error message = %s, want %s", err.Message, tt.errMsg)
}
})
}
}
func TestValidateName(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
errMsg string
}{
{
name: "valid name",
input: "John Doe",
wantErr: false,
},
{
name: "empty name",
input: "",
wantErr: true,
errMsg: "Name is required",
},
{
name: "whitespace only",
input: " ",
wantErr: true,
errMsg: "Name is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateName(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("validateName() error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil && tt.errMsg != "" && err.Message != tt.errMsg {
t.Errorf("validateName() error message = %s, want %s", err.Message, tt.errMsg)
}
})
}
}
func TestNormalizeEthereumAddress(t *testing.T) {
tests := []struct {
name string
address string
want string
}{
{
name: "lowercase address",
address: "0xabcdef1234567890abcdef1234567890abcdef12",
want: "0xabcdef1234567890abcdef1234567890abcdef12",
},
{
name: "uppercase address",
address: "0xABCDEF1234567890ABCDEF1234567890ABCDEF12",
want: "0xabcdef1234567890abcdef1234567890abcdef12",
},
{
name: "mixed case address",
address: "0xAbCdEf1234567890AbCdEf1234567890AbCdEf12",
want: "0xabcdef1234567890abcdef1234567890abcdef12",
},
{
name: "address without 0x prefix",
address: "abcdef1234567890abcdef1234567890abcdef12",
want: "0xabcdef1234567890abcdef1234567890abcdef12",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeEthereumAddress(tt.address)
if got != tt.want {
t.Errorf("normalizeEthereumAddress() = %v, want %v", got, tt.want)
}
})
}
}