2386 lines
71 KiB
Go
2386 lines
71 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
"unicode"
|
|
|
|
"github.com/ethereum/go-ethereum/crypto"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
_ "github.com/lib/pq"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// Rate limiting configuration
|
|
const (
|
|
rateLimitWindow = 15 * time.Minute // Window for counting attempts
|
|
maxLoginAttempts = 5 // Max failed login attempts per window
|
|
maxRegisterPerHour = 10 // Max registrations per IP per hour
|
|
lockoutDuration = 30 * time.Minute // How long to lock out after max attempts
|
|
)
|
|
|
|
// Token expiration configuration
|
|
const (
|
|
accessTokenExpiry = 15 * time.Minute // Short-lived access tokens
|
|
refreshTokenExpiry = 7 * 24 * time.Hour // 7 day refresh tokens
|
|
refreshTokenLength = 32 // 256-bit refresh token
|
|
csrfTokenLength = 32 // 256-bit CSRF token
|
|
)
|
|
|
|
// rateLimiter tracks login attempts per IP/email
|
|
type rateLimiter struct {
|
|
mu sync.RWMutex
|
|
attempts map[string]*attemptInfo
|
|
}
|
|
|
|
type attemptInfo struct {
|
|
count int
|
|
firstTry time.Time
|
|
lockedOut bool
|
|
lockUntil time.Time
|
|
}
|
|
|
|
var loginLimiter = &rateLimiter{attempts: make(map[string]*attemptInfo)}
|
|
var registerLimiter = &rateLimiter{attempts: make(map[string]*attemptInfo)}
|
|
|
|
// checkRateLimit returns true if the request should be blocked
|
|
func (rl *rateLimiter) checkRateLimit(key string, maxAttempts int, window time.Duration) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
info, exists := rl.attempts[key]
|
|
|
|
if !exists {
|
|
rl.attempts[key] = &attemptInfo{count: 1, firstTry: now}
|
|
return false
|
|
}
|
|
|
|
// Check if locked out
|
|
if info.lockedOut && now.Before(info.lockUntil) {
|
|
return true
|
|
}
|
|
|
|
// Reset lockout if time has passed
|
|
if info.lockedOut && now.After(info.lockUntil) {
|
|
info.lockedOut = false
|
|
info.count = 1
|
|
info.firstTry = now
|
|
return false
|
|
}
|
|
|
|
// Reset window if expired
|
|
if now.Sub(info.firstTry) > window {
|
|
info.count = 1
|
|
info.firstTry = now
|
|
return false
|
|
}
|
|
|
|
// Increment and check
|
|
info.count++
|
|
if info.count > maxAttempts {
|
|
info.lockedOut = true
|
|
info.lockUntil = now.Add(lockoutDuration)
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// recordFailedAttempt records a failed attempt (for login failures)
|
|
func (rl *rateLimiter) recordFailedAttempt(key string) {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
info, exists := rl.attempts[key]
|
|
|
|
if !exists {
|
|
rl.attempts[key] = &attemptInfo{count: 1, firstTry: now}
|
|
return
|
|
}
|
|
|
|
// Reset window if expired
|
|
if now.Sub(info.firstTry) > rateLimitWindow {
|
|
info.count = 1
|
|
info.firstTry = now
|
|
return
|
|
}
|
|
|
|
info.count++
|
|
if info.count >= maxLoginAttempts {
|
|
info.lockedOut = true
|
|
info.lockUntil = now.Add(lockoutDuration)
|
|
log.Printf("SECURITY: IP/email %s locked out after %d failed attempts", key, info.count)
|
|
}
|
|
}
|
|
|
|
// clearAttempts clears attempts for a key (called on successful login)
|
|
func (rl *rateLimiter) clearAttempts(key string) {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
delete(rl.attempts, key)
|
|
}
|
|
|
|
// getClientIP extracts the client IP from the request
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header (for proxies)
|
|
xff := r.Header.Get("X-Forwarded-For")
|
|
if xff != "" {
|
|
// Take the first IP in the chain
|
|
if idx := strings.Index(xff, ","); idx != -1 {
|
|
return strings.TrimSpace(xff[:idx])
|
|
}
|
|
return strings.TrimSpace(xff)
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
xri := r.Header.Get("X-Real-IP")
|
|
if xri != "" {
|
|
return strings.TrimSpace(xri)
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return ip
|
|
}
|
|
|
|
// Input validation constants
|
|
const (
|
|
maxNameLength = 100
|
|
maxEmailLength = 254 // RFC 5321 limit
|
|
minPasswordLength = 8
|
|
maxPasswordLength = 72 // bcrypt limit
|
|
maxRequestBody = 1 << 20 // 1MB max request body size
|
|
)
|
|
|
|
// Email validation regex (RFC 5322 simplified)
|
|
var emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
|
|
|
// ValidationError holds field-specific validation errors
|
|
type ValidationError struct {
|
|
Field string `json:"field"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
// validateEmail checks email format and length
|
|
func validateEmail(email string) *ValidationError {
|
|
email = strings.TrimSpace(email)
|
|
if email == "" {
|
|
return &ValidationError{Field: "email", Message: "Email is required"}
|
|
}
|
|
if len(email) > maxEmailLength {
|
|
return &ValidationError{Field: "email", Message: fmt.Sprintf("Email must be less than %d characters", maxEmailLength)}
|
|
}
|
|
if !emailRegex.MatchString(email) {
|
|
return &ValidationError{Field: "email", Message: "Invalid email format"}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// validatePassword checks password requirements
|
|
func validatePassword(password string) *ValidationError {
|
|
if len(password) < minPasswordLength {
|
|
return &ValidationError{Field: "password", Message: fmt.Sprintf("Password must be at least %d characters", minPasswordLength)}
|
|
}
|
|
if len(password) > maxPasswordLength {
|
|
return &ValidationError{Field: "password", Message: fmt.Sprintf("Password must be less than %d characters", maxPasswordLength)}
|
|
}
|
|
|
|
var hasUpper, hasLower, hasNumber bool
|
|
for _, c := range password {
|
|
switch {
|
|
case unicode.IsUpper(c):
|
|
hasUpper = true
|
|
case unicode.IsLower(c):
|
|
hasLower = true
|
|
case unicode.IsNumber(c):
|
|
hasNumber = true
|
|
}
|
|
}
|
|
|
|
if !hasUpper {
|
|
return &ValidationError{Field: "password", Message: "Password must contain at least one uppercase letter"}
|
|
}
|
|
if !hasLower {
|
|
return &ValidationError{Field: "password", Message: "Password must contain at least one lowercase letter"}
|
|
}
|
|
if !hasNumber {
|
|
return &ValidationError{Field: "password", Message: "Password must contain at least one number"}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateName checks name length
|
|
func validateName(name string) *ValidationError {
|
|
name = strings.TrimSpace(name)
|
|
if name == "" {
|
|
return &ValidationError{Field: "name", Message: "Name is required"}
|
|
}
|
|
if len(name) > maxNameLength {
|
|
return &ValidationError{Field: "name", Message: fmt.Sprintf("Name must be less than %d characters", maxNameLength)}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// User represents a user in the system.
|
|
type User struct {
|
|
ID int `json:"id"`
|
|
Name string `json:"name"`
|
|
Email *string `json:"email"`
|
|
Roles []string `json:"roles"`
|
|
IsInitialSuperuser bool `json:"isInitialSuperuser"`
|
|
IsProtected bool `json:"isProtected"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
}
|
|
|
|
// Identity represents an authentication method for a user.
|
|
type Identity struct {
|
|
ID int `json:"id"`
|
|
UserID int `json:"userId"`
|
|
Type string `json:"type"`
|
|
Identifier string `json:"identifier"`
|
|
IsPrimaryLogin bool `json:"isPrimaryLogin"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
}
|
|
|
|
// RegisterEmailPasswordRequest for email/password registration.
|
|
type RegisterEmailPasswordRequest struct {
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
Name string `json:"name"`
|
|
Role string `json:"role"`
|
|
}
|
|
|
|
// RegisterBlockchainRequest for blockchain address registration.
|
|
type RegisterBlockchainRequest struct {
|
|
Address string `json:"address"`
|
|
Signature string `json:"signature"`
|
|
Message string `json:"message"`
|
|
Name string `json:"name"`
|
|
Role string `json:"role"`
|
|
}
|
|
|
|
// LoginEmailPasswordRequest for email/password login.
|
|
type LoginEmailPasswordRequest struct {
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
// LoginBlockchainRequest for blockchain signature login.
|
|
type LoginBlockchainRequest struct {
|
|
Address string `json:"address"`
|
|
Signature string `json:"signature"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
// LinkIdentityRequest to link a new identity to existing account.
|
|
type LinkIdentityRequest struct {
|
|
Type string `json:"type"` // "email_password" or "blockchain_address"
|
|
Email string `json:"email,omitempty"`
|
|
Password string `json:"password,omitempty"`
|
|
Address string `json:"address,omitempty"`
|
|
Signature string `json:"signature,omitempty"`
|
|
Message string `json:"message,omitempty"`
|
|
}
|
|
|
|
// RefreshTokenRequest for refreshing access tokens.
|
|
type RefreshTokenRequest struct {
|
|
RefreshToken string `json:"refreshToken"`
|
|
}
|
|
|
|
// AuthTokenResponse includes both access and refresh tokens.
|
|
type AuthTokenResponse struct {
|
|
AccessToken string `json:"accessToken"`
|
|
RefreshToken string `json:"refreshToken"`
|
|
CsrfToken string `json:"csrfToken"` // CSRF token for state-changing requests
|
|
ExpiresIn int64 `json:"expiresIn"` // Access token expiry in seconds
|
|
TokenType string `json:"tokenType"`
|
|
}
|
|
|
|
// LogoutRequest for revoking refresh tokens.
|
|
type LogoutRequest struct {
|
|
RefreshToken string `json:"refreshToken"`
|
|
}
|
|
|
|
var (
|
|
jwtSecret []byte
|
|
db *sql.DB
|
|
defaultRole = "CLIENT"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
var userContextKey = contextKey("userClaims")
|
|
|
|
func main() {
|
|
loadConfig()
|
|
|
|
db = initDB()
|
|
defer db.Close()
|
|
|
|
// Registration routes
|
|
http.HandleFunc("/register-email-password", handleRegisterEmailPassword)
|
|
http.HandleFunc("/register-blockchain", handleRegisterBlockchain)
|
|
|
|
// Login routes
|
|
http.HandleFunc("/login-email-password", handleLoginEmailPassword)
|
|
http.HandleFunc("/login-blockchain", handleLoginBlockchain)
|
|
|
|
// Token management routes
|
|
http.HandleFunc("/auth/refresh", handleRefreshToken)
|
|
http.HandleFunc("/auth/logout", handleLogout)
|
|
http.HandleFunc("/auth/logout-all", authenticate(requireCSRF(handleLogoutAll)))
|
|
|
|
// Identity management routes (protected with CSRF)
|
|
http.HandleFunc("/link-identity", authenticate(requireCSRF(requireRole(handleLinkIdentity, "CLIENT", "STAFF", "ADMIN"))))
|
|
http.HandleFunc("/unlink-identity", authenticate(requireCSRF(requireRole(handleUnlinkIdentity, "CLIENT", "STAFF", "ADMIN"))))
|
|
http.HandleFunc("/identities", authenticate(handleGetIdentities)) // GET doesn't need CSRF
|
|
|
|
// Profile route (protected)
|
|
http.HandleFunc("/profile", authenticate(handleProfile)) // GET doesn't need CSRF
|
|
|
|
// Admin routes (ADMIN only) - Note: SUPERUSER has implicit access
|
|
http.HandleFunc("/admin/users", authenticate(requireRole(handleGetAllUsers, "ADMIN", "SUPERUSER"))) // GET doesn't need CSRF
|
|
http.HandleFunc("/admin/users/promote-role", authenticate(requireCSRF(requireRole(handlePromoteUserRole, "ADMIN", "SUPERUSER"))))
|
|
http.HandleFunc("/admin/users/demote-role", authenticate(requireCSRF(requireRole(handleDemoteUserRole, "ADMIN", "SUPERUSER"))))
|
|
|
|
// Superuser routes (SUPERUSER only) - with CSRF protection
|
|
http.HandleFunc("/superuser/promote", authenticate(requireCSRF(requireRole(handlePromoteSuperuser, "SUPERUSER"))))
|
|
http.HandleFunc("/superuser/demote", authenticate(requireCSRF(requireRole(handleDemoteSuperuser, "SUPERUSER"))))
|
|
http.HandleFunc("/superuser/transfer", authenticate(requireCSRF(requireRole(handleTransferInitialSuperuser, "SUPERUSER"))))
|
|
|
|
// Health check
|
|
http.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
fmt.Fprintln(w, "ok")
|
|
})
|
|
|
|
server := &http.Server{
|
|
Addr: ":8080",
|
|
Handler: limitBodySize(corsMiddleware(http.DefaultServeMux)),
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
ReadTimeout: 15 * time.Second,
|
|
WriteTimeout: 15 * time.Second,
|
|
IdleTimeout: 60 * time.Second,
|
|
}
|
|
|
|
// Graceful shutdown
|
|
done := make(chan bool, 1)
|
|
quit := make(chan os.Signal, 1)
|
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
go func() {
|
|
<-quit
|
|
log.Println("Auth Service shutting down...")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
server.SetKeepAlivesEnabled(false)
|
|
if err := server.Shutdown(ctx); err != nil {
|
|
log.Printf("Could not gracefully shutdown: %v", err)
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
log.Println("Auth Service listening on :8080")
|
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
log.Fatalf("Server error: %v", err)
|
|
}
|
|
|
|
<-done
|
|
log.Println("Auth Service stopped")
|
|
}
|
|
|
|
func loadConfig() {
|
|
jwtSecret = []byte(strings.TrimSpace(os.Getenv("JWT_SECRET")))
|
|
if len(jwtSecret) == 0 {
|
|
log.Fatal("JWT_SECRET must be set")
|
|
}
|
|
if len(jwtSecret) < 32 {
|
|
log.Fatal("JWT_SECRET must be at least 32 characters for security")
|
|
}
|
|
if len(jwtSecret) < 64 {
|
|
log.Println("WARNING: JWT_SECRET is less than 64 characters. Consider using a longer secret for production.")
|
|
}
|
|
|
|
if envDefaultRole := strings.TrimSpace(os.Getenv("DEFAULT_USER_ROLE")); envDefaultRole != "" {
|
|
defaultRole = strings.ToUpper(envDefaultRole)
|
|
}
|
|
}
|
|
|
|
func initDB() *sql.DB {
|
|
user := strings.TrimSpace(os.Getenv("DB_USER"))
|
|
password := strings.TrimSpace(os.Getenv("DB_PASSWORD"))
|
|
name := strings.TrimSpace(os.Getenv("DB_NAME"))
|
|
host := strings.TrimSpace(os.Getenv("DB_HOST"))
|
|
sslMode := strings.TrimSpace(os.Getenv("DB_SSL_MODE"))
|
|
schema := strings.TrimSpace(os.Getenv("DB_SCHEMA"))
|
|
|
|
if user == "" || password == "" || name == "" || host == "" {
|
|
log.Fatal("Database configuration missing: DB_USER, DB_PASSWORD, DB_NAME, DB_HOST required")
|
|
}
|
|
|
|
// Secure default: require TLS for production
|
|
if sslMode == "" {
|
|
sslMode = "require"
|
|
log.Println("WARNING: DB_SSL_MODE not set, defaulting to 'require' for security")
|
|
}
|
|
|
|
// Validate sslMode value
|
|
validSSLModes := map[string]bool{
|
|
"disable": true, // Only for local development
|
|
"require": true, // Minimum for production
|
|
"verify-ca": true, // Better
|
|
"verify-full": true, // Best
|
|
}
|
|
if !validSSLModes[sslMode] {
|
|
log.Fatalf("Invalid DB_SSL_MODE '%s'. Must be: disable, require, verify-ca, or verify-full", sslMode)
|
|
}
|
|
|
|
// Warn if using insecure mode
|
|
if sslMode == "disable" {
|
|
log.Println("WARNING: Database SSL is DISABLED. This should only be used for local development!")
|
|
}
|
|
|
|
// Validate schema value if provided
|
|
validSchemas := map[string]bool{
|
|
"": true, // Empty means use public (default)
|
|
"public": true,
|
|
"dev": true,
|
|
"testing": true,
|
|
"prod": true,
|
|
}
|
|
if !validSchemas[schema] {
|
|
log.Fatalf("Invalid DB_SCHEMA '%s'. Must be: dev, testing, prod, or empty for public", schema)
|
|
}
|
|
|
|
// Build connection string with optional search_path
|
|
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s sslmode=%s",
|
|
user, password, name, host, sslMode)
|
|
|
|
// Add search_path if schema is specified
|
|
if schema != "" && schema != "public" {
|
|
connStr += fmt.Sprintf(" search_path=%s,public", schema)
|
|
}
|
|
|
|
database, err := sql.Open("postgres", connStr)
|
|
if err != nil {
|
|
log.Fatalf("Error opening database: %v", err)
|
|
}
|
|
|
|
// Configure connection pool limits
|
|
database.SetMaxOpenConns(25) // Max open connections to DB
|
|
database.SetMaxIdleConns(5) // Max idle connections in pool
|
|
database.SetConnMaxLifetime(5 * time.Minute) // Max lifetime of a connection
|
|
database.SetConnMaxIdleTime(1 * time.Minute) // Max time a connection can be idle
|
|
|
|
if err := database.Ping(); err != nil {
|
|
log.Fatalf("Error connecting to database: %v", err)
|
|
}
|
|
|
|
schemaInfo := "public"
|
|
if schema != "" && schema != "public" {
|
|
schemaInfo = schema
|
|
}
|
|
log.Printf("Successfully connected to database (SSL mode: %s, schema: %s, max_conns: 25)", sslMode, schemaInfo)
|
|
return database
|
|
}
|
|
|
|
// limitBodySize middleware limits the request body size to prevent DoS attacks
|
|
func limitBodySize(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Body != nil {
|
|
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody)
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func corsMiddleware(next http.Handler) http.Handler {
|
|
// Log CORS configuration on first request
|
|
var corsConfigLogged bool
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
allowedOrigin := strings.TrimSpace(os.Getenv("CORS_ALLOW_ORIGIN"))
|
|
if allowedOrigin == "" {
|
|
// Default to restrictive in production - set CORS_ALLOW_ORIGIN explicitly
|
|
allowedOrigin = "http://localhost:8090"
|
|
if !corsConfigLogged {
|
|
log.Println("WARNING: CORS_ALLOW_ORIGIN not set, defaulting to http://localhost:8090")
|
|
log.Println("For production, set CORS_ALLOW_ORIGIN to your frontend domain (e.g., https://coppertone.tech)")
|
|
corsConfigLogged = true
|
|
}
|
|
} else if allowedOrigin == "*" && !corsConfigLogged {
|
|
log.Println("WARNING: CORS configured to allow ALL origins (*). This is insecure for production!")
|
|
corsConfigLogged = true
|
|
}
|
|
|
|
// CORS headers
|
|
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-CSRF-Token")
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
|
|
// Security headers
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
|
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
|
|
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// ===== REGISTRATION HANDLERS =====
|
|
|
|
func handleRegisterEmailPassword(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// Rate limit registrations by IP
|
|
clientIP := getClientIP(r)
|
|
if registerLimiter.checkRateLimit(clientIP, maxRegisterPerHour, time.Hour) {
|
|
log.Printf("SECURITY: Registration rate limit exceeded for IP %s", clientIP)
|
|
http.Error(w, "Too many registration attempts. Please try again later.", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
var req RegisterEmailPasswordRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate inputs
|
|
if err := validateEmail(req.Email); err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(err)
|
|
return
|
|
}
|
|
|
|
if err := validatePassword(req.Password); err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(err)
|
|
return
|
|
}
|
|
|
|
if err := validateName(req.Name); err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(err)
|
|
return
|
|
}
|
|
|
|
// Normalize email
|
|
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
|
req.Name = strings.TrimSpace(req.Name)
|
|
|
|
// Security: Force CLIENT role for all public registrations
|
|
// First user will be auto-promoted to SUPERUSER with initial_superuser flag
|
|
// Staff/Admin roles can only be granted by existing ADMIN/SUPERUSER
|
|
role := "CLIENT"
|
|
isInitialSuperuser := false
|
|
|
|
// Check if this is the first user (auto-promote to SUPERUSER)
|
|
var userCount int
|
|
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
|
if err != nil {
|
|
log.Printf("Error checking user count: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if userCount == 0 {
|
|
role = "SUPERUSER"
|
|
isInitialSuperuser = true
|
|
log.Println("AUDIT: First user registration - creating INITIAL SUPERUSER (god-like, non-removable)")
|
|
}
|
|
|
|
// Hash password
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
http.Error(w, "Failed to hash password", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Start transaction
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Create user
|
|
var userID int
|
|
err = tx.QueryRow(
|
|
"INSERT INTO users (name, email, is_initial_superuser, is_protected, created_at) VALUES ($1, $2, $3, $4, NOW()) RETURNING id",
|
|
req.Name, req.Email, isInitialSuperuser, isInitialSuperuser,
|
|
).Scan(&userID)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "duplicate key") {
|
|
http.Error(w, "Email already registered", http.StatusConflict)
|
|
} else {
|
|
log.Printf("Error creating user: %v", err)
|
|
http.Error(w, "Failed to create user", http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Create identity
|
|
_, err = tx.Exec(`
|
|
INSERT INTO identities (user_id, type, identifier, credential, is_primary_login, created_at)
|
|
VALUES ($1, 'email_password', $2, $3, true, NOW())
|
|
`, userID, req.Email, string(passwordHash))
|
|
if err != nil {
|
|
http.Error(w, "Failed to create identity", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Assign role
|
|
_, err = tx.Exec(
|
|
"INSERT INTO user_roles (user_id, role, created_at) VALUES ($1, $2, NOW())",
|
|
userID, role,
|
|
)
|
|
if err != nil {
|
|
http.Error(w, "Failed to assign role", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
http.Error(w, "Failed to complete registration", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusCreated)
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"message": "User registered successfully",
|
|
"userId": userID,
|
|
})
|
|
}
|
|
|
|
func handleRegisterBlockchain(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// Rate limit registrations by IP
|
|
clientIP := getClientIP(r)
|
|
if registerLimiter.checkRateLimit(clientIP, maxRegisterPerHour, time.Hour) {
|
|
log.Printf("SECURITY: Blockchain registration rate limit exceeded for IP %s", clientIP)
|
|
http.Error(w, "Too many registration attempts. Please try again later.", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
var req RegisterBlockchainRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate inputs
|
|
if err := validateName(req.Name); err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(err)
|
|
return
|
|
}
|
|
|
|
if req.Address == "" {
|
|
http.Error(w, "Blockchain address is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.Signature == "" || req.Message == "" {
|
|
http.Error(w, "Signature and message are required for verification", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
req.Name = strings.TrimSpace(req.Name)
|
|
|
|
// Security: Force CLIENT role for all public registrations
|
|
// First user will be auto-promoted to SUPERUSER with initial_superuser flag
|
|
// Staff/Admin roles can only be granted by existing ADMIN/SUPERUSER
|
|
role := "CLIENT"
|
|
isInitialSuperuser := false
|
|
|
|
// Check if this is the first user (auto-promote to SUPERUSER)
|
|
var userCount int
|
|
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
|
|
if err != nil {
|
|
log.Printf("Error checking user count: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if userCount == 0 {
|
|
role = "SUPERUSER"
|
|
isInitialSuperuser = true
|
|
log.Println("AUDIT: First user registration (blockchain) - creating INITIAL SUPERUSER (god-like, non-removable)")
|
|
}
|
|
|
|
// Verify signature
|
|
if !verifyEthereumSignature(req.Address, req.Message, req.Signature) {
|
|
http.Error(w, "Invalid signature", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Start transaction
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Create user
|
|
var userID int
|
|
err = tx.QueryRow(
|
|
"INSERT INTO users (name, is_initial_superuser, is_protected, created_at) VALUES ($1, $2, $3, NOW()) RETURNING id",
|
|
req.Name, isInitialSuperuser, isInitialSuperuser,
|
|
).Scan(&userID)
|
|
if err != nil {
|
|
log.Printf("Error creating user: %v", err)
|
|
http.Error(w, "Failed to create user", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Create identity
|
|
_, err = tx.Exec(`
|
|
INSERT INTO identities (user_id, type, identifier, is_primary_login, created_at)
|
|
VALUES ($1, 'blockchain_address', $2, true, NOW())
|
|
`, userID, strings.ToLower(req.Address))
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "duplicate key") {
|
|
http.Error(w, "Blockchain address already registered", http.StatusConflict)
|
|
} else {
|
|
http.Error(w, "Failed to create identity", http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Assign role
|
|
_, err = tx.Exec(
|
|
"INSERT INTO user_roles (user_id, role, created_at) VALUES ($1, $2, NOW())",
|
|
userID, role,
|
|
)
|
|
if err != nil {
|
|
http.Error(w, "Failed to assign role", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
http.Error(w, "Failed to complete registration", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusCreated)
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"message": "User registered successfully",
|
|
"userId": userID,
|
|
})
|
|
}
|
|
|
|
// ===== LOGIN HANDLERS =====
|
|
|
|
func handleLoginEmailPassword(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
clientIP := getClientIP(r)
|
|
|
|
var req LoginEmailPasswordRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Rate limit by IP and email combination
|
|
rateLimitKey := clientIP + ":" + strings.ToLower(req.Email)
|
|
if loginLimiter.checkRateLimit(rateLimitKey, maxLoginAttempts, rateLimitWindow) {
|
|
log.Printf("SECURITY: Rate limit exceeded for %s", rateLimitKey)
|
|
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
// Find identity
|
|
var userID int
|
|
var passwordHash string
|
|
err := db.QueryRow(`
|
|
SELECT user_id, credential
|
|
FROM identities
|
|
WHERE type = 'email_password' AND identifier = $1
|
|
`, req.Email).Scan(&userID, &passwordHash)
|
|
|
|
if err == sql.ErrNoRows {
|
|
loginLimiter.recordFailedAttempt(rateLimitKey)
|
|
log.Printf("SECURITY: Failed login attempt for email %s from IP %s (user not found)", req.Email, clientIP)
|
|
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
|
|
return
|
|
} else if err != nil {
|
|
http.Error(w, "Login failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Verify password
|
|
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(req.Password)); err != nil {
|
|
loginLimiter.recordFailedAttempt(rateLimitKey)
|
|
log.Printf("SECURITY: Failed login attempt for email %s from IP %s (wrong password)", req.Email, clientIP)
|
|
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Clear rate limit on successful login
|
|
loginLimiter.clearAttempts(rateLimitKey)
|
|
log.Printf("AUDIT: Successful login for user_id %d from IP %s", userID, clientIP)
|
|
|
|
// Generate token pair (access + refresh)
|
|
tokenResponse, err := generateTokenPair(userID, clientIP)
|
|
if err != nil {
|
|
log.Printf("Error generating token pair: %v", err)
|
|
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(tokenResponse)
|
|
}
|
|
|
|
func handleLoginBlockchain(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
clientIP := getClientIP(r)
|
|
|
|
var req LoginBlockchainRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Rate limit by IP and address combination
|
|
rateLimitKey := clientIP + ":" + strings.ToLower(req.Address)
|
|
if loginLimiter.checkRateLimit(rateLimitKey, maxLoginAttempts, rateLimitWindow) {
|
|
log.Printf("SECURITY: Rate limit exceeded for blockchain login %s", rateLimitKey)
|
|
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
// Verify signature
|
|
if !verifyEthereumSignature(req.Address, req.Message, req.Signature) {
|
|
loginLimiter.recordFailedAttempt(rateLimitKey)
|
|
log.Printf("SECURITY: Failed blockchain login for address %s from IP %s (invalid signature)", req.Address, clientIP)
|
|
http.Error(w, "Invalid signature", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Find identity
|
|
var userID int
|
|
err := db.QueryRow(`
|
|
SELECT user_id
|
|
FROM identities
|
|
WHERE type = 'blockchain_address' AND identifier = $1
|
|
`, strings.ToLower(req.Address)).Scan(&userID)
|
|
|
|
if err == sql.ErrNoRows {
|
|
loginLimiter.recordFailedAttempt(rateLimitKey)
|
|
log.Printf("SECURITY: Failed blockchain login for address %s from IP %s (not registered)", req.Address, clientIP)
|
|
http.Error(w, "Address not registered", http.StatusUnauthorized)
|
|
return
|
|
} else if err != nil {
|
|
http.Error(w, "Login failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Clear rate limit on successful login
|
|
loginLimiter.clearAttempts(rateLimitKey)
|
|
log.Printf("AUDIT: Successful blockchain login for user_id %d from IP %s", userID, clientIP)
|
|
|
|
// Generate token pair (access + refresh)
|
|
tokenResponse, err := generateTokenPair(userID, clientIP)
|
|
if err != nil {
|
|
log.Printf("Error generating token pair: %v", err)
|
|
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(tokenResponse)
|
|
}
|
|
|
|
// ===== TOKEN REFRESH AND LOGOUT HANDLERS =====
|
|
|
|
// handleRefreshToken exchanges a valid refresh token for a new access token
|
|
func handleRefreshToken(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
clientIP := getClientIP(r)
|
|
|
|
var req RefreshTokenRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.RefreshToken == "" {
|
|
http.Error(w, "Refresh token is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate the refresh token
|
|
tokenID, userID, err := validateRefreshToken(req.RefreshToken)
|
|
if err != nil {
|
|
log.Printf("SECURITY: Invalid refresh token attempt from IP %s: %v", clientIP, err)
|
|
http.Error(w, "Invalid or expired refresh token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Revoke the old refresh token (rotation for security)
|
|
if err := revokeRefreshToken(tokenID); err != nil {
|
|
log.Printf("Warning: Failed to revoke old refresh token: %v", err)
|
|
}
|
|
|
|
// Generate new token pair
|
|
tokenResponse, err := generateTokenPair(userID, clientIP)
|
|
if err != nil {
|
|
log.Printf("Error generating token pair during refresh: %v", err)
|
|
http.Error(w, "Failed to refresh tokens", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: Token refresh for user_id %d from IP %s", userID, clientIP)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(tokenResponse)
|
|
}
|
|
|
|
// handleLogout revokes a specific refresh token
|
|
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
clientIP := getClientIP(r)
|
|
|
|
var req LogoutRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.RefreshToken == "" {
|
|
http.Error(w, "Refresh token is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate and get token info
|
|
tokenID, userID, err := validateRefreshToken(req.RefreshToken)
|
|
if err != nil {
|
|
// Even if token is invalid, return success (don't leak info)
|
|
log.Printf("SECURITY: Logout with invalid token from IP %s", clientIP)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"})
|
|
return
|
|
}
|
|
|
|
// Revoke the refresh token
|
|
if err := revokeRefreshToken(tokenID); err != nil {
|
|
log.Printf("Error revoking refresh token: %v", err)
|
|
http.Error(w, "Failed to logout", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: User %d logged out from IP %s", userID, clientIP)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"})
|
|
}
|
|
|
|
// handleLogoutAll revokes all refresh tokens for the authenticated user
|
|
func handleLogoutAll(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := int(claims["userId"].(float64))
|
|
clientIP := getClientIP(r)
|
|
|
|
if err := revokeAllUserRefreshTokens(userID); err != nil {
|
|
log.Printf("Error revoking all refresh tokens for user %d: %v", userID, err)
|
|
http.Error(w, "Failed to logout from all devices", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: User %d logged out from all devices (initiated from IP %s)", userID, clientIP)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out from all devices successfully"})
|
|
}
|
|
|
|
// ===== IDENTITY MANAGEMENT =====
|
|
|
|
func handleLinkIdentity(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := int(claims["userId"].(float64))
|
|
|
|
var req LinkIdentityRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
switch req.Type {
|
|
case "email_password":
|
|
// Validate email
|
|
if err := validateEmail(req.Email); err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(err)
|
|
return
|
|
}
|
|
|
|
// Validate password
|
|
if err := validatePassword(req.Password); err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(err)
|
|
return
|
|
}
|
|
|
|
// Normalize email
|
|
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
|
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
http.Error(w, "Failed to hash password", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
_, err = db.Exec(`
|
|
INSERT INTO identities (user_id, type, identifier, credential, is_primary_login)
|
|
VALUES ($1, 'email_password', $2, $3, false)
|
|
`, userID, req.Email, string(passwordHash))
|
|
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "duplicate key") {
|
|
http.Error(w, "Identity already linked", http.StatusConflict)
|
|
} else {
|
|
http.Error(w, "Failed to link identity", http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
case "blockchain_address":
|
|
if req.Address == "" || req.Signature == "" || req.Message == "" {
|
|
http.Error(w, "Address, signature, and message required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if !verifyEthereumSignature(req.Address, req.Message, req.Signature) {
|
|
http.Error(w, "Invalid signature", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
_, err := db.Exec(`
|
|
INSERT INTO identities (user_id, type, identifier, is_primary_login)
|
|
VALUES ($1, 'blockchain_address', $2, false)
|
|
`, userID, strings.ToLower(req.Address))
|
|
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "duplicate key") {
|
|
http.Error(w, "Identity already linked", http.StatusConflict)
|
|
} else {
|
|
http.Error(w, "Failed to link identity", http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
default:
|
|
http.Error(w, "Invalid identity type", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusCreated)
|
|
json.NewEncoder(w).Encode(map[string]string{"message": "Identity linked successfully"})
|
|
}
|
|
|
|
// UnlinkIdentityRequest to unlink an identity from account
|
|
type UnlinkIdentityRequest struct {
|
|
IdentityID int `json:"identityId"`
|
|
}
|
|
|
|
func handleUnlinkIdentity(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := int(claims["userId"].(float64))
|
|
|
|
var req UnlinkIdentityRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.IdentityID == 0 {
|
|
http.Error(w, "Identity ID is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Check identity belongs to user
|
|
var identityUserID int
|
|
var isPrimary bool
|
|
err := db.QueryRow("SELECT user_id, is_primary_login FROM identities WHERE id = $1", req.IdentityID).Scan(&identityUserID, &isPrimary)
|
|
if err == sql.ErrNoRows {
|
|
http.Error(w, "Identity not found", http.StatusNotFound)
|
|
return
|
|
} else if err != nil {
|
|
log.Printf("Error checking identity: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if identityUserID != userID {
|
|
http.Error(w, "Forbidden: identity does not belong to you", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Check user has at least 2 identities (can't remove last one)
|
|
var identityCount int
|
|
err = db.QueryRow("SELECT COUNT(*) FROM identities WHERE user_id = $1", userID).Scan(&identityCount)
|
|
if err != nil {
|
|
log.Printf("Error counting identities: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if identityCount <= 1 {
|
|
http.Error(w, "Cannot remove your last identity. You must have at least one login method.", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// If unlinking primary identity, promote another one to primary
|
|
if isPrimary {
|
|
_, err = db.Exec(`
|
|
UPDATE identities
|
|
SET is_primary_login = true
|
|
WHERE user_id = $1 AND id != $2
|
|
LIMIT 1
|
|
`, userID, req.IdentityID)
|
|
if err != nil {
|
|
log.Printf("Error promoting identity: %v", err)
|
|
// Continue anyway - identity will be deleted
|
|
}
|
|
}
|
|
|
|
// Delete the identity
|
|
result, err := db.Exec("DELETE FROM identities WHERE id = $1", req.IdentityID)
|
|
if err != nil {
|
|
log.Printf("Error deleting identity: %v", err)
|
|
http.Error(w, "Failed to unlink identity", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
rowsAffected, _ := result.RowsAffected()
|
|
if rowsAffected == 0 {
|
|
http.Error(w, "Identity not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: User %d unlinked identity %d", userID, req.IdentityID)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{"message": "Identity unlinked successfully"})
|
|
}
|
|
|
|
func handleGetIdentities(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := int(claims["userId"].(float64))
|
|
|
|
rows, err := db.Query(`
|
|
SELECT id, user_id, type, identifier, is_primary_login, created_at
|
|
FROM identities
|
|
WHERE user_id = $1
|
|
`, userID)
|
|
if err != nil {
|
|
http.Error(w, "Failed to fetch identities", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var identities []Identity
|
|
for rows.Next() {
|
|
var id Identity
|
|
err := rows.Scan(&id.ID, &id.UserID, &id.Type, &id.Identifier, &id.IsPrimaryLogin, &id.CreatedAt)
|
|
if err != nil {
|
|
http.Error(w, "Failed to parse identities", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
identities = append(identities, id)
|
|
}
|
|
|
|
json.NewEncoder(w).Encode(identities)
|
|
}
|
|
|
|
func handleProfile(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := int(claims["userId"].(float64))
|
|
|
|
var user User
|
|
err := db.QueryRow(`
|
|
SELECT id, name, email, created_at
|
|
FROM users
|
|
WHERE id = $1
|
|
`, userID).Scan(&user.ID, &user.Name, &user.Email, &user.CreatedAt)
|
|
|
|
if err != nil {
|
|
http.Error(w, "Failed to fetch profile", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Get roles
|
|
rows, err := db.Query("SELECT role FROM user_roles WHERE user_id = $1", userID)
|
|
if err != nil {
|
|
http.Error(w, "Failed to fetch roles", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var roles []string
|
|
for rows.Next() {
|
|
var role string
|
|
rows.Scan(&role)
|
|
roles = append(roles, role)
|
|
}
|
|
user.Roles = roles
|
|
|
|
json.NewEncoder(w).Encode(user)
|
|
}
|
|
|
|
// ===== MIDDLEWARE =====
|
|
|
|
func authenticate(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
http.Error(w, "Missing authorization token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
const bearerPrefix = "Bearer "
|
|
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
|
http.Error(w, "Invalid authorization header format", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, bearerPrefix)
|
|
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method")
|
|
}
|
|
return jwtSecret, nil
|
|
})
|
|
|
|
if err != nil || !token.Valid {
|
|
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userContextKey, claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
func requireRole(next http.HandlerFunc, allowedRoles ...string) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userRoles := claims["roles"].([]interface{})
|
|
|
|
// Check if user has any of the allowed roles
|
|
for _, userRole := range userRoles {
|
|
for _, allowedRole := range allowedRoles {
|
|
if userRole.(string) == allowedRole {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
http.Error(w, "Forbidden: insufficient permissions", http.StatusForbidden)
|
|
}
|
|
}
|
|
|
|
// requireCSRF validates CSRF token for state-changing requests (POST, PUT, DELETE)
|
|
// This middleware should be used after authenticate() for protected endpoints
|
|
func requireCSRF(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Only validate CSRF for state-changing methods
|
|
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := int(claims["userId"].(float64))
|
|
clientIP := getClientIP(r)
|
|
|
|
// Get CSRF token from header
|
|
csrfToken := r.Header.Get("X-CSRF-Token")
|
|
if csrfToken == "" {
|
|
log.Printf("SECURITY: CSRF token missing for user %d from IP %s on %s %s",
|
|
userID, clientIP, r.Method, r.URL.Path)
|
|
http.Error(w, "CSRF token required", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Validate CSRF token
|
|
if !validateCSRFToken(userID, csrfToken) {
|
|
log.Printf("SECURITY: Invalid CSRF token for user %d from IP %s on %s %s",
|
|
userID, clientIP, r.Method, r.URL.Path)
|
|
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Verify Origin header matches expected origin (additional CSRF protection)
|
|
origin := r.Header.Get("Origin")
|
|
allowedOrigin := strings.TrimSpace(os.Getenv("CORS_ALLOW_ORIGIN"))
|
|
if allowedOrigin == "" {
|
|
allowedOrigin = "http://localhost:8090"
|
|
}
|
|
|
|
// If Origin header is present, verify it matches
|
|
if origin != "" && allowedOrigin != "*" && origin != allowedOrigin {
|
|
log.Printf("SECURITY: Origin mismatch for user %d: expected %s, got %s",
|
|
userID, allowedOrigin, origin)
|
|
http.Error(w, "Origin not allowed", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
|
|
// ===== HELPER FUNCTIONS =====
|
|
|
|
// generateAccessToken creates a short-lived JWT access token
|
|
func generateAccessToken(userID int) (string, error) {
|
|
// Get user roles
|
|
rows, err := db.Query("SELECT role FROM user_roles WHERE user_id = $1", userID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var roles []string
|
|
for rows.Next() {
|
|
var role string
|
|
rows.Scan(&role)
|
|
roles = append(roles, role)
|
|
}
|
|
|
|
// Get user email and superuser status
|
|
var email *string
|
|
var isInitialSuperuser bool
|
|
db.QueryRow(`
|
|
SELECT email, COALESCE(is_initial_superuser, false)
|
|
FROM users WHERE id = $1
|
|
`, userID).Scan(&email, &isInitialSuperuser)
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"user_id": userID, // snake_case for downstream services
|
|
"userId": userID, // camelCase for compatibility
|
|
"email": email,
|
|
"roles": roles,
|
|
"isInitialSuperuser": isInitialSuperuser,
|
|
"exp": time.Now().Add(accessTokenExpiry).Unix(),
|
|
"type": "access",
|
|
})
|
|
|
|
return token.SignedString(jwtSecret)
|
|
}
|
|
|
|
// generateRefreshToken creates a secure random refresh token and stores it in the database
|
|
func generateRefreshToken(userID int, clientIP string) (string, error) {
|
|
// Generate cryptographically secure random token
|
|
tokenBytes := make([]byte, refreshTokenLength)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
return "", fmt.Errorf("failed to generate random token: %w", err)
|
|
}
|
|
token := hex.EncodeToString(tokenBytes)
|
|
|
|
// Hash the token before storing (we only store the hash)
|
|
tokenHash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to hash refresh token: %w", err)
|
|
}
|
|
|
|
expiresAt := time.Now().Add(refreshTokenExpiry)
|
|
|
|
// Store hashed token in database
|
|
_, err = db.Exec(`
|
|
INSERT INTO refresh_tokens (user_id, token_hash, expires_at, client_ip, created_at)
|
|
VALUES ($1, $2, $3, $4, NOW())
|
|
`, userID, string(tokenHash), expiresAt, clientIP)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to store refresh token: %w", err)
|
|
}
|
|
|
|
log.Printf("AUDIT: Refresh token created for user_id %d from IP %s, expires %s",
|
|
userID, clientIP, expiresAt.Format(time.RFC3339))
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// validateRefreshToken validates a refresh token and returns the user ID
|
|
func validateRefreshToken(token string) (int, int, error) {
|
|
// Get all non-expired tokens and check against provided token
|
|
rows, err := db.Query(`
|
|
SELECT id, user_id, token_hash
|
|
FROM refresh_tokens
|
|
WHERE expires_at > NOW() AND revoked_at IS NULL
|
|
`)
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("failed to query refresh tokens: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var tokenID, userID int
|
|
var tokenHash string
|
|
if err := rows.Scan(&tokenID, &userID, &tokenHash); err != nil {
|
|
continue
|
|
}
|
|
|
|
// Compare token with hash
|
|
if bcrypt.CompareHashAndPassword([]byte(tokenHash), []byte(token)) == nil {
|
|
return tokenID, userID, nil
|
|
}
|
|
}
|
|
|
|
return 0, 0, errors.New("invalid or expired refresh token")
|
|
}
|
|
|
|
// revokeRefreshToken marks a refresh token as revoked
|
|
func revokeRefreshToken(tokenID int) error {
|
|
_, err := db.Exec(`
|
|
UPDATE refresh_tokens
|
|
SET revoked_at = NOW()
|
|
WHERE id = $1
|
|
`, tokenID)
|
|
return err
|
|
}
|
|
|
|
// revokeAllUserRefreshTokens revokes all refresh tokens for a user
|
|
func revokeAllUserRefreshTokens(userID int) error {
|
|
result, err := db.Exec(`
|
|
UPDATE refresh_tokens
|
|
SET revoked_at = NOW()
|
|
WHERE user_id = $1 AND revoked_at IS NULL
|
|
`, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rowsAffected, _ := result.RowsAffected()
|
|
log.Printf("AUDIT: Revoked %d refresh tokens for user_id %d", rowsAffected, userID)
|
|
return nil
|
|
}
|
|
|
|
// ===== CSRF TOKEN MANAGEMENT =====
|
|
|
|
// generateCSRFToken creates a cryptographically secure CSRF token
|
|
func generateCSRFToken() (string, error) {
|
|
tokenBytes := make([]byte, csrfTokenLength)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
return "", fmt.Errorf("failed to generate CSRF token: %w", err)
|
|
}
|
|
return hex.EncodeToString(tokenBytes), nil
|
|
}
|
|
|
|
// storeCSRFToken stores a CSRF token in the database linked to a user
|
|
func storeCSRFToken(userID int, csrfToken string, clientIP string) error {
|
|
// Hash the CSRF token before storing
|
|
tokenHash, err := bcrypt.GenerateFromPassword([]byte(csrfToken), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to hash CSRF token: %w", err)
|
|
}
|
|
|
|
expiresAt := time.Now().Add(accessTokenExpiry) // CSRF token expires with access token
|
|
|
|
// Store hashed token in database
|
|
_, err = db.Exec(`
|
|
INSERT INTO csrf_tokens (user_id, token_hash, expires_at, client_ip, created_at)
|
|
VALUES ($1, $2, $3, $4, NOW())
|
|
`, userID, string(tokenHash), expiresAt, clientIP)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to store CSRF token: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateCSRFToken validates a CSRF token for a user
|
|
func validateCSRFToken(userID int, token string) bool {
|
|
if token == "" {
|
|
return false
|
|
}
|
|
|
|
// Get all non-expired tokens for this user and check against provided token
|
|
rows, err := db.Query(`
|
|
SELECT token_hash
|
|
FROM csrf_tokens
|
|
WHERE user_id = $1 AND expires_at > NOW()
|
|
`, userID)
|
|
if err != nil {
|
|
log.Printf("Error querying CSRF tokens: %v", err)
|
|
return false
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var tokenHash string
|
|
if err := rows.Scan(&tokenHash); err != nil {
|
|
continue
|
|
}
|
|
|
|
// Compare token with hash
|
|
if bcrypt.CompareHashAndPassword([]byte(tokenHash), []byte(token)) == nil {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// cleanupExpiredCSRFTokens removes expired CSRF tokens (call periodically)
|
|
func cleanupExpiredCSRFTokens() {
|
|
result, err := db.Exec(`DELETE FROM csrf_tokens WHERE expires_at < NOW()`)
|
|
if err != nil {
|
|
log.Printf("Error cleaning up expired CSRF tokens: %v", err)
|
|
return
|
|
}
|
|
rowsAffected, _ := result.RowsAffected()
|
|
if rowsAffected > 0 {
|
|
log.Printf("Cleaned up %d expired CSRF tokens", rowsAffected)
|
|
}
|
|
}
|
|
|
|
// revokeUserCSRFTokens revokes all CSRF tokens for a user (on logout)
|
|
func revokeUserCSRFTokens(userID int) error {
|
|
_, err := db.Exec(`DELETE FROM csrf_tokens WHERE user_id = $1`, userID)
|
|
return err
|
|
}
|
|
|
|
// generateTokenPair creates access, refresh, and CSRF tokens
|
|
func generateTokenPair(userID int, clientIP string) (*AuthTokenResponse, error) {
|
|
accessToken, err := generateAccessToken(userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate access token: %w", err)
|
|
}
|
|
|
|
refreshToken, err := generateRefreshToken(userID, clientIP)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
|
}
|
|
|
|
// Generate CSRF token
|
|
csrfToken, err := generateCSRFToken()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
|
|
}
|
|
|
|
// Store CSRF token in database
|
|
if err := storeCSRFToken(userID, csrfToken, clientIP); err != nil {
|
|
log.Printf("Warning: failed to store CSRF token: %v", err)
|
|
// Continue anyway - CSRF validation will fail but auth still works
|
|
}
|
|
|
|
return &AuthTokenResponse{
|
|
AccessToken: accessToken,
|
|
RefreshToken: refreshToken,
|
|
CsrfToken: csrfToken,
|
|
ExpiresIn: int64(accessTokenExpiry.Seconds()),
|
|
TokenType: "Bearer",
|
|
}, nil
|
|
}
|
|
|
|
// Legacy function for backward compatibility
|
|
func generateToken(userID int) (string, error) {
|
|
return generateAccessToken(userID)
|
|
}
|
|
|
|
func extractRoles(claims jwt.MapClaims) ([]string, error) {
|
|
rawRoles, ok := claims["roles"]
|
|
if !ok {
|
|
return nil, errors.New("roles not present in token")
|
|
}
|
|
|
|
switch v := rawRoles.(type) {
|
|
case []interface{}:
|
|
roles := make([]string, 0, len(v))
|
|
for _, roleVal := range v {
|
|
roleStr, ok := roleVal.(string)
|
|
if !ok {
|
|
return nil, errors.New("role value not a string")
|
|
}
|
|
roles = append(roles, roleStr)
|
|
}
|
|
return roles, nil
|
|
case []string:
|
|
return v, nil
|
|
default:
|
|
return nil, errors.New("roles claim has unexpected type")
|
|
}
|
|
}
|
|
|
|
func verifyEthereumSignature(address, message, signature string) bool {
|
|
address = normalizeEthereumAddress(address)
|
|
if address == "" {
|
|
return false
|
|
}
|
|
|
|
// Decode signature
|
|
sigBytes, err := hex.DecodeString(strings.TrimPrefix(signature, "0x"))
|
|
if err != nil || len(sigBytes) != 65 {
|
|
return false
|
|
}
|
|
|
|
// Ethereum specific: adjust v value
|
|
if sigBytes[64] == 27 || sigBytes[64] == 28 {
|
|
sigBytes[64] -= 27
|
|
}
|
|
|
|
// Hash message with Ethereum prefix
|
|
messageHash := crypto.Keccak256Hash([]byte(fmt.Sprintf("\x19Ethereum Signed Message:\n%d%s", len(message), message)))
|
|
|
|
// Recover public key
|
|
pubKey, err := crypto.SigToPub(messageHash.Bytes(), sigBytes)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Get address from public key
|
|
recoveredAddr := crypto.PubkeyToAddress(*pubKey).Hex()
|
|
|
|
return strings.ToLower(recoveredAddr) == address
|
|
}
|
|
|
|
// === Test helpers and shared utilities ===
|
|
|
|
func generateJWT(user User, roles []string) (string, error) {
|
|
if len(jwtSecret) == 0 {
|
|
return "", errors.New("JWT_SECRET not configured")
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"user_id": user.ID,
|
|
"userId": user.ID,
|
|
"email": user.Email,
|
|
"roles": roles,
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
})
|
|
return token.SignedString(jwtSecret)
|
|
}
|
|
|
|
func hashPassword(password string) (string, error) {
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(hash), nil
|
|
}
|
|
|
|
func checkPasswordHash(password, hash string) bool {
|
|
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
|
|
}
|
|
|
|
func normalizeEthereumAddress(addr string) string {
|
|
addr = strings.ToLower(strings.TrimSpace(addr))
|
|
if addr == "" {
|
|
return ""
|
|
}
|
|
if !strings.HasPrefix(addr, "0x") {
|
|
addr = "0x" + addr
|
|
}
|
|
if len(addr) != 42 { // 0x + 40 hex chars
|
|
return ""
|
|
}
|
|
return addr
|
|
}
|
|
|
|
// ===== ADMIN ENDPOINTS =====
|
|
|
|
// PromoteUserRoleRequest for promoting user roles
|
|
type PromoteUserRoleRequest struct {
|
|
UserID int `json:"userId"`
|
|
Role string `json:"role"`
|
|
}
|
|
|
|
// DemoteUserRoleRequest for demoting user roles
|
|
type DemoteUserRoleRequest struct {
|
|
UserID int `json:"userId"`
|
|
Role string `json:"role"`
|
|
}
|
|
|
|
// handleGetAllUsers returns all users (ADMIN/SUPERUSER only)
|
|
func handleGetAllUsers(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
rows, err := db.Query(`
|
|
SELECT u.id, u.name, u.email,
|
|
COALESCE(u.is_initial_superuser, false),
|
|
COALESCE(u.is_protected, false),
|
|
u.created_at
|
|
FROM users u
|
|
ORDER BY u.created_at DESC
|
|
`)
|
|
if err != nil {
|
|
log.Printf("Error fetching users: %v", err)
|
|
http.Error(w, "Failed to fetch users", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var users []User
|
|
for rows.Next() {
|
|
var u User
|
|
err := rows.Scan(&u.ID, &u.Name, &u.Email, &u.IsInitialSuperuser, &u.IsProtected, &u.CreatedAt)
|
|
if err != nil {
|
|
log.Printf("Error scanning user: %v", err)
|
|
http.Error(w, "Failed to parse users", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
users = append(users, u)
|
|
}
|
|
|
|
// Fetch roles for each user
|
|
for i := range users {
|
|
roleRows, err := db.Query("SELECT role FROM user_roles WHERE user_id = $1", users[i].ID)
|
|
if err != nil {
|
|
log.Printf("Error fetching roles for user %d: %v", users[i].ID, err)
|
|
continue
|
|
}
|
|
var roles []string
|
|
for roleRows.Next() {
|
|
var role string
|
|
roleRows.Scan(&role)
|
|
roles = append(roles, role)
|
|
}
|
|
roleRows.Close()
|
|
users[i].Roles = roles
|
|
}
|
|
|
|
// Return empty array instead of null if no users
|
|
if users == nil {
|
|
users = []User{}
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(users)
|
|
}
|
|
|
|
// handleDemoteUserRole allows ADMIN/SUPERUSER users to remove roles from other users
|
|
// ADMIN can only demote CLIENT, STAFF, ADMIN roles (cannot touch SUPERUSER)
|
|
// SUPERUSER can demote any role except SUPERUSER (use /superuser/demote for that)
|
|
func handleDemoteUserRole(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
adminUserID := int(claims["userId"].(float64))
|
|
callerRoles, _ := extractRoles(claims)
|
|
|
|
var req DemoteUserRoleRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate role - SUPERUSER demotion must go through /superuser/demote
|
|
if strings.ToUpper(req.Role) == "SUPERUSER" {
|
|
http.Error(w, "Use /superuser/demote endpoint to remove SUPERUSER role", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
validRoles := map[string]bool{"CLIENT": true, "STAFF": true, "ADMIN": true}
|
|
if !validRoles[strings.ToUpper(req.Role)] {
|
|
http.Error(w, "Invalid role. Must be CLIENT, STAFF, or ADMIN", http.StatusBadRequest)
|
|
return
|
|
}
|
|
req.Role = strings.ToUpper(req.Role)
|
|
|
|
// Check if caller is SUPERUSER
|
|
isSuperuser := false
|
|
for _, r := range callerRoles {
|
|
if r == "SUPERUSER" {
|
|
isSuperuser = true
|
|
break
|
|
}
|
|
}
|
|
|
|
// If not superuser, verify target is not a superuser (ADMINs cannot touch SUPERUSERs)
|
|
if !isSuperuser {
|
|
var targetHasSuperuser bool
|
|
err := db.QueryRow(`
|
|
SELECT EXISTS(SELECT 1 FROM user_roles WHERE user_id = $1 AND role = 'SUPERUSER')
|
|
`, req.UserID).Scan(&targetHasSuperuser)
|
|
if err == nil && targetHasSuperuser {
|
|
http.Error(w, "Forbidden: ADMINs cannot modify SUPERUSER accounts", http.StatusForbidden)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Prevent admin from demoting themselves from ADMIN role
|
|
if req.UserID == adminUserID && req.Role == "ADMIN" && !isSuperuser {
|
|
http.Error(w, "Cannot remove your own ADMIN role", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Validate user exists
|
|
var userName string
|
|
err := db.QueryRow("SELECT name FROM users WHERE id = $1", req.UserID).Scan(&userName)
|
|
if err == sql.ErrNoRows {
|
|
http.Error(w, "User not found", http.StatusNotFound)
|
|
return
|
|
} else if err != nil {
|
|
log.Printf("Error checking user existence: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Check if user has this role
|
|
var existingRoleID int
|
|
err = db.QueryRow(`
|
|
SELECT id FROM user_roles
|
|
WHERE user_id = $1 AND role = $2
|
|
`, req.UserID, req.Role).Scan(&existingRoleID)
|
|
|
|
if err == sql.ErrNoRows {
|
|
http.Error(w, fmt.Sprintf("User does not have %s role", req.Role), http.StatusNotFound)
|
|
return
|
|
} else if err != nil {
|
|
log.Printf("Error checking existing role: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Check if this is the user's only role - don't allow removal
|
|
var roleCount int
|
|
err = db.QueryRow("SELECT COUNT(*) FROM user_roles WHERE user_id = $1", req.UserID).Scan(&roleCount)
|
|
if err != nil {
|
|
log.Printf("Error counting roles: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if roleCount <= 1 {
|
|
http.Error(w, "Cannot remove user's only role. Assign a different role first.", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Delete the role
|
|
_, err = db.Exec(`
|
|
DELETE FROM user_roles
|
|
WHERE user_id = $1 AND role = $2
|
|
`, req.UserID, req.Role)
|
|
|
|
if err != nil {
|
|
log.Printf("Error deleting role: %v", err)
|
|
http.Error(w, "Failed to remove role", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Log the role change for audit trail
|
|
log.Printf("AUDIT: Admin user %d removed %s role from user %d (%s)",
|
|
adminUserID, req.Role, req.UserID, userName)
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"message": fmt.Sprintf("Successfully removed %s role from user %d", req.Role, req.UserID),
|
|
})
|
|
}
|
|
|
|
// handlePromoteUserRole allows ADMIN/SUPERUSER users to grant roles to other users
|
|
// ADMIN can only promote to CLIENT, STAFF, ADMIN
|
|
// SUPERUSER can promote to any role except SUPERUSER (use /superuser/promote for that)
|
|
func handlePromoteUserRole(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
adminUserID := int(claims["userId"].(float64))
|
|
callerRoles, _ := extractRoles(claims)
|
|
|
|
var req PromoteUserRoleRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate role - SUPERUSER promotion must go through /superuser/promote
|
|
validRoles := map[string]bool{"CLIENT": true, "STAFF": true, "ADMIN": true}
|
|
if !validRoles[strings.ToUpper(req.Role)] {
|
|
if strings.ToUpper(req.Role) == "SUPERUSER" {
|
|
http.Error(w, "Use /superuser/promote endpoint to promote to SUPERUSER", http.StatusBadRequest)
|
|
return
|
|
}
|
|
http.Error(w, "Invalid role. Must be CLIENT, STAFF, or ADMIN", http.StatusBadRequest)
|
|
return
|
|
}
|
|
req.Role = strings.ToUpper(req.Role)
|
|
|
|
// Check if caller is SUPERUSER (they can do anything except promote to SUPERUSER here)
|
|
isSuperuser := false
|
|
for _, r := range callerRoles {
|
|
if r == "SUPERUSER" {
|
|
isSuperuser = true
|
|
break
|
|
}
|
|
}
|
|
|
|
// If not superuser, verify target is not a superuser (ADMINs cannot touch SUPERUSERs)
|
|
if !isSuperuser {
|
|
var targetHasSuperuser bool
|
|
err := db.QueryRow(`
|
|
SELECT EXISTS(SELECT 1 FROM user_roles WHERE user_id = $1 AND role = 'SUPERUSER')
|
|
`, req.UserID).Scan(&targetHasSuperuser)
|
|
if err == nil && targetHasSuperuser {
|
|
http.Error(w, "Forbidden: ADMINs cannot modify SUPERUSER accounts", http.StatusForbidden)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Validate user exists
|
|
var userName string
|
|
err := db.QueryRow("SELECT name FROM users WHERE id = $1", req.UserID).Scan(&userName)
|
|
if err == sql.ErrNoRows {
|
|
http.Error(w, "User not found", http.StatusNotFound)
|
|
return
|
|
} else if err != nil {
|
|
log.Printf("Error checking user existence: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Check if user already has this role
|
|
var existingRoleID int
|
|
err = db.QueryRow(`
|
|
SELECT id FROM user_roles
|
|
WHERE user_id = $1 AND role = $2
|
|
`, req.UserID, req.Role).Scan(&existingRoleID)
|
|
|
|
if err == nil {
|
|
http.Error(w, fmt.Sprintf("User already has %s role", req.Role), http.StatusConflict)
|
|
return
|
|
} else if err != sql.ErrNoRows {
|
|
log.Printf("Error checking existing role: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Insert new role
|
|
_, err = db.Exec(`
|
|
INSERT INTO user_roles (user_id, role, created_at)
|
|
VALUES ($1, $2, NOW())
|
|
`, req.UserID, req.Role)
|
|
|
|
if err != nil {
|
|
log.Printf("Error inserting role: %v", err)
|
|
http.Error(w, "Failed to assign role", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Log the role change for audit trail
|
|
log.Printf("AUDIT: Admin user %d granted %s role to user %d (%s)",
|
|
adminUserID, req.Role, req.UserID, userName)
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"message": fmt.Sprintf("Successfully granted %s role to user %d", req.Role, req.UserID),
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// SUPERUSER MANAGEMENT ENDPOINTS
|
|
// =============================================================================
|
|
|
|
// SuperuserPromoteRequest for promoting user to SUPERUSER
|
|
type SuperuserPromoteRequest struct {
|
|
UserID int `json:"userId"`
|
|
}
|
|
|
|
// SuperuserDemoteRequest for demoting a SUPERUSER
|
|
type SuperuserDemoteRequest struct {
|
|
UserID int `json:"userId"`
|
|
}
|
|
|
|
// TransferInitialSuperuserRequest for transferring initial superuser status
|
|
type TransferInitialSuperuserRequest struct {
|
|
NewSuperuserID int `json:"newSuperuserId"`
|
|
Reason string `json:"reason,omitempty"`
|
|
}
|
|
|
|
// handlePromoteSuperuser promotes a user to SUPERUSER (SUPERUSER only)
|
|
func handlePromoteSuperuser(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
callerUserID := int(claims["userId"].(float64))
|
|
|
|
var req SuperuserPromoteRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate user exists
|
|
var userName string
|
|
err := db.QueryRow("SELECT name FROM users WHERE id = $1", req.UserID).Scan(&userName)
|
|
if err == sql.ErrNoRows {
|
|
http.Error(w, "User not found", http.StatusNotFound)
|
|
return
|
|
} else if err != nil {
|
|
log.Printf("Error checking user existence: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Check if user already has SUPERUSER role
|
|
var existingRoleID int
|
|
err = db.QueryRow(`
|
|
SELECT id FROM user_roles
|
|
WHERE user_id = $1 AND role = 'SUPERUSER'
|
|
`, req.UserID).Scan(&existingRoleID)
|
|
|
|
if err == nil {
|
|
http.Error(w, "User is already a SUPERUSER", http.StatusConflict)
|
|
return
|
|
} else if err != sql.ErrNoRows {
|
|
log.Printf("Error checking existing role: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Insert SUPERUSER role
|
|
_, err = db.Exec(`
|
|
INSERT INTO user_roles (user_id, role, created_at)
|
|
VALUES ($1, 'SUPERUSER', NOW())
|
|
`, req.UserID)
|
|
|
|
if err != nil {
|
|
log.Printf("Error inserting SUPERUSER role: %v", err)
|
|
http.Error(w, "Failed to promote user to SUPERUSER", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: SUPERUSER %d promoted user %d (%s) to SUPERUSER",
|
|
callerUserID, req.UserID, userName)
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"message": fmt.Sprintf("Successfully promoted user %d to SUPERUSER", req.UserID),
|
|
})
|
|
}
|
|
|
|
// handleDemoteSuperuser removes SUPERUSER role from a user (SUPERUSER only)
|
|
// Cannot demote initial superuser - they must transfer first
|
|
func handleDemoteSuperuser(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
callerUserID := int(claims["userId"].(float64))
|
|
|
|
var req SuperuserDemoteRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Cannot demote yourself
|
|
if req.UserID == callerUserID {
|
|
http.Error(w, "Cannot demote yourself. Have another SUPERUSER do it.", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Check if target is the initial superuser
|
|
var isInitialSU bool
|
|
var userName string
|
|
err := db.QueryRow(`
|
|
SELECT name, COALESCE(is_initial_superuser, false)
|
|
FROM users WHERE id = $1
|
|
`, req.UserID).Scan(&userName, &isInitialSU)
|
|
|
|
if err == sql.ErrNoRows {
|
|
http.Error(w, "User not found", http.StatusNotFound)
|
|
return
|
|
} else if err != nil {
|
|
log.Printf("Error checking user: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if isInitialSU {
|
|
http.Error(w, "Cannot demote the INITIAL SUPERUSER. They must transfer their status first using /superuser/transfer", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Check if user has SUPERUSER role
|
|
var existingRoleID int
|
|
err = db.QueryRow(`
|
|
SELECT id FROM user_roles
|
|
WHERE user_id = $1 AND role = 'SUPERUSER'
|
|
`, req.UserID).Scan(&existingRoleID)
|
|
|
|
if err == sql.ErrNoRows {
|
|
http.Error(w, "User does not have SUPERUSER role", http.StatusNotFound)
|
|
return
|
|
} else if err != nil {
|
|
log.Printf("Error checking role: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Ensure user has at least one other role
|
|
var roleCount int
|
|
err = db.QueryRow("SELECT COUNT(*) FROM user_roles WHERE user_id = $1", req.UserID).Scan(&roleCount)
|
|
if err != nil {
|
|
log.Printf("Error counting roles: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if roleCount <= 1 {
|
|
// Add CLIENT role before removing SUPERUSER
|
|
_, err = db.Exec(`
|
|
INSERT INTO user_roles (user_id, role, created_at)
|
|
VALUES ($1, 'CLIENT', NOW())
|
|
ON CONFLICT (user_id, role) DO NOTHING
|
|
`, req.UserID)
|
|
if err != nil {
|
|
log.Printf("Error adding fallback CLIENT role: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Remove SUPERUSER role
|
|
_, err = db.Exec(`
|
|
DELETE FROM user_roles
|
|
WHERE user_id = $1 AND role = 'SUPERUSER'
|
|
`, req.UserID)
|
|
|
|
if err != nil {
|
|
log.Printf("Error removing SUPERUSER role: %v", err)
|
|
http.Error(w, "Failed to demote SUPERUSER", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: SUPERUSER %d demoted user %d (%s) from SUPERUSER",
|
|
callerUserID, req.UserID, userName)
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"message": fmt.Sprintf("Successfully removed SUPERUSER role from user %d", req.UserID),
|
|
})
|
|
}
|
|
|
|
// handleTransferInitialSuperuser transfers initial superuser status to another user
|
|
// ONLY the current initial superuser can call this (self-transfer only)
|
|
func handleTransferInitialSuperuser(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
callerUserID := int(claims["userId"].(float64))
|
|
|
|
// Verify caller is the initial superuser
|
|
var isInitialSU bool
|
|
err := db.QueryRow(`
|
|
SELECT COALESCE(is_initial_superuser, false)
|
|
FROM users WHERE id = $1
|
|
`, callerUserID).Scan(&isInitialSU)
|
|
|
|
if err != nil {
|
|
log.Printf("Error checking initial superuser status: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if !isInitialSU {
|
|
http.Error(w, "Forbidden: Only the INITIAL SUPERUSER can transfer their status", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
var req TransferInitialSuperuserRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.NewSuperuserID == callerUserID {
|
|
http.Error(w, "Cannot transfer to yourself", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate new superuser exists
|
|
var newUserName string
|
|
err = db.QueryRow("SELECT name FROM users WHERE id = $1", req.NewSuperuserID).Scan(&newUserName)
|
|
if err == sql.ErrNoRows {
|
|
http.Error(w, "Target user not found", http.StatusNotFound)
|
|
return
|
|
} else if err != nil {
|
|
log.Printf("Error checking user existence: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Start transaction for transfer
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
log.Printf("Error starting transaction: %v", err)
|
|
http.Error(w, "Database error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// 1. Remove initial_superuser flag from current user
|
|
_, err = tx.Exec(`
|
|
UPDATE users SET is_initial_superuser = false
|
|
WHERE id = $1
|
|
`, callerUserID)
|
|
if err != nil {
|
|
log.Printf("Error removing initial superuser flag: %v", err)
|
|
http.Error(w, "Transfer failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 2. Set initial_superuser and is_protected on new user
|
|
_, err = tx.Exec(`
|
|
UPDATE users SET is_initial_superuser = true, is_protected = true
|
|
WHERE id = $1
|
|
`, req.NewSuperuserID)
|
|
if err != nil {
|
|
log.Printf("Error setting initial superuser flag: %v", err)
|
|
http.Error(w, "Transfer failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 3. Ensure new user has SUPERUSER role
|
|
_, err = tx.Exec(`
|
|
INSERT INTO user_roles (user_id, role, created_at)
|
|
VALUES ($1, 'SUPERUSER', NOW())
|
|
ON CONFLICT (user_id, role) DO NOTHING
|
|
`, req.NewSuperuserID)
|
|
if err != nil {
|
|
log.Printf("Error granting SUPERUSER role: %v", err)
|
|
http.Error(w, "Transfer failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 4. Record the transfer for audit
|
|
_, err = tx.Exec(`
|
|
INSERT INTO superuser_transfers (from_user_id, to_user_id, reason)
|
|
VALUES ($1, $2, $3)
|
|
`, callerUserID, req.NewSuperuserID, req.Reason)
|
|
if err != nil {
|
|
log.Printf("Error recording transfer: %v", err)
|
|
// Don't fail on audit record
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
log.Printf("Error committing transfer: %v", err)
|
|
http.Error(w, "Transfer failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: INITIAL SUPERUSER transferred from user %d to user %d (%s). Reason: %s",
|
|
callerUserID, req.NewSuperuserID, newUserName, req.Reason)
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"message": fmt.Sprintf("Successfully transferred INITIAL SUPERUSER status to user %d (%s)", req.NewSuperuserID, newUserName),
|
|
})
|
|
}
|