637 lines
16 KiB
Go
637 lines
16 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// Test setup helper to initialize JWT secret for tests
|
|
func setupTestEnv() {
|
|
os.Setenv("JWT_SECRET", "test-jwt-secret-minimum-32-characters-long-for-security")
|
|
}
|
|
|
|
// TestRateLimiter tests the rate limiting functionality
|
|
func TestRateLimiter(t *testing.T) {
|
|
rl := &rateLimiter{attempts: make(map[string]*attemptInfo)}
|
|
|
|
// Test that first request is not rate limited
|
|
if rl.checkRateLimit("test-ip", 5, time.Minute) {
|
|
t.Error("First request should not be rate limited")
|
|
}
|
|
|
|
// Fill up the limit
|
|
for i := 0; i < 4; i++ {
|
|
rl.checkRateLimit("test-ip", 5, time.Minute)
|
|
}
|
|
|
|
// 6th request should be rate limited
|
|
if !rl.checkRateLimit("test-ip", 5, time.Minute) {
|
|
t.Error("6th request should be rate limited")
|
|
}
|
|
|
|
// Different IP should not be rate limited
|
|
if rl.checkRateLimit("other-ip", 5, time.Minute) {
|
|
t.Error("Different IP should not be rate limited")
|
|
}
|
|
}
|
|
|
|
func TestRateLimiterClearAttempts(t *testing.T) {
|
|
rl := &rateLimiter{attempts: make(map[string]*attemptInfo)}
|
|
|
|
// Add some attempts
|
|
for i := 0; i < 3; i++ {
|
|
rl.checkRateLimit("test-ip", 5, time.Minute)
|
|
}
|
|
|
|
// Clear attempts
|
|
rl.clearAttempts("test-ip")
|
|
|
|
// Should not be rate limited after clearing
|
|
if rl.checkRateLimit("test-ip", 5, time.Minute) {
|
|
t.Error("Should not be rate limited after clearing attempts")
|
|
}
|
|
}
|
|
|
|
func TestGetClientIP(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
headers map[string]string
|
|
remoteAddr string
|
|
want string
|
|
}{
|
|
{
|
|
name: "X-Forwarded-For single",
|
|
headers: map[string]string{"X-Forwarded-For": "192.168.1.1"},
|
|
remoteAddr: "10.0.0.1:12345",
|
|
want: "192.168.1.1",
|
|
},
|
|
{
|
|
name: "X-Forwarded-For multiple",
|
|
headers: map[string]string{"X-Forwarded-For": "192.168.1.1, 10.0.0.2"},
|
|
remoteAddr: "10.0.0.1:12345",
|
|
want: "192.168.1.1",
|
|
},
|
|
{
|
|
name: "X-Real-IP",
|
|
headers: map[string]string{"X-Real-IP": "192.168.1.2"},
|
|
remoteAddr: "10.0.0.1:12345",
|
|
want: "192.168.1.2",
|
|
},
|
|
{
|
|
name: "Remote address only",
|
|
headers: map[string]string{},
|
|
remoteAddr: "10.0.0.1:12345",
|
|
want: "10.0.0.1",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = tt.remoteAddr
|
|
for k, v := range tt.headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
got := getClientIP(req)
|
|
if got != tt.want {
|
|
t.Errorf("getClientIP() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRegisterEmailPasswordValidation tests input validation for registration
|
|
func TestRegisterEmailPasswordValidation(t *testing.T) {
|
|
setupTestEnv()
|
|
|
|
tests := []struct {
|
|
name string
|
|
payload map[string]string
|
|
expectedStatus int
|
|
expectedError string
|
|
}{
|
|
{
|
|
name: "missing email",
|
|
payload: map[string]string{"password": "Password123", "name": "Test"},
|
|
expectedStatus: http.StatusBadRequest,
|
|
expectedError: "Email is required",
|
|
},
|
|
{
|
|
name: "invalid email format",
|
|
payload: map[string]string{"email": "notanemail", "password": "Password123", "name": "Test"},
|
|
expectedStatus: http.StatusBadRequest,
|
|
expectedError: "Invalid email format",
|
|
},
|
|
{
|
|
name: "password too short",
|
|
payload: map[string]string{"email": "test@example.com", "password": "Pass1", "name": "Test"},
|
|
expectedStatus: http.StatusBadRequest,
|
|
expectedError: "Password must be at least 8 characters",
|
|
},
|
|
{
|
|
name: "password no uppercase",
|
|
payload: map[string]string{"email": "test@example.com", "password": "password123", "name": "Test"},
|
|
expectedStatus: http.StatusBadRequest,
|
|
expectedError: "uppercase",
|
|
},
|
|
{
|
|
name: "missing name",
|
|
payload: map[string]string{"email": "test@example.com", "password": "Password123"},
|
|
expectedStatus: http.StatusBadRequest,
|
|
expectedError: "Name is required",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
body, _ := json.Marshal(tt.payload)
|
|
req := httptest.NewRequest("POST", "/register-email-password", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
rr := httptest.NewRecorder()
|
|
handleRegisterEmailPassword(rr, req)
|
|
|
|
if rr.Code != tt.expectedStatus {
|
|
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus)
|
|
}
|
|
|
|
if tt.expectedError != "" && !strings.Contains(rr.Body.String(), tt.expectedError) {
|
|
t.Errorf("handler returned unexpected body: got %v, want to contain %v", rr.Body.String(), tt.expectedError)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestLoginEmailPasswordValidation tests input validation for login
|
|
// Note: Login handler requires DB for user lookup, so we skip these tests without DB
|
|
func TestLoginEmailPasswordValidation(t *testing.T) {
|
|
if db == nil {
|
|
t.Skip("Skipping login validation tests - requires database connection")
|
|
}
|
|
setupTestEnv()
|
|
|
|
tests := []struct {
|
|
name string
|
|
payload map[string]string
|
|
expectedStatus int
|
|
}{
|
|
{
|
|
name: "missing email",
|
|
payload: map[string]string{"password": "Password123"},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "missing password",
|
|
payload: map[string]string{"email": "test@example.com"},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "empty body",
|
|
payload: map[string]string{},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
body, _ := json.Marshal(tt.payload)
|
|
req := httptest.NewRequest("POST", "/login-email-password", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
rr := httptest.NewRecorder()
|
|
handleLoginEmailPassword(rr, req)
|
|
|
|
if rr.Code != tt.expectedStatus {
|
|
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRefreshTokenValidation tests refresh token endpoint validation (non-DB tests only)
|
|
func TestRefreshTokenValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
method string
|
|
payload map[string]string
|
|
expectedStatus int
|
|
requiresDB bool
|
|
}{
|
|
{
|
|
name: "wrong method",
|
|
method: "GET",
|
|
payload: map[string]string{},
|
|
expectedStatus: http.StatusMethodNotAllowed,
|
|
requiresDB: false,
|
|
},
|
|
{
|
|
name: "missing refresh token",
|
|
method: "POST",
|
|
payload: map[string]string{},
|
|
expectedStatus: http.StatusBadRequest,
|
|
requiresDB: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.requiresDB && db == nil {
|
|
t.Skip("Skipping test that requires database connection")
|
|
}
|
|
|
|
body, _ := json.Marshal(tt.payload)
|
|
req := httptest.NewRequest(tt.method, "/auth/refresh", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
rr := httptest.NewRecorder()
|
|
handleRefreshToken(rr, req)
|
|
|
|
if rr.Code != tt.expectedStatus {
|
|
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestLogoutValidation tests logout endpoint validation (non-DB tests only)
|
|
func TestLogoutValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
method string
|
|
payload map[string]string
|
|
expectedStatus int
|
|
requiresDB bool
|
|
}{
|
|
{
|
|
name: "wrong method",
|
|
method: "GET",
|
|
payload: map[string]string{},
|
|
expectedStatus: http.StatusMethodNotAllowed,
|
|
requiresDB: false,
|
|
},
|
|
{
|
|
name: "missing refresh token",
|
|
method: "POST",
|
|
payload: map[string]string{},
|
|
expectedStatus: http.StatusBadRequest,
|
|
requiresDB: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.requiresDB && db == nil {
|
|
t.Skip("Skipping test that requires database connection")
|
|
}
|
|
|
|
body, _ := json.Marshal(tt.payload)
|
|
req := httptest.NewRequest(tt.method, "/auth/logout", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
rr := httptest.NewRecorder()
|
|
handleLogout(rr, req)
|
|
|
|
if rr.Code != tt.expectedStatus {
|
|
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.expectedStatus)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestExtractRoles tests role extraction from JWT claims
|
|
func TestExtractRoles(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
claims map[string]interface{}
|
|
wantRoles []string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "string array roles",
|
|
claims: map[string]interface{}{"roles": []interface{}{"CLIENT", "STAFF"}},
|
|
wantRoles: []string{"CLIENT", "STAFF"},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "empty roles",
|
|
claims: map[string]interface{}{"roles": []interface{}{}},
|
|
wantRoles: []string{},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "missing roles",
|
|
claims: map[string]interface{}{},
|
|
wantRoles: nil,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
roles, err := extractRoles(tt.claims)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("extractRoles() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
if !tt.wantErr && len(roles) != len(tt.wantRoles) {
|
|
t.Errorf("extractRoles() = %v, want %v", roles, tt.wantRoles)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestVerifyEthereumSignature(t *testing.T) {
|
|
// Test with invalid inputs to ensure error handling
|
|
tests := []struct {
|
|
name string
|
|
address string
|
|
message string
|
|
signature string
|
|
want bool
|
|
}{
|
|
{
|
|
name: "empty inputs",
|
|
address: "",
|
|
message: "",
|
|
signature: "",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "invalid address format",
|
|
address: "invalid",
|
|
message: "test message",
|
|
signature: "0x0000",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "invalid signature format",
|
|
address: "0x1234567890123456789012345678901234567890",
|
|
message: "test message",
|
|
signature: "invalid",
|
|
want: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := verifyEthereumSignature(tt.address, tt.message, tt.signature)
|
|
if got != tt.want {
|
|
t.Errorf("verifyEthereumSignature() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateJWT(t *testing.T) {
|
|
user := User{
|
|
ID: 1,
|
|
Name: "Test User",
|
|
Email: ptr("test@example.com"),
|
|
}
|
|
roles := []string{"CLIENT"}
|
|
|
|
// This will fail if JWT_SECRET is not set, which is expected in test environment
|
|
// In production CI, JWT_SECRET should be set
|
|
token, err := generateJWT(user, roles)
|
|
|
|
// If JWT_SECRET is not set, we expect an error
|
|
if token == "" && err == nil {
|
|
t.Error("Expected either a token or an error")
|
|
}
|
|
|
|
// If we got a token, it should be a non-empty string
|
|
if token != "" && len(token) < 10 {
|
|
t.Error("Token is too short to be valid")
|
|
}
|
|
}
|
|
|
|
func ptr[T any](v T) *T { return &v }
|
|
|
|
func TestPasswordHashing(t *testing.T) {
|
|
password := "testPassword123"
|
|
|
|
// Test hashing
|
|
hashed, err := hashPassword(password)
|
|
if err != nil {
|
|
t.Fatalf("hashPassword() error = %v", err)
|
|
}
|
|
|
|
if hashed == "" {
|
|
t.Error("hashPassword() returned empty string")
|
|
}
|
|
|
|
if hashed == password {
|
|
t.Error("hashPassword() did not hash the password")
|
|
}
|
|
|
|
// Test verification with correct password
|
|
if !checkPasswordHash(password, hashed) {
|
|
t.Error("checkPasswordHash() failed to verify correct password")
|
|
}
|
|
|
|
// Test verification with incorrect password
|
|
if checkPasswordHash("wrongPassword", hashed) {
|
|
t.Error("checkPasswordHash() verified incorrect password")
|
|
}
|
|
}
|
|
|
|
func TestValidateEmail(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
email string
|
|
wantErr bool
|
|
errMsg string
|
|
}{
|
|
{
|
|
name: "valid email",
|
|
email: "test@example.com",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "valid email with subdomain",
|
|
email: "test@mail.example.com",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "valid email with plus",
|
|
email: "test+tag@example.com",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "empty email",
|
|
email: "",
|
|
wantErr: true,
|
|
errMsg: "Email is required",
|
|
},
|
|
{
|
|
name: "email without @",
|
|
email: "testexample.com",
|
|
wantErr: true,
|
|
errMsg: "Invalid email format",
|
|
},
|
|
{
|
|
name: "email without domain",
|
|
email: "test@",
|
|
wantErr: true,
|
|
errMsg: "Invalid email format",
|
|
},
|
|
{
|
|
name: "email with invalid TLD",
|
|
email: "test@example.c",
|
|
wantErr: true,
|
|
errMsg: "Invalid email format",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validateEmail(tt.email)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("validateEmail(%s) error = %v, wantErr %v", tt.email, err, tt.wantErr)
|
|
}
|
|
if err != nil && tt.errMsg != "" && err.Message != tt.errMsg {
|
|
t.Errorf("validateEmail(%s) error message = %s, want %s", tt.email, err.Message, tt.errMsg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidatePassword(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
password string
|
|
wantErr bool
|
|
errMsg string
|
|
}{
|
|
{
|
|
name: "valid password",
|
|
password: "Password1",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "valid complex password",
|
|
password: "MySecure123Password!",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "too short",
|
|
password: "Pass1",
|
|
wantErr: true,
|
|
errMsg: "Password must be at least 8 characters",
|
|
},
|
|
{
|
|
name: "no uppercase",
|
|
password: "password123",
|
|
wantErr: true,
|
|
errMsg: "Password must contain at least one uppercase letter",
|
|
},
|
|
{
|
|
name: "no lowercase",
|
|
password: "PASSWORD123",
|
|
wantErr: true,
|
|
errMsg: "Password must contain at least one lowercase letter",
|
|
},
|
|
{
|
|
name: "no number",
|
|
password: "PasswordABC",
|
|
wantErr: true,
|
|
errMsg: "Password must contain at least one number",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validatePassword(tt.password)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("validatePassword() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
if err != nil && tt.errMsg != "" && err.Message != tt.errMsg {
|
|
t.Errorf("validatePassword() error message = %s, want %s", err.Message, tt.errMsg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidateName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
wantErr bool
|
|
errMsg string
|
|
}{
|
|
{
|
|
name: "valid name",
|
|
input: "John Doe",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "empty name",
|
|
input: "",
|
|
wantErr: true,
|
|
errMsg: "Name is required",
|
|
},
|
|
{
|
|
name: "whitespace only",
|
|
input: " ",
|
|
wantErr: true,
|
|
errMsg: "Name is required",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validateName(tt.input)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("validateName() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
if err != nil && tt.errMsg != "" && err.Message != tt.errMsg {
|
|
t.Errorf("validateName() error message = %s, want %s", err.Message, tt.errMsg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNormalizeEthereumAddress(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
address string
|
|
want string
|
|
}{
|
|
{
|
|
name: "lowercase address",
|
|
address: "0xabcdef1234567890abcdef1234567890abcdef12",
|
|
want: "0xabcdef1234567890abcdef1234567890abcdef12",
|
|
},
|
|
{
|
|
name: "uppercase address",
|
|
address: "0xABCDEF1234567890ABCDEF1234567890ABCDEF12",
|
|
want: "0xabcdef1234567890abcdef1234567890abcdef12",
|
|
},
|
|
{
|
|
name: "mixed case address",
|
|
address: "0xAbCdEf1234567890AbCdEf1234567890AbCdEf12",
|
|
want: "0xabcdef1234567890abcdef1234567890abcdef12",
|
|
},
|
|
{
|
|
name: "address without 0x prefix",
|
|
address: "abcdef1234567890abcdef1234567890abcdef12",
|
|
want: "0xabcdef1234567890abcdef1234567890abcdef12",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := normalizeEthereumAddress(tt.address)
|
|
if got != tt.want {
|
|
t.Errorf("normalizeEthereumAddress() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|