Files
2025-12-26 13:38:04 +01:00

2386 lines
71 KiB
Go

package main
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"regexp"
"strings"
"sync"
"syscall"
"time"
"unicode"
"github.com/ethereum/go-ethereum/crypto"
"github.com/golang-jwt/jwt/v5"
_ "github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
)
// Rate limiting configuration
const (
rateLimitWindow = 15 * time.Minute // Window for counting attempts
maxLoginAttempts = 5 // Max failed login attempts per window
maxRegisterPerHour = 10 // Max registrations per IP per hour
lockoutDuration = 30 * time.Minute // How long to lock out after max attempts
)
// Token expiration configuration
const (
accessTokenExpiry = 15 * time.Minute // Short-lived access tokens
refreshTokenExpiry = 7 * 24 * time.Hour // 7 day refresh tokens
refreshTokenLength = 32 // 256-bit refresh token
csrfTokenLength = 32 // 256-bit CSRF token
)
// rateLimiter tracks login attempts per IP/email
type rateLimiter struct {
mu sync.RWMutex
attempts map[string]*attemptInfo
}
type attemptInfo struct {
count int
firstTry time.Time
lockedOut bool
lockUntil time.Time
}
var loginLimiter = &rateLimiter{attempts: make(map[string]*attemptInfo)}
var registerLimiter = &rateLimiter{attempts: make(map[string]*attemptInfo)}
// checkRateLimit returns true if the request should be blocked
func (rl *rateLimiter) checkRateLimit(key string, maxAttempts int, window time.Duration) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
info, exists := rl.attempts[key]
if !exists {
rl.attempts[key] = &attemptInfo{count: 1, firstTry: now}
return false
}
// Check if locked out
if info.lockedOut && now.Before(info.lockUntil) {
return true
}
// Reset lockout if time has passed
if info.lockedOut && now.After(info.lockUntil) {
info.lockedOut = false
info.count = 1
info.firstTry = now
return false
}
// Reset window if expired
if now.Sub(info.firstTry) > window {
info.count = 1
info.firstTry = now
return false
}
// Increment and check
info.count++
if info.count > maxAttempts {
info.lockedOut = true
info.lockUntil = now.Add(lockoutDuration)
return true
}
return false
}
// recordFailedAttempt records a failed attempt (for login failures)
func (rl *rateLimiter) recordFailedAttempt(key string) {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
info, exists := rl.attempts[key]
if !exists {
rl.attempts[key] = &attemptInfo{count: 1, firstTry: now}
return
}
// Reset window if expired
if now.Sub(info.firstTry) > rateLimitWindow {
info.count = 1
info.firstTry = now
return
}
info.count++
if info.count >= maxLoginAttempts {
info.lockedOut = true
info.lockUntil = now.Add(lockoutDuration)
log.Printf("SECURITY: IP/email %s locked out after %d failed attempts", key, info.count)
}
}
// clearAttempts clears attempts for a key (called on successful login)
func (rl *rateLimiter) clearAttempts(key string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.attempts, key)
}
// getClientIP extracts the client IP from the request
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (for proxies)
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
// Take the first IP in the chain
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
xri := r.Header.Get("X-Real-IP")
if xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}
// Input validation constants
const (
maxNameLength = 100
maxEmailLength = 254 // RFC 5321 limit
minPasswordLength = 8
maxPasswordLength = 72 // bcrypt limit
maxRequestBody = 1 << 20 // 1MB max request body size
)
// Email validation regex (RFC 5322 simplified)
var emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
// ValidationError holds field-specific validation errors
type ValidationError struct {
Field string `json:"field"`
Message string `json:"message"`
}
// validateEmail checks email format and length
func validateEmail(email string) *ValidationError {
email = strings.TrimSpace(email)
if email == "" {
return &ValidationError{Field: "email", Message: "Email is required"}
}
if len(email) > maxEmailLength {
return &ValidationError{Field: "email", Message: fmt.Sprintf("Email must be less than %d characters", maxEmailLength)}
}
if !emailRegex.MatchString(email) {
return &ValidationError{Field: "email", Message: "Invalid email format"}
}
return nil
}
// validatePassword checks password requirements
func validatePassword(password string) *ValidationError {
if len(password) < minPasswordLength {
return &ValidationError{Field: "password", Message: fmt.Sprintf("Password must be at least %d characters", minPasswordLength)}
}
if len(password) > maxPasswordLength {
return &ValidationError{Field: "password", Message: fmt.Sprintf("Password must be less than %d characters", maxPasswordLength)}
}
var hasUpper, hasLower, hasNumber bool
for _, c := range password {
switch {
case unicode.IsUpper(c):
hasUpper = true
case unicode.IsLower(c):
hasLower = true
case unicode.IsNumber(c):
hasNumber = true
}
}
if !hasUpper {
return &ValidationError{Field: "password", Message: "Password must contain at least one uppercase letter"}
}
if !hasLower {
return &ValidationError{Field: "password", Message: "Password must contain at least one lowercase letter"}
}
if !hasNumber {
return &ValidationError{Field: "password", Message: "Password must contain at least one number"}
}
return nil
}
// validateName checks name length
func validateName(name string) *ValidationError {
name = strings.TrimSpace(name)
if name == "" {
return &ValidationError{Field: "name", Message: "Name is required"}
}
if len(name) > maxNameLength {
return &ValidationError{Field: "name", Message: fmt.Sprintf("Name must be less than %d characters", maxNameLength)}
}
return nil
}
// User represents a user in the system.
type User struct {
ID int `json:"id"`
Name string `json:"name"`
Email *string `json:"email"`
Roles []string `json:"roles"`
IsInitialSuperuser bool `json:"isInitialSuperuser"`
IsProtected bool `json:"isProtected"`
CreatedAt time.Time `json:"createdAt"`
}
// Identity represents an authentication method for a user.
type Identity struct {
ID int `json:"id"`
UserID int `json:"userId"`
Type string `json:"type"`
Identifier string `json:"identifier"`
IsPrimaryLogin bool `json:"isPrimaryLogin"`
CreatedAt time.Time `json:"createdAt"`
}
// RegisterEmailPasswordRequest for email/password registration.
type RegisterEmailPasswordRequest struct {
Email string `json:"email"`
Password string `json:"password"`
Name string `json:"name"`
Role string `json:"role"`
}
// RegisterBlockchainRequest for blockchain address registration.
type RegisterBlockchainRequest struct {
Address string `json:"address"`
Signature string `json:"signature"`
Message string `json:"message"`
Name string `json:"name"`
Role string `json:"role"`
}
// LoginEmailPasswordRequest for email/password login.
type LoginEmailPasswordRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
// LoginBlockchainRequest for blockchain signature login.
type LoginBlockchainRequest struct {
Address string `json:"address"`
Signature string `json:"signature"`
Message string `json:"message"`
}
// LinkIdentityRequest to link a new identity to existing account.
type LinkIdentityRequest struct {
Type string `json:"type"` // "email_password" or "blockchain_address"
Email string `json:"email,omitempty"`
Password string `json:"password,omitempty"`
Address string `json:"address,omitempty"`
Signature string `json:"signature,omitempty"`
Message string `json:"message,omitempty"`
}
// RefreshTokenRequest for refreshing access tokens.
type RefreshTokenRequest struct {
RefreshToken string `json:"refreshToken"`
}
// AuthTokenResponse includes both access and refresh tokens.
type AuthTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
CsrfToken string `json:"csrfToken"` // CSRF token for state-changing requests
ExpiresIn int64 `json:"expiresIn"` // Access token expiry in seconds
TokenType string `json:"tokenType"`
}
// LogoutRequest for revoking refresh tokens.
type LogoutRequest struct {
RefreshToken string `json:"refreshToken"`
}
var (
jwtSecret []byte
db *sql.DB
defaultRole = "CLIENT"
)
type contextKey string
var userContextKey = contextKey("userClaims")
func main() {
loadConfig()
db = initDB()
defer db.Close()
// Registration routes
http.HandleFunc("/register-email-password", handleRegisterEmailPassword)
http.HandleFunc("/register-blockchain", handleRegisterBlockchain)
// Login routes
http.HandleFunc("/login-email-password", handleLoginEmailPassword)
http.HandleFunc("/login-blockchain", handleLoginBlockchain)
// Token management routes
http.HandleFunc("/auth/refresh", handleRefreshToken)
http.HandleFunc("/auth/logout", handleLogout)
http.HandleFunc("/auth/logout-all", authenticate(requireCSRF(handleLogoutAll)))
// Identity management routes (protected with CSRF)
http.HandleFunc("/link-identity", authenticate(requireCSRF(requireRole(handleLinkIdentity, "CLIENT", "STAFF", "ADMIN"))))
http.HandleFunc("/unlink-identity", authenticate(requireCSRF(requireRole(handleUnlinkIdentity, "CLIENT", "STAFF", "ADMIN"))))
http.HandleFunc("/identities", authenticate(handleGetIdentities)) // GET doesn't need CSRF
// Profile route (protected)
http.HandleFunc("/profile", authenticate(handleProfile)) // GET doesn't need CSRF
// Admin routes (ADMIN only) - Note: SUPERUSER has implicit access
http.HandleFunc("/admin/users", authenticate(requireRole(handleGetAllUsers, "ADMIN", "SUPERUSER"))) // GET doesn't need CSRF
http.HandleFunc("/admin/users/promote-role", authenticate(requireCSRF(requireRole(handlePromoteUserRole, "ADMIN", "SUPERUSER"))))
http.HandleFunc("/admin/users/demote-role", authenticate(requireCSRF(requireRole(handleDemoteUserRole, "ADMIN", "SUPERUSER"))))
// Superuser routes (SUPERUSER only) - with CSRF protection
http.HandleFunc("/superuser/promote", authenticate(requireCSRF(requireRole(handlePromoteSuperuser, "SUPERUSER"))))
http.HandleFunc("/superuser/demote", authenticate(requireCSRF(requireRole(handleDemoteSuperuser, "SUPERUSER"))))
http.HandleFunc("/superuser/transfer", authenticate(requireCSRF(requireRole(handleTransferInitialSuperuser, "SUPERUSER"))))
// Health check
http.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "ok")
})
server := &http.Server{
Addr: ":8080",
Handler: limitBodySize(corsMiddleware(http.DefaultServeMux)),
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
// Graceful shutdown
done := make(chan bool, 1)
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-quit
log.Println("Auth Service shutting down...")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
server.SetKeepAlivesEnabled(false)
if err := server.Shutdown(ctx); err != nil {
log.Printf("Could not gracefully shutdown: %v", err)
}
close(done)
}()
log.Println("Auth Service listening on :8080")
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server error: %v", err)
}
<-done
log.Println("Auth Service stopped")
}
func loadConfig() {
jwtSecret = []byte(strings.TrimSpace(os.Getenv("JWT_SECRET")))
if len(jwtSecret) == 0 {
log.Fatal("JWT_SECRET must be set")
}
if len(jwtSecret) < 32 {
log.Fatal("JWT_SECRET must be at least 32 characters for security")
}
if len(jwtSecret) < 64 {
log.Println("WARNING: JWT_SECRET is less than 64 characters. Consider using a longer secret for production.")
}
if envDefaultRole := strings.TrimSpace(os.Getenv("DEFAULT_USER_ROLE")); envDefaultRole != "" {
defaultRole = strings.ToUpper(envDefaultRole)
}
}
func initDB() *sql.DB {
user := strings.TrimSpace(os.Getenv("DB_USER"))
password := strings.TrimSpace(os.Getenv("DB_PASSWORD"))
name := strings.TrimSpace(os.Getenv("DB_NAME"))
host := strings.TrimSpace(os.Getenv("DB_HOST"))
sslMode := strings.TrimSpace(os.Getenv("DB_SSL_MODE"))
schema := strings.TrimSpace(os.Getenv("DB_SCHEMA"))
if user == "" || password == "" || name == "" || host == "" {
log.Fatal("Database configuration missing: DB_USER, DB_PASSWORD, DB_NAME, DB_HOST required")
}
// Secure default: require TLS for production
if sslMode == "" {
sslMode = "require"
log.Println("WARNING: DB_SSL_MODE not set, defaulting to 'require' for security")
}
// Validate sslMode value
validSSLModes := map[string]bool{
"disable": true, // Only for local development
"require": true, // Minimum for production
"verify-ca": true, // Better
"verify-full": true, // Best
}
if !validSSLModes[sslMode] {
log.Fatalf("Invalid DB_SSL_MODE '%s'. Must be: disable, require, verify-ca, or verify-full", sslMode)
}
// Warn if using insecure mode
if sslMode == "disable" {
log.Println("WARNING: Database SSL is DISABLED. This should only be used for local development!")
}
// Validate schema value if provided
validSchemas := map[string]bool{
"": true, // Empty means use public (default)
"public": true,
"dev": true,
"testing": true,
"prod": true,
}
if !validSchemas[schema] {
log.Fatalf("Invalid DB_SCHEMA '%s'. Must be: dev, testing, prod, or empty for public", schema)
}
// Build connection string with optional search_path
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s sslmode=%s",
user, password, name, host, sslMode)
// Add search_path if schema is specified
if schema != "" && schema != "public" {
connStr += fmt.Sprintf(" search_path=%s,public", schema)
}
database, err := sql.Open("postgres", connStr)
if err != nil {
log.Fatalf("Error opening database: %v", err)
}
// Configure connection pool limits
database.SetMaxOpenConns(25) // Max open connections to DB
database.SetMaxIdleConns(5) // Max idle connections in pool
database.SetConnMaxLifetime(5 * time.Minute) // Max lifetime of a connection
database.SetConnMaxIdleTime(1 * time.Minute) // Max time a connection can be idle
if err := database.Ping(); err != nil {
log.Fatalf("Error connecting to database: %v", err)
}
schemaInfo := "public"
if schema != "" && schema != "public" {
schemaInfo = schema
}
log.Printf("Successfully connected to database (SSL mode: %s, schema: %s, max_conns: 25)", sslMode, schemaInfo)
return database
}
// limitBodySize middleware limits the request body size to prevent DoS attacks
func limitBodySize(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Body != nil {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody)
}
next.ServeHTTP(w, r)
})
}
func corsMiddleware(next http.Handler) http.Handler {
// Log CORS configuration on first request
var corsConfigLogged bool
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
allowedOrigin := strings.TrimSpace(os.Getenv("CORS_ALLOW_ORIGIN"))
if allowedOrigin == "" {
// Default to restrictive in production - set CORS_ALLOW_ORIGIN explicitly
allowedOrigin = "http://localhost:8090"
if !corsConfigLogged {
log.Println("WARNING: CORS_ALLOW_ORIGIN not set, defaulting to http://localhost:8090")
log.Println("For production, set CORS_ALLOW_ORIGIN to your frontend domain (e.g., https://coppertone.tech)")
corsConfigLogged = true
}
} else if allowedOrigin == "*" && !corsConfigLogged {
log.Println("WARNING: CORS configured to allow ALL origins (*). This is insecure for production!")
corsConfigLogged = true
}
// CORS headers
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-CSRF-Token")
w.Header().Set("Access-Control-Allow-Credentials", "true")
// Security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'")
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// ===== REGISTRATION HANDLERS =====
func handleRegisterEmailPassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Rate limit registrations by IP
clientIP := getClientIP(r)
if registerLimiter.checkRateLimit(clientIP, maxRegisterPerHour, time.Hour) {
log.Printf("SECURITY: Registration rate limit exceeded for IP %s", clientIP)
http.Error(w, "Too many registration attempts. Please try again later.", http.StatusTooManyRequests)
return
}
var req RegisterEmailPasswordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Validate inputs
if err := validateEmail(req.Email); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(err)
return
}
if err := validatePassword(req.Password); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(err)
return
}
if err := validateName(req.Name); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(err)
return
}
// Normalize email
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
req.Name = strings.TrimSpace(req.Name)
// Security: Force CLIENT role for all public registrations
// First user will be auto-promoted to SUPERUSER with initial_superuser flag
// Staff/Admin roles can only be granted by existing ADMIN/SUPERUSER
role := "CLIENT"
isInitialSuperuser := false
// Check if this is the first user (auto-promote to SUPERUSER)
var userCount int
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
if err != nil {
log.Printf("Error checking user count: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
if userCount == 0 {
role = "SUPERUSER"
isInitialSuperuser = true
log.Println("AUDIT: First user registration - creating INITIAL SUPERUSER (god-like, non-removable)")
}
// Hash password
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
http.Error(w, "Failed to hash password", http.StatusInternalServerError)
return
}
// Start transaction
tx, err := db.Begin()
if err != nil {
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
defer tx.Rollback()
// Create user
var userID int
err = tx.QueryRow(
"INSERT INTO users (name, email, is_initial_superuser, is_protected, created_at) VALUES ($1, $2, $3, $4, NOW()) RETURNING id",
req.Name, req.Email, isInitialSuperuser, isInitialSuperuser,
).Scan(&userID)
if err != nil {
if strings.Contains(err.Error(), "duplicate key") {
http.Error(w, "Email already registered", http.StatusConflict)
} else {
log.Printf("Error creating user: %v", err)
http.Error(w, "Failed to create user", http.StatusInternalServerError)
}
return
}
// Create identity
_, err = tx.Exec(`
INSERT INTO identities (user_id, type, identifier, credential, is_primary_login, created_at)
VALUES ($1, 'email_password', $2, $3, true, NOW())
`, userID, req.Email, string(passwordHash))
if err != nil {
http.Error(w, "Failed to create identity", http.StatusInternalServerError)
return
}
// Assign role
_, err = tx.Exec(
"INSERT INTO user_roles (user_id, role, created_at) VALUES ($1, $2, NOW())",
userID, role,
)
if err != nil {
http.Error(w, "Failed to assign role", http.StatusInternalServerError)
return
}
if err = tx.Commit(); err != nil {
http.Error(w, "Failed to complete registration", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{
"message": "User registered successfully",
"userId": userID,
})
}
func handleRegisterBlockchain(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Rate limit registrations by IP
clientIP := getClientIP(r)
if registerLimiter.checkRateLimit(clientIP, maxRegisterPerHour, time.Hour) {
log.Printf("SECURITY: Blockchain registration rate limit exceeded for IP %s", clientIP)
http.Error(w, "Too many registration attempts. Please try again later.", http.StatusTooManyRequests)
return
}
var req RegisterBlockchainRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Validate inputs
if err := validateName(req.Name); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(err)
return
}
if req.Address == "" {
http.Error(w, "Blockchain address is required", http.StatusBadRequest)
return
}
if req.Signature == "" || req.Message == "" {
http.Error(w, "Signature and message are required for verification", http.StatusBadRequest)
return
}
req.Name = strings.TrimSpace(req.Name)
// Security: Force CLIENT role for all public registrations
// First user will be auto-promoted to SUPERUSER with initial_superuser flag
// Staff/Admin roles can only be granted by existing ADMIN/SUPERUSER
role := "CLIENT"
isInitialSuperuser := false
// Check if this is the first user (auto-promote to SUPERUSER)
var userCount int
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
if err != nil {
log.Printf("Error checking user count: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
if userCount == 0 {
role = "SUPERUSER"
isInitialSuperuser = true
log.Println("AUDIT: First user registration (blockchain) - creating INITIAL SUPERUSER (god-like, non-removable)")
}
// Verify signature
if !verifyEthereumSignature(req.Address, req.Message, req.Signature) {
http.Error(w, "Invalid signature", http.StatusUnauthorized)
return
}
// Start transaction
tx, err := db.Begin()
if err != nil {
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
defer tx.Rollback()
// Create user
var userID int
err = tx.QueryRow(
"INSERT INTO users (name, is_initial_superuser, is_protected, created_at) VALUES ($1, $2, $3, NOW()) RETURNING id",
req.Name, isInitialSuperuser, isInitialSuperuser,
).Scan(&userID)
if err != nil {
log.Printf("Error creating user: %v", err)
http.Error(w, "Failed to create user", http.StatusInternalServerError)
return
}
// Create identity
_, err = tx.Exec(`
INSERT INTO identities (user_id, type, identifier, is_primary_login, created_at)
VALUES ($1, 'blockchain_address', $2, true, NOW())
`, userID, strings.ToLower(req.Address))
if err != nil {
if strings.Contains(err.Error(), "duplicate key") {
http.Error(w, "Blockchain address already registered", http.StatusConflict)
} else {
http.Error(w, "Failed to create identity", http.StatusInternalServerError)
}
return
}
// Assign role
_, err = tx.Exec(
"INSERT INTO user_roles (user_id, role, created_at) VALUES ($1, $2, NOW())",
userID, role,
)
if err != nil {
http.Error(w, "Failed to assign role", http.StatusInternalServerError)
return
}
if err = tx.Commit(); err != nil {
http.Error(w, "Failed to complete registration", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{
"message": "User registered successfully",
"userId": userID,
})
}
// ===== LOGIN HANDLERS =====
func handleLoginEmailPassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
clientIP := getClientIP(r)
var req LoginEmailPasswordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Rate limit by IP and email combination
rateLimitKey := clientIP + ":" + strings.ToLower(req.Email)
if loginLimiter.checkRateLimit(rateLimitKey, maxLoginAttempts, rateLimitWindow) {
log.Printf("SECURITY: Rate limit exceeded for %s", rateLimitKey)
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
return
}
// Find identity
var userID int
var passwordHash string
err := db.QueryRow(`
SELECT user_id, credential
FROM identities
WHERE type = 'email_password' AND identifier = $1
`, req.Email).Scan(&userID, &passwordHash)
if err == sql.ErrNoRows {
loginLimiter.recordFailedAttempt(rateLimitKey)
log.Printf("SECURITY: Failed login attempt for email %s from IP %s (user not found)", req.Email, clientIP)
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
return
} else if err != nil {
http.Error(w, "Login failed", http.StatusInternalServerError)
return
}
// Verify password
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(req.Password)); err != nil {
loginLimiter.recordFailedAttempt(rateLimitKey)
log.Printf("SECURITY: Failed login attempt for email %s from IP %s (wrong password)", req.Email, clientIP)
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
return
}
// Clear rate limit on successful login
loginLimiter.clearAttempts(rateLimitKey)
log.Printf("AUDIT: Successful login for user_id %d from IP %s", userID, clientIP)
// Generate token pair (access + refresh)
tokenResponse, err := generateTokenPair(userID, clientIP)
if err != nil {
log.Printf("Error generating token pair: %v", err)
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(tokenResponse)
}
func handleLoginBlockchain(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
clientIP := getClientIP(r)
var req LoginBlockchainRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Rate limit by IP and address combination
rateLimitKey := clientIP + ":" + strings.ToLower(req.Address)
if loginLimiter.checkRateLimit(rateLimitKey, maxLoginAttempts, rateLimitWindow) {
log.Printf("SECURITY: Rate limit exceeded for blockchain login %s", rateLimitKey)
http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests)
return
}
// Verify signature
if !verifyEthereumSignature(req.Address, req.Message, req.Signature) {
loginLimiter.recordFailedAttempt(rateLimitKey)
log.Printf("SECURITY: Failed blockchain login for address %s from IP %s (invalid signature)", req.Address, clientIP)
http.Error(w, "Invalid signature", http.StatusUnauthorized)
return
}
// Find identity
var userID int
err := db.QueryRow(`
SELECT user_id
FROM identities
WHERE type = 'blockchain_address' AND identifier = $1
`, strings.ToLower(req.Address)).Scan(&userID)
if err == sql.ErrNoRows {
loginLimiter.recordFailedAttempt(rateLimitKey)
log.Printf("SECURITY: Failed blockchain login for address %s from IP %s (not registered)", req.Address, clientIP)
http.Error(w, "Address not registered", http.StatusUnauthorized)
return
} else if err != nil {
http.Error(w, "Login failed", http.StatusInternalServerError)
return
}
// Clear rate limit on successful login
loginLimiter.clearAttempts(rateLimitKey)
log.Printf("AUDIT: Successful blockchain login for user_id %d from IP %s", userID, clientIP)
// Generate token pair (access + refresh)
tokenResponse, err := generateTokenPair(userID, clientIP)
if err != nil {
log.Printf("Error generating token pair: %v", err)
http.Error(w, "Failed to generate tokens", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(tokenResponse)
}
// ===== TOKEN REFRESH AND LOGOUT HANDLERS =====
// handleRefreshToken exchanges a valid refresh token for a new access token
func handleRefreshToken(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
clientIP := getClientIP(r)
var req RefreshTokenRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.RefreshToken == "" {
http.Error(w, "Refresh token is required", http.StatusBadRequest)
return
}
// Validate the refresh token
tokenID, userID, err := validateRefreshToken(req.RefreshToken)
if err != nil {
log.Printf("SECURITY: Invalid refresh token attempt from IP %s: %v", clientIP, err)
http.Error(w, "Invalid or expired refresh token", http.StatusUnauthorized)
return
}
// Revoke the old refresh token (rotation for security)
if err := revokeRefreshToken(tokenID); err != nil {
log.Printf("Warning: Failed to revoke old refresh token: %v", err)
}
// Generate new token pair
tokenResponse, err := generateTokenPair(userID, clientIP)
if err != nil {
log.Printf("Error generating token pair during refresh: %v", err)
http.Error(w, "Failed to refresh tokens", http.StatusInternalServerError)
return
}
log.Printf("AUDIT: Token refresh for user_id %d from IP %s", userID, clientIP)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(tokenResponse)
}
// handleLogout revokes a specific refresh token
func handleLogout(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
clientIP := getClientIP(r)
var req LogoutRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.RefreshToken == "" {
http.Error(w, "Refresh token is required", http.StatusBadRequest)
return
}
// Validate and get token info
tokenID, userID, err := validateRefreshToken(req.RefreshToken)
if err != nil {
// Even if token is invalid, return success (don't leak info)
log.Printf("SECURITY: Logout with invalid token from IP %s", clientIP)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"})
return
}
// Revoke the refresh token
if err := revokeRefreshToken(tokenID); err != nil {
log.Printf("Error revoking refresh token: %v", err)
http.Error(w, "Failed to logout", http.StatusInternalServerError)
return
}
log.Printf("AUDIT: User %d logged out from IP %s", userID, clientIP)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"})
}
// handleLogoutAll revokes all refresh tokens for the authenticated user
func handleLogoutAll(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := int(claims["userId"].(float64))
clientIP := getClientIP(r)
if err := revokeAllUserRefreshTokens(userID); err != nil {
log.Printf("Error revoking all refresh tokens for user %d: %v", userID, err)
http.Error(w, "Failed to logout from all devices", http.StatusInternalServerError)
return
}
log.Printf("AUDIT: User %d logged out from all devices (initiated from IP %s)", userID, clientIP)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out from all devices successfully"})
}
// ===== IDENTITY MANAGEMENT =====
func handleLinkIdentity(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := int(claims["userId"].(float64))
var req LinkIdentityRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch req.Type {
case "email_password":
// Validate email
if err := validateEmail(req.Email); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(err)
return
}
// Validate password
if err := validatePassword(req.Password); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(err)
return
}
// Normalize email
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
http.Error(w, "Failed to hash password", http.StatusInternalServerError)
return
}
_, err = db.Exec(`
INSERT INTO identities (user_id, type, identifier, credential, is_primary_login)
VALUES ($1, 'email_password', $2, $3, false)
`, userID, req.Email, string(passwordHash))
if err != nil {
if strings.Contains(err.Error(), "duplicate key") {
http.Error(w, "Identity already linked", http.StatusConflict)
} else {
http.Error(w, "Failed to link identity", http.StatusInternalServerError)
}
return
}
case "blockchain_address":
if req.Address == "" || req.Signature == "" || req.Message == "" {
http.Error(w, "Address, signature, and message required", http.StatusBadRequest)
return
}
if !verifyEthereumSignature(req.Address, req.Message, req.Signature) {
http.Error(w, "Invalid signature", http.StatusUnauthorized)
return
}
_, err := db.Exec(`
INSERT INTO identities (user_id, type, identifier, is_primary_login)
VALUES ($1, 'blockchain_address', $2, false)
`, userID, strings.ToLower(req.Address))
if err != nil {
if strings.Contains(err.Error(), "duplicate key") {
http.Error(w, "Identity already linked", http.StatusConflict)
} else {
http.Error(w, "Failed to link identity", http.StatusInternalServerError)
}
return
}
default:
http.Error(w, "Invalid identity type", http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"message": "Identity linked successfully"})
}
// UnlinkIdentityRequest to unlink an identity from account
type UnlinkIdentityRequest struct {
IdentityID int `json:"identityId"`
}
func handleUnlinkIdentity(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := int(claims["userId"].(float64))
var req UnlinkIdentityRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.IdentityID == 0 {
http.Error(w, "Identity ID is required", http.StatusBadRequest)
return
}
// Check identity belongs to user
var identityUserID int
var isPrimary bool
err := db.QueryRow("SELECT user_id, is_primary_login FROM identities WHERE id = $1", req.IdentityID).Scan(&identityUserID, &isPrimary)
if err == sql.ErrNoRows {
http.Error(w, "Identity not found", http.StatusNotFound)
return
} else if err != nil {
log.Printf("Error checking identity: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
if identityUserID != userID {
http.Error(w, "Forbidden: identity does not belong to you", http.StatusForbidden)
return
}
// Check user has at least 2 identities (can't remove last one)
var identityCount int
err = db.QueryRow("SELECT COUNT(*) FROM identities WHERE user_id = $1", userID).Scan(&identityCount)
if err != nil {
log.Printf("Error counting identities: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
if identityCount <= 1 {
http.Error(w, "Cannot remove your last identity. You must have at least one login method.", http.StatusBadRequest)
return
}
// If unlinking primary identity, promote another one to primary
if isPrimary {
_, err = db.Exec(`
UPDATE identities
SET is_primary_login = true
WHERE user_id = $1 AND id != $2
LIMIT 1
`, userID, req.IdentityID)
if err != nil {
log.Printf("Error promoting identity: %v", err)
// Continue anyway - identity will be deleted
}
}
// Delete the identity
result, err := db.Exec("DELETE FROM identities WHERE id = $1", req.IdentityID)
if err != nil {
log.Printf("Error deleting identity: %v", err)
http.Error(w, "Failed to unlink identity", http.StatusInternalServerError)
return
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
http.Error(w, "Identity not found", http.StatusNotFound)
return
}
log.Printf("AUDIT: User %d unlinked identity %d", userID, req.IdentityID)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"message": "Identity unlinked successfully"})
}
func handleGetIdentities(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := int(claims["userId"].(float64))
rows, err := db.Query(`
SELECT id, user_id, type, identifier, is_primary_login, created_at
FROM identities
WHERE user_id = $1
`, userID)
if err != nil {
http.Error(w, "Failed to fetch identities", http.StatusInternalServerError)
return
}
defer rows.Close()
var identities []Identity
for rows.Next() {
var id Identity
err := rows.Scan(&id.ID, &id.UserID, &id.Type, &id.Identifier, &id.IsPrimaryLogin, &id.CreatedAt)
if err != nil {
http.Error(w, "Failed to parse identities", http.StatusInternalServerError)
return
}
identities = append(identities, id)
}
json.NewEncoder(w).Encode(identities)
}
func handleProfile(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := int(claims["userId"].(float64))
var user User
err := db.QueryRow(`
SELECT id, name, email, created_at
FROM users
WHERE id = $1
`, userID).Scan(&user.ID, &user.Name, &user.Email, &user.CreatedAt)
if err != nil {
http.Error(w, "Failed to fetch profile", http.StatusInternalServerError)
return
}
// Get roles
rows, err := db.Query("SELECT role FROM user_roles WHERE user_id = $1", userID)
if err != nil {
http.Error(w, "Failed to fetch roles", http.StatusInternalServerError)
return
}
defer rows.Close()
var roles []string
for rows.Next() {
var role string
rows.Scan(&role)
roles = append(roles, role)
}
user.Roles = roles
json.NewEncoder(w).Encode(user)
}
// ===== MIDDLEWARE =====
func authenticate(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Missing authorization token", http.StatusUnauthorized)
return
}
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
http.Error(w, "Invalid authorization header format", http.StatusUnauthorized)
return
}
tokenString := strings.TrimPrefix(authHeader, bearerPrefix)
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), userContextKey, claims)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
func requireRole(next http.HandlerFunc, allowedRoles ...string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userRoles := claims["roles"].([]interface{})
// Check if user has any of the allowed roles
for _, userRole := range userRoles {
for _, allowedRole := range allowedRoles {
if userRole.(string) == allowedRole {
next.ServeHTTP(w, r)
return
}
}
}
http.Error(w, "Forbidden: insufficient permissions", http.StatusForbidden)
}
}
// requireCSRF validates CSRF token for state-changing requests (POST, PUT, DELETE)
// This middleware should be used after authenticate() for protected endpoints
func requireCSRF(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Only validate CSRF for state-changing methods
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
next.ServeHTTP(w, r)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := int(claims["userId"].(float64))
clientIP := getClientIP(r)
// Get CSRF token from header
csrfToken := r.Header.Get("X-CSRF-Token")
if csrfToken == "" {
log.Printf("SECURITY: CSRF token missing for user %d from IP %s on %s %s",
userID, clientIP, r.Method, r.URL.Path)
http.Error(w, "CSRF token required", http.StatusForbidden)
return
}
// Validate CSRF token
if !validateCSRFToken(userID, csrfToken) {
log.Printf("SECURITY: Invalid CSRF token for user %d from IP %s on %s %s",
userID, clientIP, r.Method, r.URL.Path)
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
// Verify Origin header matches expected origin (additional CSRF protection)
origin := r.Header.Get("Origin")
allowedOrigin := strings.TrimSpace(os.Getenv("CORS_ALLOW_ORIGIN"))
if allowedOrigin == "" {
allowedOrigin = "http://localhost:8090"
}
// If Origin header is present, verify it matches
if origin != "" && allowedOrigin != "*" && origin != allowedOrigin {
log.Printf("SECURITY: Origin mismatch for user %d: expected %s, got %s",
userID, allowedOrigin, origin)
http.Error(w, "Origin not allowed", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
}
}
// ===== HELPER FUNCTIONS =====
// generateAccessToken creates a short-lived JWT access token
func generateAccessToken(userID int) (string, error) {
// Get user roles
rows, err := db.Query("SELECT role FROM user_roles WHERE user_id = $1", userID)
if err != nil {
return "", err
}
defer rows.Close()
var roles []string
for rows.Next() {
var role string
rows.Scan(&role)
roles = append(roles, role)
}
// Get user email and superuser status
var email *string
var isInitialSuperuser bool
db.QueryRow(`
SELECT email, COALESCE(is_initial_superuser, false)
FROM users WHERE id = $1
`, userID).Scan(&email, &isInitialSuperuser)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": userID, // snake_case for downstream services
"userId": userID, // camelCase for compatibility
"email": email,
"roles": roles,
"isInitialSuperuser": isInitialSuperuser,
"exp": time.Now().Add(accessTokenExpiry).Unix(),
"type": "access",
})
return token.SignedString(jwtSecret)
}
// generateRefreshToken creates a secure random refresh token and stores it in the database
func generateRefreshToken(userID int, clientIP string) (string, error) {
// Generate cryptographically secure random token
tokenBytes := make([]byte, refreshTokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
return "", fmt.Errorf("failed to generate random token: %w", err)
}
token := hex.EncodeToString(tokenBytes)
// Hash the token before storing (we only store the hash)
tokenHash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost)
if err != nil {
return "", fmt.Errorf("failed to hash refresh token: %w", err)
}
expiresAt := time.Now().Add(refreshTokenExpiry)
// Store hashed token in database
_, err = db.Exec(`
INSERT INTO refresh_tokens (user_id, token_hash, expires_at, client_ip, created_at)
VALUES ($1, $2, $3, $4, NOW())
`, userID, string(tokenHash), expiresAt, clientIP)
if err != nil {
return "", fmt.Errorf("failed to store refresh token: %w", err)
}
log.Printf("AUDIT: Refresh token created for user_id %d from IP %s, expires %s",
userID, clientIP, expiresAt.Format(time.RFC3339))
return token, nil
}
// validateRefreshToken validates a refresh token and returns the user ID
func validateRefreshToken(token string) (int, int, error) {
// Get all non-expired tokens and check against provided token
rows, err := db.Query(`
SELECT id, user_id, token_hash
FROM refresh_tokens
WHERE expires_at > NOW() AND revoked_at IS NULL
`)
if err != nil {
return 0, 0, fmt.Errorf("failed to query refresh tokens: %w", err)
}
defer rows.Close()
for rows.Next() {
var tokenID, userID int
var tokenHash string
if err := rows.Scan(&tokenID, &userID, &tokenHash); err != nil {
continue
}
// Compare token with hash
if bcrypt.CompareHashAndPassword([]byte(tokenHash), []byte(token)) == nil {
return tokenID, userID, nil
}
}
return 0, 0, errors.New("invalid or expired refresh token")
}
// revokeRefreshToken marks a refresh token as revoked
func revokeRefreshToken(tokenID int) error {
_, err := db.Exec(`
UPDATE refresh_tokens
SET revoked_at = NOW()
WHERE id = $1
`, tokenID)
return err
}
// revokeAllUserRefreshTokens revokes all refresh tokens for a user
func revokeAllUserRefreshTokens(userID int) error {
result, err := db.Exec(`
UPDATE refresh_tokens
SET revoked_at = NOW()
WHERE user_id = $1 AND revoked_at IS NULL
`, userID)
if err != nil {
return err
}
rowsAffected, _ := result.RowsAffected()
log.Printf("AUDIT: Revoked %d refresh tokens for user_id %d", rowsAffected, userID)
return nil
}
// ===== CSRF TOKEN MANAGEMENT =====
// generateCSRFToken creates a cryptographically secure CSRF token
func generateCSRFToken() (string, error) {
tokenBytes := make([]byte, csrfTokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
return "", fmt.Errorf("failed to generate CSRF token: %w", err)
}
return hex.EncodeToString(tokenBytes), nil
}
// storeCSRFToken stores a CSRF token in the database linked to a user
func storeCSRFToken(userID int, csrfToken string, clientIP string) error {
// Hash the CSRF token before storing
tokenHash, err := bcrypt.GenerateFromPassword([]byte(csrfToken), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash CSRF token: %w", err)
}
expiresAt := time.Now().Add(accessTokenExpiry) // CSRF token expires with access token
// Store hashed token in database
_, err = db.Exec(`
INSERT INTO csrf_tokens (user_id, token_hash, expires_at, client_ip, created_at)
VALUES ($1, $2, $3, $4, NOW())
`, userID, string(tokenHash), expiresAt, clientIP)
if err != nil {
return fmt.Errorf("failed to store CSRF token: %w", err)
}
return nil
}
// validateCSRFToken validates a CSRF token for a user
func validateCSRFToken(userID int, token string) bool {
if token == "" {
return false
}
// Get all non-expired tokens for this user and check against provided token
rows, err := db.Query(`
SELECT token_hash
FROM csrf_tokens
WHERE user_id = $1 AND expires_at > NOW()
`, userID)
if err != nil {
log.Printf("Error querying CSRF tokens: %v", err)
return false
}
defer rows.Close()
for rows.Next() {
var tokenHash string
if err := rows.Scan(&tokenHash); err != nil {
continue
}
// Compare token with hash
if bcrypt.CompareHashAndPassword([]byte(tokenHash), []byte(token)) == nil {
return true
}
}
return false
}
// cleanupExpiredCSRFTokens removes expired CSRF tokens (call periodically)
func cleanupExpiredCSRFTokens() {
result, err := db.Exec(`DELETE FROM csrf_tokens WHERE expires_at < NOW()`)
if err != nil {
log.Printf("Error cleaning up expired CSRF tokens: %v", err)
return
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected > 0 {
log.Printf("Cleaned up %d expired CSRF tokens", rowsAffected)
}
}
// revokeUserCSRFTokens revokes all CSRF tokens for a user (on logout)
func revokeUserCSRFTokens(userID int) error {
_, err := db.Exec(`DELETE FROM csrf_tokens WHERE user_id = $1`, userID)
return err
}
// generateTokenPair creates access, refresh, and CSRF tokens
func generateTokenPair(userID int, clientIP string) (*AuthTokenResponse, error) {
accessToken, err := generateAccessToken(userID)
if err != nil {
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
refreshToken, err := generateRefreshToken(userID, clientIP)
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
// Generate CSRF token
csrfToken, err := generateCSRFToken()
if err != nil {
return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
}
// Store CSRF token in database
if err := storeCSRFToken(userID, csrfToken, clientIP); err != nil {
log.Printf("Warning: failed to store CSRF token: %v", err)
// Continue anyway - CSRF validation will fail but auth still works
}
return &AuthTokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
CsrfToken: csrfToken,
ExpiresIn: int64(accessTokenExpiry.Seconds()),
TokenType: "Bearer",
}, nil
}
// Legacy function for backward compatibility
func generateToken(userID int) (string, error) {
return generateAccessToken(userID)
}
func extractRoles(claims jwt.MapClaims) ([]string, error) {
rawRoles, ok := claims["roles"]
if !ok {
return nil, errors.New("roles not present in token")
}
switch v := rawRoles.(type) {
case []interface{}:
roles := make([]string, 0, len(v))
for _, roleVal := range v {
roleStr, ok := roleVal.(string)
if !ok {
return nil, errors.New("role value not a string")
}
roles = append(roles, roleStr)
}
return roles, nil
case []string:
return v, nil
default:
return nil, errors.New("roles claim has unexpected type")
}
}
func verifyEthereumSignature(address, message, signature string) bool {
address = normalizeEthereumAddress(address)
if address == "" {
return false
}
// Decode signature
sigBytes, err := hex.DecodeString(strings.TrimPrefix(signature, "0x"))
if err != nil || len(sigBytes) != 65 {
return false
}
// Ethereum specific: adjust v value
if sigBytes[64] == 27 || sigBytes[64] == 28 {
sigBytes[64] -= 27
}
// Hash message with Ethereum prefix
messageHash := crypto.Keccak256Hash([]byte(fmt.Sprintf("\x19Ethereum Signed Message:\n%d%s", len(message), message)))
// Recover public key
pubKey, err := crypto.SigToPub(messageHash.Bytes(), sigBytes)
if err != nil {
return false
}
// Get address from public key
recoveredAddr := crypto.PubkeyToAddress(*pubKey).Hex()
return strings.ToLower(recoveredAddr) == address
}
// === Test helpers and shared utilities ===
func generateJWT(user User, roles []string) (string, error) {
if len(jwtSecret) == 0 {
return "", errors.New("JWT_SECRET not configured")
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.ID,
"userId": user.ID,
"email": user.Email,
"roles": roles,
"exp": time.Now().Add(time.Hour).Unix(),
})
return token.SignedString(jwtSecret)
}
func hashPassword(password string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hash), nil
}
func checkPasswordHash(password, hash string) bool {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
}
func normalizeEthereumAddress(addr string) string {
addr = strings.ToLower(strings.TrimSpace(addr))
if addr == "" {
return ""
}
if !strings.HasPrefix(addr, "0x") {
addr = "0x" + addr
}
if len(addr) != 42 { // 0x + 40 hex chars
return ""
}
return addr
}
// ===== ADMIN ENDPOINTS =====
// PromoteUserRoleRequest for promoting user roles
type PromoteUserRoleRequest struct {
UserID int `json:"userId"`
Role string `json:"role"`
}
// DemoteUserRoleRequest for demoting user roles
type DemoteUserRoleRequest struct {
UserID int `json:"userId"`
Role string `json:"role"`
}
// handleGetAllUsers returns all users (ADMIN/SUPERUSER only)
func handleGetAllUsers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
rows, err := db.Query(`
SELECT u.id, u.name, u.email,
COALESCE(u.is_initial_superuser, false),
COALESCE(u.is_protected, false),
u.created_at
FROM users u
ORDER BY u.created_at DESC
`)
if err != nil {
log.Printf("Error fetching users: %v", err)
http.Error(w, "Failed to fetch users", http.StatusInternalServerError)
return
}
defer rows.Close()
var users []User
for rows.Next() {
var u User
err := rows.Scan(&u.ID, &u.Name, &u.Email, &u.IsInitialSuperuser, &u.IsProtected, &u.CreatedAt)
if err != nil {
log.Printf("Error scanning user: %v", err)
http.Error(w, "Failed to parse users", http.StatusInternalServerError)
return
}
users = append(users, u)
}
// Fetch roles for each user
for i := range users {
roleRows, err := db.Query("SELECT role FROM user_roles WHERE user_id = $1", users[i].ID)
if err != nil {
log.Printf("Error fetching roles for user %d: %v", users[i].ID, err)
continue
}
var roles []string
for roleRows.Next() {
var role string
roleRows.Scan(&role)
roles = append(roles, role)
}
roleRows.Close()
users[i].Roles = roles
}
// Return empty array instead of null if no users
if users == nil {
users = []User{}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(users)
}
// handleDemoteUserRole allows ADMIN/SUPERUSER users to remove roles from other users
// ADMIN can only demote CLIENT, STAFF, ADMIN roles (cannot touch SUPERUSER)
// SUPERUSER can demote any role except SUPERUSER (use /superuser/demote for that)
func handleDemoteUserRole(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
adminUserID := int(claims["userId"].(float64))
callerRoles, _ := extractRoles(claims)
var req DemoteUserRoleRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate role - SUPERUSER demotion must go through /superuser/demote
if strings.ToUpper(req.Role) == "SUPERUSER" {
http.Error(w, "Use /superuser/demote endpoint to remove SUPERUSER role", http.StatusBadRequest)
return
}
validRoles := map[string]bool{"CLIENT": true, "STAFF": true, "ADMIN": true}
if !validRoles[strings.ToUpper(req.Role)] {
http.Error(w, "Invalid role. Must be CLIENT, STAFF, or ADMIN", http.StatusBadRequest)
return
}
req.Role = strings.ToUpper(req.Role)
// Check if caller is SUPERUSER
isSuperuser := false
for _, r := range callerRoles {
if r == "SUPERUSER" {
isSuperuser = true
break
}
}
// If not superuser, verify target is not a superuser (ADMINs cannot touch SUPERUSERs)
if !isSuperuser {
var targetHasSuperuser bool
err := db.QueryRow(`
SELECT EXISTS(SELECT 1 FROM user_roles WHERE user_id = $1 AND role = 'SUPERUSER')
`, req.UserID).Scan(&targetHasSuperuser)
if err == nil && targetHasSuperuser {
http.Error(w, "Forbidden: ADMINs cannot modify SUPERUSER accounts", http.StatusForbidden)
return
}
}
// Prevent admin from demoting themselves from ADMIN role
if req.UserID == adminUserID && req.Role == "ADMIN" && !isSuperuser {
http.Error(w, "Cannot remove your own ADMIN role", http.StatusForbidden)
return
}
// Validate user exists
var userName string
err := db.QueryRow("SELECT name FROM users WHERE id = $1", req.UserID).Scan(&userName)
if err == sql.ErrNoRows {
http.Error(w, "User not found", http.StatusNotFound)
return
} else if err != nil {
log.Printf("Error checking user existence: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
// Check if user has this role
var existingRoleID int
err = db.QueryRow(`
SELECT id FROM user_roles
WHERE user_id = $1 AND role = $2
`, req.UserID, req.Role).Scan(&existingRoleID)
if err == sql.ErrNoRows {
http.Error(w, fmt.Sprintf("User does not have %s role", req.Role), http.StatusNotFound)
return
} else if err != nil {
log.Printf("Error checking existing role: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
// Check if this is the user's only role - don't allow removal
var roleCount int
err = db.QueryRow("SELECT COUNT(*) FROM user_roles WHERE user_id = $1", req.UserID).Scan(&roleCount)
if err != nil {
log.Printf("Error counting roles: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
if roleCount <= 1 {
http.Error(w, "Cannot remove user's only role. Assign a different role first.", http.StatusBadRequest)
return
}
// Delete the role
_, err = db.Exec(`
DELETE FROM user_roles
WHERE user_id = $1 AND role = $2
`, req.UserID, req.Role)
if err != nil {
log.Printf("Error deleting role: %v", err)
http.Error(w, "Failed to remove role", http.StatusInternalServerError)
return
}
// Log the role change for audit trail
log.Printf("AUDIT: Admin user %d removed %s role from user %d (%s)",
adminUserID, req.Role, req.UserID, userName)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"message": fmt.Sprintf("Successfully removed %s role from user %d", req.Role, req.UserID),
})
}
// handlePromoteUserRole allows ADMIN/SUPERUSER users to grant roles to other users
// ADMIN can only promote to CLIENT, STAFF, ADMIN
// SUPERUSER can promote to any role except SUPERUSER (use /superuser/promote for that)
func handlePromoteUserRole(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
adminUserID := int(claims["userId"].(float64))
callerRoles, _ := extractRoles(claims)
var req PromoteUserRoleRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate role - SUPERUSER promotion must go through /superuser/promote
validRoles := map[string]bool{"CLIENT": true, "STAFF": true, "ADMIN": true}
if !validRoles[strings.ToUpper(req.Role)] {
if strings.ToUpper(req.Role) == "SUPERUSER" {
http.Error(w, "Use /superuser/promote endpoint to promote to SUPERUSER", http.StatusBadRequest)
return
}
http.Error(w, "Invalid role. Must be CLIENT, STAFF, or ADMIN", http.StatusBadRequest)
return
}
req.Role = strings.ToUpper(req.Role)
// Check if caller is SUPERUSER (they can do anything except promote to SUPERUSER here)
isSuperuser := false
for _, r := range callerRoles {
if r == "SUPERUSER" {
isSuperuser = true
break
}
}
// If not superuser, verify target is not a superuser (ADMINs cannot touch SUPERUSERs)
if !isSuperuser {
var targetHasSuperuser bool
err := db.QueryRow(`
SELECT EXISTS(SELECT 1 FROM user_roles WHERE user_id = $1 AND role = 'SUPERUSER')
`, req.UserID).Scan(&targetHasSuperuser)
if err == nil && targetHasSuperuser {
http.Error(w, "Forbidden: ADMINs cannot modify SUPERUSER accounts", http.StatusForbidden)
return
}
}
// Validate user exists
var userName string
err := db.QueryRow("SELECT name FROM users WHERE id = $1", req.UserID).Scan(&userName)
if err == sql.ErrNoRows {
http.Error(w, "User not found", http.StatusNotFound)
return
} else if err != nil {
log.Printf("Error checking user existence: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
// Check if user already has this role
var existingRoleID int
err = db.QueryRow(`
SELECT id FROM user_roles
WHERE user_id = $1 AND role = $2
`, req.UserID, req.Role).Scan(&existingRoleID)
if err == nil {
http.Error(w, fmt.Sprintf("User already has %s role", req.Role), http.StatusConflict)
return
} else if err != sql.ErrNoRows {
log.Printf("Error checking existing role: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
// Insert new role
_, err = db.Exec(`
INSERT INTO user_roles (user_id, role, created_at)
VALUES ($1, $2, NOW())
`, req.UserID, req.Role)
if err != nil {
log.Printf("Error inserting role: %v", err)
http.Error(w, "Failed to assign role", http.StatusInternalServerError)
return
}
// Log the role change for audit trail
log.Printf("AUDIT: Admin user %d granted %s role to user %d (%s)",
adminUserID, req.Role, req.UserID, userName)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"message": fmt.Sprintf("Successfully granted %s role to user %d", req.Role, req.UserID),
})
}
// =============================================================================
// SUPERUSER MANAGEMENT ENDPOINTS
// =============================================================================
// SuperuserPromoteRequest for promoting user to SUPERUSER
type SuperuserPromoteRequest struct {
UserID int `json:"userId"`
}
// SuperuserDemoteRequest for demoting a SUPERUSER
type SuperuserDemoteRequest struct {
UserID int `json:"userId"`
}
// TransferInitialSuperuserRequest for transferring initial superuser status
type TransferInitialSuperuserRequest struct {
NewSuperuserID int `json:"newSuperuserId"`
Reason string `json:"reason,omitempty"`
}
// handlePromoteSuperuser promotes a user to SUPERUSER (SUPERUSER only)
func handlePromoteSuperuser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
callerUserID := int(claims["userId"].(float64))
var req SuperuserPromoteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate user exists
var userName string
err := db.QueryRow("SELECT name FROM users WHERE id = $1", req.UserID).Scan(&userName)
if err == sql.ErrNoRows {
http.Error(w, "User not found", http.StatusNotFound)
return
} else if err != nil {
log.Printf("Error checking user existence: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
// Check if user already has SUPERUSER role
var existingRoleID int
err = db.QueryRow(`
SELECT id FROM user_roles
WHERE user_id = $1 AND role = 'SUPERUSER'
`, req.UserID).Scan(&existingRoleID)
if err == nil {
http.Error(w, "User is already a SUPERUSER", http.StatusConflict)
return
} else if err != sql.ErrNoRows {
log.Printf("Error checking existing role: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
// Insert SUPERUSER role
_, err = db.Exec(`
INSERT INTO user_roles (user_id, role, created_at)
VALUES ($1, 'SUPERUSER', NOW())
`, req.UserID)
if err != nil {
log.Printf("Error inserting SUPERUSER role: %v", err)
http.Error(w, "Failed to promote user to SUPERUSER", http.StatusInternalServerError)
return
}
log.Printf("AUDIT: SUPERUSER %d promoted user %d (%s) to SUPERUSER",
callerUserID, req.UserID, userName)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"message": fmt.Sprintf("Successfully promoted user %d to SUPERUSER", req.UserID),
})
}
// handleDemoteSuperuser removes SUPERUSER role from a user (SUPERUSER only)
// Cannot demote initial superuser - they must transfer first
func handleDemoteSuperuser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
callerUserID := int(claims["userId"].(float64))
var req SuperuserDemoteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Cannot demote yourself
if req.UserID == callerUserID {
http.Error(w, "Cannot demote yourself. Have another SUPERUSER do it.", http.StatusForbidden)
return
}
// Check if target is the initial superuser
var isInitialSU bool
var userName string
err := db.QueryRow(`
SELECT name, COALESCE(is_initial_superuser, false)
FROM users WHERE id = $1
`, req.UserID).Scan(&userName, &isInitialSU)
if err == sql.ErrNoRows {
http.Error(w, "User not found", http.StatusNotFound)
return
} else if err != nil {
log.Printf("Error checking user: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
if isInitialSU {
http.Error(w, "Cannot demote the INITIAL SUPERUSER. They must transfer their status first using /superuser/transfer", http.StatusForbidden)
return
}
// Check if user has SUPERUSER role
var existingRoleID int
err = db.QueryRow(`
SELECT id FROM user_roles
WHERE user_id = $1 AND role = 'SUPERUSER'
`, req.UserID).Scan(&existingRoleID)
if err == sql.ErrNoRows {
http.Error(w, "User does not have SUPERUSER role", http.StatusNotFound)
return
} else if err != nil {
log.Printf("Error checking role: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
// Ensure user has at least one other role
var roleCount int
err = db.QueryRow("SELECT COUNT(*) FROM user_roles WHERE user_id = $1", req.UserID).Scan(&roleCount)
if err != nil {
log.Printf("Error counting roles: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
if roleCount <= 1 {
// Add CLIENT role before removing SUPERUSER
_, err = db.Exec(`
INSERT INTO user_roles (user_id, role, created_at)
VALUES ($1, 'CLIENT', NOW())
ON CONFLICT (user_id, role) DO NOTHING
`, req.UserID)
if err != nil {
log.Printf("Error adding fallback CLIENT role: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
}
// Remove SUPERUSER role
_, err = db.Exec(`
DELETE FROM user_roles
WHERE user_id = $1 AND role = 'SUPERUSER'
`, req.UserID)
if err != nil {
log.Printf("Error removing SUPERUSER role: %v", err)
http.Error(w, "Failed to demote SUPERUSER", http.StatusInternalServerError)
return
}
log.Printf("AUDIT: SUPERUSER %d demoted user %d (%s) from SUPERUSER",
callerUserID, req.UserID, userName)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"message": fmt.Sprintf("Successfully removed SUPERUSER role from user %d", req.UserID),
})
}
// handleTransferInitialSuperuser transfers initial superuser status to another user
// ONLY the current initial superuser can call this (self-transfer only)
func handleTransferInitialSuperuser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
callerUserID := int(claims["userId"].(float64))
// Verify caller is the initial superuser
var isInitialSU bool
err := db.QueryRow(`
SELECT COALESCE(is_initial_superuser, false)
FROM users WHERE id = $1
`, callerUserID).Scan(&isInitialSU)
if err != nil {
log.Printf("Error checking initial superuser status: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
if !isInitialSU {
http.Error(w, "Forbidden: Only the INITIAL SUPERUSER can transfer their status", http.StatusForbidden)
return
}
var req TransferInitialSuperuserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.NewSuperuserID == callerUserID {
http.Error(w, "Cannot transfer to yourself", http.StatusBadRequest)
return
}
// Validate new superuser exists
var newUserName string
err = db.QueryRow("SELECT name FROM users WHERE id = $1", req.NewSuperuserID).Scan(&newUserName)
if err == sql.ErrNoRows {
http.Error(w, "Target user not found", http.StatusNotFound)
return
} else if err != nil {
log.Printf("Error checking user existence: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
// Start transaction for transfer
tx, err := db.Begin()
if err != nil {
log.Printf("Error starting transaction: %v", err)
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
defer tx.Rollback()
// 1. Remove initial_superuser flag from current user
_, err = tx.Exec(`
UPDATE users SET is_initial_superuser = false
WHERE id = $1
`, callerUserID)
if err != nil {
log.Printf("Error removing initial superuser flag: %v", err)
http.Error(w, "Transfer failed", http.StatusInternalServerError)
return
}
// 2. Set initial_superuser and is_protected on new user
_, err = tx.Exec(`
UPDATE users SET is_initial_superuser = true, is_protected = true
WHERE id = $1
`, req.NewSuperuserID)
if err != nil {
log.Printf("Error setting initial superuser flag: %v", err)
http.Error(w, "Transfer failed", http.StatusInternalServerError)
return
}
// 3. Ensure new user has SUPERUSER role
_, err = tx.Exec(`
INSERT INTO user_roles (user_id, role, created_at)
VALUES ($1, 'SUPERUSER', NOW())
ON CONFLICT (user_id, role) DO NOTHING
`, req.NewSuperuserID)
if err != nil {
log.Printf("Error granting SUPERUSER role: %v", err)
http.Error(w, "Transfer failed", http.StatusInternalServerError)
return
}
// 4. Record the transfer for audit
_, err = tx.Exec(`
INSERT INTO superuser_transfers (from_user_id, to_user_id, reason)
VALUES ($1, $2, $3)
`, callerUserID, req.NewSuperuserID, req.Reason)
if err != nil {
log.Printf("Error recording transfer: %v", err)
// Don't fail on audit record
}
if err = tx.Commit(); err != nil {
log.Printf("Error committing transfer: %v", err)
http.Error(w, "Transfer failed", http.StatusInternalServerError)
return
}
log.Printf("AUDIT: INITIAL SUPERUSER transferred from user %d to user %d (%s). Reason: %s",
callerUserID, req.NewSuperuserID, newUserName, req.Reason)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"message": fmt.Sprintf("Successfully transferred INITIAL SUPERUSER status to user %d (%s)", req.NewSuperuserID, newUserName),
})
}