package main import ( "context" "crypto/rand" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "log" "net" "net/http" "os" "os/signal" "regexp" "strings" "sync" "syscall" "time" "unicode" "github.com/ethereum/go-ethereum/crypto" "github.com/golang-jwt/jwt/v5" _ "github.com/lib/pq" "golang.org/x/crypto/bcrypt" ) // Rate limiting configuration const ( rateLimitWindow = 15 * time.Minute // Window for counting attempts maxLoginAttempts = 5 // Max failed login attempts per window maxRegisterPerHour = 10 // Max registrations per IP per hour lockoutDuration = 30 * time.Minute // How long to lock out after max attempts ) // Token expiration configuration const ( accessTokenExpiry = 15 * time.Minute // Short-lived access tokens refreshTokenExpiry = 7 * 24 * time.Hour // 7 day refresh tokens refreshTokenLength = 32 // 256-bit refresh token csrfTokenLength = 32 // 256-bit CSRF token ) // rateLimiter tracks login attempts per IP/email type rateLimiter struct { mu sync.RWMutex attempts map[string]*attemptInfo } type attemptInfo struct { count int firstTry time.Time lockedOut bool lockUntil time.Time } var loginLimiter = &rateLimiter{attempts: make(map[string]*attemptInfo)} var registerLimiter = &rateLimiter{attempts: make(map[string]*attemptInfo)} // checkRateLimit returns true if the request should be blocked func (rl *rateLimiter) checkRateLimit(key string, maxAttempts int, window time.Duration) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() info, exists := rl.attempts[key] if !exists { rl.attempts[key] = &attemptInfo{count: 1, firstTry: now} return false } // Check if locked out if info.lockedOut && now.Before(info.lockUntil) { return true } // Reset lockout if time has passed if info.lockedOut && now.After(info.lockUntil) { info.lockedOut = false info.count = 1 info.firstTry = now return false } // Reset window if expired if now.Sub(info.firstTry) > window { info.count = 1 info.firstTry = now return false } // Increment and check info.count++ if info.count > maxAttempts { info.lockedOut = true info.lockUntil = now.Add(lockoutDuration) return true } return false } // recordFailedAttempt records a failed attempt (for login failures) func (rl *rateLimiter) recordFailedAttempt(key string) { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() info, exists := rl.attempts[key] if !exists { rl.attempts[key] = &attemptInfo{count: 1, firstTry: now} return } // Reset window if expired if now.Sub(info.firstTry) > rateLimitWindow { info.count = 1 info.firstTry = now return } info.count++ if info.count >= maxLoginAttempts { info.lockedOut = true info.lockUntil = now.Add(lockoutDuration) log.Printf("SECURITY: IP/email %s locked out after %d failed attempts", key, info.count) } } // clearAttempts clears attempts for a key (called on successful login) func (rl *rateLimiter) clearAttempts(key string) { rl.mu.Lock() defer rl.mu.Unlock() delete(rl.attempts, key) } // getClientIP extracts the client IP from the request func getClientIP(r *http.Request) string { // Check X-Forwarded-For header (for proxies) xff := r.Header.Get("X-Forwarded-For") if xff != "" { // Take the first IP in the chain if idx := strings.Index(xff, ","); idx != -1 { return strings.TrimSpace(xff[:idx]) } return strings.TrimSpace(xff) } // Check X-Real-IP header xri := r.Header.Get("X-Real-IP") if xri != "" { return strings.TrimSpace(xri) } // Fall back to RemoteAddr ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return ip } // Input validation constants const ( maxNameLength = 100 maxEmailLength = 254 // RFC 5321 limit minPasswordLength = 8 maxPasswordLength = 72 // bcrypt limit maxRequestBody = 1 << 20 // 1MB max request body size ) // Email validation regex (RFC 5322 simplified) var emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`) // ValidationError holds field-specific validation errors type ValidationError struct { Field string `json:"field"` Message string `json:"message"` } // validateEmail checks email format and length func validateEmail(email string) *ValidationError { email = strings.TrimSpace(email) if email == "" { return &ValidationError{Field: "email", Message: "Email is required"} } if len(email) > maxEmailLength { return &ValidationError{Field: "email", Message: fmt.Sprintf("Email must be less than %d characters", maxEmailLength)} } if !emailRegex.MatchString(email) { return &ValidationError{Field: "email", Message: "Invalid email format"} } return nil } // validatePassword checks password requirements func validatePassword(password string) *ValidationError { if len(password) < minPasswordLength { return &ValidationError{Field: "password", Message: fmt.Sprintf("Password must be at least %d characters", minPasswordLength)} } if len(password) > maxPasswordLength { return &ValidationError{Field: "password", Message: fmt.Sprintf("Password must be less than %d characters", maxPasswordLength)} } var hasUpper, hasLower, hasNumber bool for _, c := range password { switch { case unicode.IsUpper(c): hasUpper = true case unicode.IsLower(c): hasLower = true case unicode.IsNumber(c): hasNumber = true } } if !hasUpper { return &ValidationError{Field: "password", Message: "Password must contain at least one uppercase letter"} } if !hasLower { return &ValidationError{Field: "password", Message: "Password must contain at least one lowercase letter"} } if !hasNumber { return &ValidationError{Field: "password", Message: "Password must contain at least one number"} } return nil } // validateName checks name length func validateName(name string) *ValidationError { name = strings.TrimSpace(name) if name == "" { return &ValidationError{Field: "name", Message: "Name is required"} } if len(name) > maxNameLength { return &ValidationError{Field: "name", Message: fmt.Sprintf("Name must be less than %d characters", maxNameLength)} } return nil } // User represents a user in the system. type User struct { ID int `json:"id"` Name string `json:"name"` Email *string `json:"email"` Roles []string `json:"roles"` IsInitialSuperuser bool `json:"isInitialSuperuser"` IsProtected bool `json:"isProtected"` CreatedAt time.Time `json:"createdAt"` } // Identity represents an authentication method for a user. type Identity struct { ID int `json:"id"` UserID int `json:"userId"` Type string `json:"type"` Identifier string `json:"identifier"` IsPrimaryLogin bool `json:"isPrimaryLogin"` CreatedAt time.Time `json:"createdAt"` } // RegisterEmailPasswordRequest for email/password registration. type RegisterEmailPasswordRequest struct { Email string `json:"email"` Password string `json:"password"` Name string `json:"name"` Role string `json:"role"` } // RegisterBlockchainRequest for blockchain address registration. type RegisterBlockchainRequest struct { Address string `json:"address"` Signature string `json:"signature"` Message string `json:"message"` Name string `json:"name"` Role string `json:"role"` } // LoginEmailPasswordRequest for email/password login. type LoginEmailPasswordRequest struct { Email string `json:"email"` Password string `json:"password"` } // LoginBlockchainRequest for blockchain signature login. type LoginBlockchainRequest struct { Address string `json:"address"` Signature string `json:"signature"` Message string `json:"message"` } // LinkIdentityRequest to link a new identity to existing account. type LinkIdentityRequest struct { Type string `json:"type"` // "email_password" or "blockchain_address" Email string `json:"email,omitempty"` Password string `json:"password,omitempty"` Address string `json:"address,omitempty"` Signature string `json:"signature,omitempty"` Message string `json:"message,omitempty"` } // RefreshTokenRequest for refreshing access tokens. type RefreshTokenRequest struct { RefreshToken string `json:"refreshToken"` } // AuthTokenResponse includes both access and refresh tokens. type AuthTokenResponse struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` CsrfToken string `json:"csrfToken"` // CSRF token for state-changing requests ExpiresIn int64 `json:"expiresIn"` // Access token expiry in seconds TokenType string `json:"tokenType"` } // LogoutRequest for revoking refresh tokens. type LogoutRequest struct { RefreshToken string `json:"refreshToken"` } var ( jwtSecret []byte db *sql.DB defaultRole = "CLIENT" ) type contextKey string var userContextKey = contextKey("userClaims") func main() { loadConfig() db = initDB() defer db.Close() // Registration routes http.HandleFunc("/register-email-password", handleRegisterEmailPassword) http.HandleFunc("/register-blockchain", handleRegisterBlockchain) // Login routes http.HandleFunc("/login-email-password", handleLoginEmailPassword) http.HandleFunc("/login-blockchain", handleLoginBlockchain) // Token management routes http.HandleFunc("/auth/refresh", handleRefreshToken) http.HandleFunc("/auth/logout", handleLogout) http.HandleFunc("/auth/logout-all", authenticate(requireCSRF(handleLogoutAll))) // Identity management routes (protected with CSRF) http.HandleFunc("/link-identity", authenticate(requireCSRF(requireRole(handleLinkIdentity, "CLIENT", "STAFF", "ADMIN")))) http.HandleFunc("/unlink-identity", authenticate(requireCSRF(requireRole(handleUnlinkIdentity, "CLIENT", "STAFF", "ADMIN")))) http.HandleFunc("/identities", authenticate(handleGetIdentities)) // GET doesn't need CSRF // Profile route (protected) http.HandleFunc("/profile", authenticate(handleProfile)) // GET doesn't need CSRF // Admin routes (ADMIN only) - Note: SUPERUSER has implicit access http.HandleFunc("/admin/users", authenticate(requireRole(handleGetAllUsers, "ADMIN", "SUPERUSER"))) // GET doesn't need CSRF http.HandleFunc("/admin/users/promote-role", authenticate(requireCSRF(requireRole(handlePromoteUserRole, "ADMIN", "SUPERUSER")))) http.HandleFunc("/admin/users/demote-role", authenticate(requireCSRF(requireRole(handleDemoteUserRole, "ADMIN", "SUPERUSER")))) // Superuser routes (SUPERUSER only) - with CSRF protection http.HandleFunc("/superuser/promote", authenticate(requireCSRF(requireRole(handlePromoteSuperuser, "SUPERUSER")))) http.HandleFunc("/superuser/demote", authenticate(requireCSRF(requireRole(handleDemoteSuperuser, "SUPERUSER")))) http.HandleFunc("/superuser/transfer", authenticate(requireCSRF(requireRole(handleTransferInitialSuperuser, "SUPERUSER")))) // Health check http.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprintln(w, "ok") }) server := &http.Server{ Addr: ":8080", Handler: limitBodySize(corsMiddleware(http.DefaultServeMux)), ReadHeaderTimeout: 10 * time.Second, ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, IdleTimeout: 60 * time.Second, } // Graceful shutdown done := make(chan bool, 1) quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) go func() { <-quit log.Println("Auth Service shutting down...") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() server.SetKeepAlivesEnabled(false) if err := server.Shutdown(ctx); err != nil { log.Printf("Could not gracefully shutdown: %v", err) } close(done) }() log.Println("Auth Service listening on :8080") if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Server error: %v", err) } <-done log.Println("Auth Service stopped") } func loadConfig() { jwtSecret = []byte(strings.TrimSpace(os.Getenv("JWT_SECRET"))) if len(jwtSecret) == 0 { log.Fatal("JWT_SECRET must be set") } if len(jwtSecret) < 32 { log.Fatal("JWT_SECRET must be at least 32 characters for security") } if len(jwtSecret) < 64 { log.Println("WARNING: JWT_SECRET is less than 64 characters. Consider using a longer secret for production.") } if envDefaultRole := strings.TrimSpace(os.Getenv("DEFAULT_USER_ROLE")); envDefaultRole != "" { defaultRole = strings.ToUpper(envDefaultRole) } } func initDB() *sql.DB { user := strings.TrimSpace(os.Getenv("DB_USER")) password := strings.TrimSpace(os.Getenv("DB_PASSWORD")) name := strings.TrimSpace(os.Getenv("DB_NAME")) host := strings.TrimSpace(os.Getenv("DB_HOST")) sslMode := strings.TrimSpace(os.Getenv("DB_SSL_MODE")) schema := strings.TrimSpace(os.Getenv("DB_SCHEMA")) if user == "" || password == "" || name == "" || host == "" { log.Fatal("Database configuration missing: DB_USER, DB_PASSWORD, DB_NAME, DB_HOST required") } // Secure default: require TLS for production if sslMode == "" { sslMode = "require" log.Println("WARNING: DB_SSL_MODE not set, defaulting to 'require' for security") } // Validate sslMode value validSSLModes := map[string]bool{ "disable": true, // Only for local development "require": true, // Minimum for production "verify-ca": true, // Better "verify-full": true, // Best } if !validSSLModes[sslMode] { log.Fatalf("Invalid DB_SSL_MODE '%s'. Must be: disable, require, verify-ca, or verify-full", sslMode) } // Warn if using insecure mode if sslMode == "disable" { log.Println("WARNING: Database SSL is DISABLED. This should only be used for local development!") } // Validate schema value if provided validSchemas := map[string]bool{ "": true, // Empty means use public (default) "public": true, "dev": true, "testing": true, "prod": true, } if !validSchemas[schema] { log.Fatalf("Invalid DB_SCHEMA '%s'. Must be: dev, testing, prod, or empty for public", schema) } // Build connection string with optional search_path connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s sslmode=%s", user, password, name, host, sslMode) // Add search_path if schema is specified if schema != "" && schema != "public" { connStr += fmt.Sprintf(" search_path=%s,public", schema) } database, err := sql.Open("postgres", connStr) if err != nil { log.Fatalf("Error opening database: %v", err) } // Configure connection pool limits database.SetMaxOpenConns(25) // Max open connections to DB database.SetMaxIdleConns(5) // Max idle connections in pool database.SetConnMaxLifetime(5 * time.Minute) // Max lifetime of a connection database.SetConnMaxIdleTime(1 * time.Minute) // Max time a connection can be idle if err := database.Ping(); err != nil { log.Fatalf("Error connecting to database: %v", err) } schemaInfo := "public" if schema != "" && schema != "public" { schemaInfo = schema } log.Printf("Successfully connected to database (SSL mode: %s, schema: %s, max_conns: 25)", sslMode, schemaInfo) return database } // limitBodySize middleware limits the request body size to prevent DoS attacks func limitBodySize(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Body != nil { r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody) } next.ServeHTTP(w, r) }) } func corsMiddleware(next http.Handler) http.Handler { // Log CORS configuration on first request var corsConfigLogged bool return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { allowedOrigin := strings.TrimSpace(os.Getenv("CORS_ALLOW_ORIGIN")) if allowedOrigin == "" { // Default to restrictive in production - set CORS_ALLOW_ORIGIN explicitly allowedOrigin = "http://localhost:8090" if !corsConfigLogged { log.Println("WARNING: CORS_ALLOW_ORIGIN not set, defaulting to http://localhost:8090") log.Println("For production, set CORS_ALLOW_ORIGIN to your frontend domain (e.g., https://coppertone.tech)") corsConfigLogged = true } } else if allowedOrigin == "*" && !corsConfigLogged { log.Println("WARNING: CORS configured to allow ALL origins (*). This is insecure for production!") corsConfigLogged = true } // CORS headers w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-CSRF-Token") w.Header().Set("Access-Control-Allow-Credentials", "true") // Security headers w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Content-Security-Policy", "default-src 'self'") w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } // ===== REGISTRATION HANDLERS ===== func handleRegisterEmailPassword(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // Rate limit registrations by IP clientIP := getClientIP(r) if registerLimiter.checkRateLimit(clientIP, maxRegisterPerHour, time.Hour) { log.Printf("SECURITY: Registration rate limit exceeded for IP %s", clientIP) http.Error(w, "Too many registration attempts. Please try again later.", http.StatusTooManyRequests) return } var req RegisterEmailPasswordRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Validate inputs if err := validateEmail(req.Email); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(err) return } if err := validatePassword(req.Password); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(err) return } if err := validateName(req.Name); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(err) return } // Normalize email req.Email = strings.TrimSpace(strings.ToLower(req.Email)) req.Name = strings.TrimSpace(req.Name) // Security: Force CLIENT role for all public registrations // First user will be auto-promoted to SUPERUSER with initial_superuser flag // Staff/Admin roles can only be granted by existing ADMIN/SUPERUSER role := "CLIENT" isInitialSuperuser := false // Check if this is the first user (auto-promote to SUPERUSER) var userCount int err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) if err != nil { log.Printf("Error checking user count: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } if userCount == 0 { role = "SUPERUSER" isInitialSuperuser = true log.Println("AUDIT: First user registration - creating INITIAL SUPERUSER (god-like, non-removable)") } // Hash password passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) if err != nil { http.Error(w, "Failed to hash password", http.StatusInternalServerError) return } // Start transaction tx, err := db.Begin() if err != nil { http.Error(w, "Database error", http.StatusInternalServerError) return } defer tx.Rollback() // Create user var userID int err = tx.QueryRow( "INSERT INTO users (name, email, is_initial_superuser, is_protected, created_at) VALUES ($1, $2, $3, $4, NOW()) RETURNING id", req.Name, req.Email, isInitialSuperuser, isInitialSuperuser, ).Scan(&userID) if err != nil { if strings.Contains(err.Error(), "duplicate key") { http.Error(w, "Email already registered", http.StatusConflict) } else { log.Printf("Error creating user: %v", err) http.Error(w, "Failed to create user", http.StatusInternalServerError) } return } // Create identity _, err = tx.Exec(` INSERT INTO identities (user_id, type, identifier, credential, is_primary_login, created_at) VALUES ($1, 'email_password', $2, $3, true, NOW()) `, userID, req.Email, string(passwordHash)) if err != nil { http.Error(w, "Failed to create identity", http.StatusInternalServerError) return } // Assign role _, err = tx.Exec( "INSERT INTO user_roles (user_id, role, created_at) VALUES ($1, $2, NOW())", userID, role, ) if err != nil { http.Error(w, "Failed to assign role", http.StatusInternalServerError) return } if err = tx.Commit(); err != nil { http.Error(w, "Failed to complete registration", http.StatusInternalServerError) return } w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(map[string]interface{}{ "message": "User registered successfully", "userId": userID, }) } func handleRegisterBlockchain(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // Rate limit registrations by IP clientIP := getClientIP(r) if registerLimiter.checkRateLimit(clientIP, maxRegisterPerHour, time.Hour) { log.Printf("SECURITY: Blockchain registration rate limit exceeded for IP %s", clientIP) http.Error(w, "Too many registration attempts. Please try again later.", http.StatusTooManyRequests) return } var req RegisterBlockchainRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Validate inputs if err := validateName(req.Name); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(err) return } if req.Address == "" { http.Error(w, "Blockchain address is required", http.StatusBadRequest) return } if req.Signature == "" || req.Message == "" { http.Error(w, "Signature and message are required for verification", http.StatusBadRequest) return } req.Name = strings.TrimSpace(req.Name) // Security: Force CLIENT role for all public registrations // First user will be auto-promoted to SUPERUSER with initial_superuser flag // Staff/Admin roles can only be granted by existing ADMIN/SUPERUSER role := "CLIENT" isInitialSuperuser := false // Check if this is the first user (auto-promote to SUPERUSER) var userCount int err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount) if err != nil { log.Printf("Error checking user count: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } if userCount == 0 { role = "SUPERUSER" isInitialSuperuser = true log.Println("AUDIT: First user registration (blockchain) - creating INITIAL SUPERUSER (god-like, non-removable)") } // Verify signature if !verifyEthereumSignature(req.Address, req.Message, req.Signature) { http.Error(w, "Invalid signature", http.StatusUnauthorized) return } // Start transaction tx, err := db.Begin() if err != nil { http.Error(w, "Database error", http.StatusInternalServerError) return } defer tx.Rollback() // Create user var userID int err = tx.QueryRow( "INSERT INTO users (name, is_initial_superuser, is_protected, created_at) VALUES ($1, $2, $3, NOW()) RETURNING id", req.Name, isInitialSuperuser, isInitialSuperuser, ).Scan(&userID) if err != nil { log.Printf("Error creating user: %v", err) http.Error(w, "Failed to create user", http.StatusInternalServerError) return } // Create identity _, err = tx.Exec(` INSERT INTO identities (user_id, type, identifier, is_primary_login, created_at) VALUES ($1, 'blockchain_address', $2, true, NOW()) `, userID, strings.ToLower(req.Address)) if err != nil { if strings.Contains(err.Error(), "duplicate key") { http.Error(w, "Blockchain address already registered", http.StatusConflict) } else { http.Error(w, "Failed to create identity", http.StatusInternalServerError) } return } // Assign role _, err = tx.Exec( "INSERT INTO user_roles (user_id, role, created_at) VALUES ($1, $2, NOW())", userID, role, ) if err != nil { http.Error(w, "Failed to assign role", http.StatusInternalServerError) return } if err = tx.Commit(); err != nil { http.Error(w, "Failed to complete registration", http.StatusInternalServerError) return } w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(map[string]interface{}{ "message": "User registered successfully", "userId": userID, }) } // ===== LOGIN HANDLERS ===== func handleLoginEmailPassword(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } clientIP := getClientIP(r) var req LoginEmailPasswordRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Rate limit by IP and email combination rateLimitKey := clientIP + ":" + strings.ToLower(req.Email) if loginLimiter.checkRateLimit(rateLimitKey, maxLoginAttempts, rateLimitWindow) { log.Printf("SECURITY: Rate limit exceeded for %s", rateLimitKey) http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests) return } // Find identity var userID int var passwordHash string err := db.QueryRow(` SELECT user_id, credential FROM identities WHERE type = 'email_password' AND identifier = $1 `, req.Email).Scan(&userID, &passwordHash) if err == sql.ErrNoRows { loginLimiter.recordFailedAttempt(rateLimitKey) log.Printf("SECURITY: Failed login attempt for email %s from IP %s (user not found)", req.Email, clientIP) http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } else if err != nil { http.Error(w, "Login failed", http.StatusInternalServerError) return } // Verify password if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(req.Password)); err != nil { loginLimiter.recordFailedAttempt(rateLimitKey) log.Printf("SECURITY: Failed login attempt for email %s from IP %s (wrong password)", req.Email, clientIP) http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } // Clear rate limit on successful login loginLimiter.clearAttempts(rateLimitKey) log.Printf("AUDIT: Successful login for user_id %d from IP %s", userID, clientIP) // Generate token pair (access + refresh) tokenResponse, err := generateTokenPair(userID, clientIP) if err != nil { log.Printf("Error generating token pair: %v", err) http.Error(w, "Failed to generate tokens", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(tokenResponse) } func handleLoginBlockchain(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } clientIP := getClientIP(r) var req LoginBlockchainRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Rate limit by IP and address combination rateLimitKey := clientIP + ":" + strings.ToLower(req.Address) if loginLimiter.checkRateLimit(rateLimitKey, maxLoginAttempts, rateLimitWindow) { log.Printf("SECURITY: Rate limit exceeded for blockchain login %s", rateLimitKey) http.Error(w, "Too many login attempts. Please try again later.", http.StatusTooManyRequests) return } // Verify signature if !verifyEthereumSignature(req.Address, req.Message, req.Signature) { loginLimiter.recordFailedAttempt(rateLimitKey) log.Printf("SECURITY: Failed blockchain login for address %s from IP %s (invalid signature)", req.Address, clientIP) http.Error(w, "Invalid signature", http.StatusUnauthorized) return } // Find identity var userID int err := db.QueryRow(` SELECT user_id FROM identities WHERE type = 'blockchain_address' AND identifier = $1 `, strings.ToLower(req.Address)).Scan(&userID) if err == sql.ErrNoRows { loginLimiter.recordFailedAttempt(rateLimitKey) log.Printf("SECURITY: Failed blockchain login for address %s from IP %s (not registered)", req.Address, clientIP) http.Error(w, "Address not registered", http.StatusUnauthorized) return } else if err != nil { http.Error(w, "Login failed", http.StatusInternalServerError) return } // Clear rate limit on successful login loginLimiter.clearAttempts(rateLimitKey) log.Printf("AUDIT: Successful blockchain login for user_id %d from IP %s", userID, clientIP) // Generate token pair (access + refresh) tokenResponse, err := generateTokenPair(userID, clientIP) if err != nil { log.Printf("Error generating token pair: %v", err) http.Error(w, "Failed to generate tokens", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(tokenResponse) } // ===== TOKEN REFRESH AND LOGOUT HANDLERS ===== // handleRefreshToken exchanges a valid refresh token for a new access token func handleRefreshToken(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } clientIP := getClientIP(r) var req RefreshTokenRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.RefreshToken == "" { http.Error(w, "Refresh token is required", http.StatusBadRequest) return } // Validate the refresh token tokenID, userID, err := validateRefreshToken(req.RefreshToken) if err != nil { log.Printf("SECURITY: Invalid refresh token attempt from IP %s: %v", clientIP, err) http.Error(w, "Invalid or expired refresh token", http.StatusUnauthorized) return } // Revoke the old refresh token (rotation for security) if err := revokeRefreshToken(tokenID); err != nil { log.Printf("Warning: Failed to revoke old refresh token: %v", err) } // Generate new token pair tokenResponse, err := generateTokenPair(userID, clientIP) if err != nil { log.Printf("Error generating token pair during refresh: %v", err) http.Error(w, "Failed to refresh tokens", http.StatusInternalServerError) return } log.Printf("AUDIT: Token refresh for user_id %d from IP %s", userID, clientIP) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(tokenResponse) } // handleLogout revokes a specific refresh token func handleLogout(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } clientIP := getClientIP(r) var req LogoutRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.RefreshToken == "" { http.Error(w, "Refresh token is required", http.StatusBadRequest) return } // Validate and get token info tokenID, userID, err := validateRefreshToken(req.RefreshToken) if err != nil { // Even if token is invalid, return success (don't leak info) log.Printf("SECURITY: Logout with invalid token from IP %s", clientIP) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"}) return } // Revoke the refresh token if err := revokeRefreshToken(tokenID); err != nil { log.Printf("Error revoking refresh token: %v", err) http.Error(w, "Failed to logout", http.StatusInternalServerError) return } log.Printf("AUDIT: User %d logged out from IP %s", userID, clientIP) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"}) } // handleLogoutAll revokes all refresh tokens for the authenticated user func handleLogoutAll(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := int(claims["userId"].(float64)) clientIP := getClientIP(r) if err := revokeAllUserRefreshTokens(userID); err != nil { log.Printf("Error revoking all refresh tokens for user %d: %v", userID, err) http.Error(w, "Failed to logout from all devices", http.StatusInternalServerError) return } log.Printf("AUDIT: User %d logged out from all devices (initiated from IP %s)", userID, clientIP) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"message": "Logged out from all devices successfully"}) } // ===== IDENTITY MANAGEMENT ===== func handleLinkIdentity(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := int(claims["userId"].(float64)) var req LinkIdentityRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } switch req.Type { case "email_password": // Validate email if err := validateEmail(req.Email); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(err) return } // Validate password if err := validatePassword(req.Password); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(err) return } // Normalize email req.Email = strings.TrimSpace(strings.ToLower(req.Email)) passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) if err != nil { http.Error(w, "Failed to hash password", http.StatusInternalServerError) return } _, err = db.Exec(` INSERT INTO identities (user_id, type, identifier, credential, is_primary_login) VALUES ($1, 'email_password', $2, $3, false) `, userID, req.Email, string(passwordHash)) if err != nil { if strings.Contains(err.Error(), "duplicate key") { http.Error(w, "Identity already linked", http.StatusConflict) } else { http.Error(w, "Failed to link identity", http.StatusInternalServerError) } return } case "blockchain_address": if req.Address == "" || req.Signature == "" || req.Message == "" { http.Error(w, "Address, signature, and message required", http.StatusBadRequest) return } if !verifyEthereumSignature(req.Address, req.Message, req.Signature) { http.Error(w, "Invalid signature", http.StatusUnauthorized) return } _, err := db.Exec(` INSERT INTO identities (user_id, type, identifier, is_primary_login) VALUES ($1, 'blockchain_address', $2, false) `, userID, strings.ToLower(req.Address)) if err != nil { if strings.Contains(err.Error(), "duplicate key") { http.Error(w, "Identity already linked", http.StatusConflict) } else { http.Error(w, "Failed to link identity", http.StatusInternalServerError) } return } default: http.Error(w, "Invalid identity type", http.StatusBadRequest) return } w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(map[string]string{"message": "Identity linked successfully"}) } // UnlinkIdentityRequest to unlink an identity from account type UnlinkIdentityRequest struct { IdentityID int `json:"identityId"` } func handleUnlinkIdentity(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := int(claims["userId"].(float64)) var req UnlinkIdentityRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.IdentityID == 0 { http.Error(w, "Identity ID is required", http.StatusBadRequest) return } // Check identity belongs to user var identityUserID int var isPrimary bool err := db.QueryRow("SELECT user_id, is_primary_login FROM identities WHERE id = $1", req.IdentityID).Scan(&identityUserID, &isPrimary) if err == sql.ErrNoRows { http.Error(w, "Identity not found", http.StatusNotFound) return } else if err != nil { log.Printf("Error checking identity: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } if identityUserID != userID { http.Error(w, "Forbidden: identity does not belong to you", http.StatusForbidden) return } // Check user has at least 2 identities (can't remove last one) var identityCount int err = db.QueryRow("SELECT COUNT(*) FROM identities WHERE user_id = $1", userID).Scan(&identityCount) if err != nil { log.Printf("Error counting identities: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } if identityCount <= 1 { http.Error(w, "Cannot remove your last identity. You must have at least one login method.", http.StatusBadRequest) return } // If unlinking primary identity, promote another one to primary if isPrimary { _, err = db.Exec(` UPDATE identities SET is_primary_login = true WHERE user_id = $1 AND id != $2 LIMIT 1 `, userID, req.IdentityID) if err != nil { log.Printf("Error promoting identity: %v", err) // Continue anyway - identity will be deleted } } // Delete the identity result, err := db.Exec("DELETE FROM identities WHERE id = $1", req.IdentityID) if err != nil { log.Printf("Error deleting identity: %v", err) http.Error(w, "Failed to unlink identity", http.StatusInternalServerError) return } rowsAffected, _ := result.RowsAffected() if rowsAffected == 0 { http.Error(w, "Identity not found", http.StatusNotFound) return } log.Printf("AUDIT: User %d unlinked identity %d", userID, req.IdentityID) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"message": "Identity unlinked successfully"}) } func handleGetIdentities(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := int(claims["userId"].(float64)) rows, err := db.Query(` SELECT id, user_id, type, identifier, is_primary_login, created_at FROM identities WHERE user_id = $1 `, userID) if err != nil { http.Error(w, "Failed to fetch identities", http.StatusInternalServerError) return } defer rows.Close() var identities []Identity for rows.Next() { var id Identity err := rows.Scan(&id.ID, &id.UserID, &id.Type, &id.Identifier, &id.IsPrimaryLogin, &id.CreatedAt) if err != nil { http.Error(w, "Failed to parse identities", http.StatusInternalServerError) return } identities = append(identities, id) } json.NewEncoder(w).Encode(identities) } func handleProfile(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := int(claims["userId"].(float64)) var user User err := db.QueryRow(` SELECT id, name, email, created_at FROM users WHERE id = $1 `, userID).Scan(&user.ID, &user.Name, &user.Email, &user.CreatedAt) if err != nil { http.Error(w, "Failed to fetch profile", http.StatusInternalServerError) return } // Get roles rows, err := db.Query("SELECT role FROM user_roles WHERE user_id = $1", userID) if err != nil { http.Error(w, "Failed to fetch roles", http.StatusInternalServerError) return } defer rows.Close() var roles []string for rows.Next() { var role string rows.Scan(&role) roles = append(roles, role) } user.Roles = roles json.NewEncoder(w).Encode(user) } // ===== MIDDLEWARE ===== func authenticate(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { http.Error(w, "Missing authorization token", http.StatusUnauthorized) return } const bearerPrefix = "Bearer " if !strings.HasPrefix(authHeader, bearerPrefix) { http.Error(w, "Invalid authorization header format", http.StatusUnauthorized) return } tokenString := strings.TrimPrefix(authHeader, bearerPrefix) token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method") } return jwtSecret, nil }) if err != nil || !token.Valid { http.Error(w, "Invalid or expired token", http.StatusUnauthorized) return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { http.Error(w, "Invalid token claims", http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), userContextKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) } } func requireRole(next http.HandlerFunc, allowedRoles ...string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { claims := r.Context().Value(userContextKey).(jwt.MapClaims) userRoles := claims["roles"].([]interface{}) // Check if user has any of the allowed roles for _, userRole := range userRoles { for _, allowedRole := range allowedRoles { if userRole.(string) == allowedRole { next.ServeHTTP(w, r) return } } } http.Error(w, "Forbidden: insufficient permissions", http.StatusForbidden) } } // requireCSRF validates CSRF token for state-changing requests (POST, PUT, DELETE) // This middleware should be used after authenticate() for protected endpoints func requireCSRF(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Only validate CSRF for state-changing methods if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions { next.ServeHTTP(w, r) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := int(claims["userId"].(float64)) clientIP := getClientIP(r) // Get CSRF token from header csrfToken := r.Header.Get("X-CSRF-Token") if csrfToken == "" { log.Printf("SECURITY: CSRF token missing for user %d from IP %s on %s %s", userID, clientIP, r.Method, r.URL.Path) http.Error(w, "CSRF token required", http.StatusForbidden) return } // Validate CSRF token if !validateCSRFToken(userID, csrfToken) { log.Printf("SECURITY: Invalid CSRF token for user %d from IP %s on %s %s", userID, clientIP, r.Method, r.URL.Path) http.Error(w, "Invalid CSRF token", http.StatusForbidden) return } // Verify Origin header matches expected origin (additional CSRF protection) origin := r.Header.Get("Origin") allowedOrigin := strings.TrimSpace(os.Getenv("CORS_ALLOW_ORIGIN")) if allowedOrigin == "" { allowedOrigin = "http://localhost:8090" } // If Origin header is present, verify it matches if origin != "" && allowedOrigin != "*" && origin != allowedOrigin { log.Printf("SECURITY: Origin mismatch for user %d: expected %s, got %s", userID, allowedOrigin, origin) http.Error(w, "Origin not allowed", http.StatusForbidden) return } next.ServeHTTP(w, r) } } // ===== HELPER FUNCTIONS ===== // generateAccessToken creates a short-lived JWT access token func generateAccessToken(userID int) (string, error) { // Get user roles rows, err := db.Query("SELECT role FROM user_roles WHERE user_id = $1", userID) if err != nil { return "", err } defer rows.Close() var roles []string for rows.Next() { var role string rows.Scan(&role) roles = append(roles, role) } // Get user email and superuser status var email *string var isInitialSuperuser bool db.QueryRow(` SELECT email, COALESCE(is_initial_superuser, false) FROM users WHERE id = $1 `, userID).Scan(&email, &isInitialSuperuser) token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "user_id": userID, // snake_case for downstream services "userId": userID, // camelCase for compatibility "email": email, "roles": roles, "isInitialSuperuser": isInitialSuperuser, "exp": time.Now().Add(accessTokenExpiry).Unix(), "type": "access", }) return token.SignedString(jwtSecret) } // generateRefreshToken creates a secure random refresh token and stores it in the database func generateRefreshToken(userID int, clientIP string) (string, error) { // Generate cryptographically secure random token tokenBytes := make([]byte, refreshTokenLength) if _, err := rand.Read(tokenBytes); err != nil { return "", fmt.Errorf("failed to generate random token: %w", err) } token := hex.EncodeToString(tokenBytes) // Hash the token before storing (we only store the hash) tokenHash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost) if err != nil { return "", fmt.Errorf("failed to hash refresh token: %w", err) } expiresAt := time.Now().Add(refreshTokenExpiry) // Store hashed token in database _, err = db.Exec(` INSERT INTO refresh_tokens (user_id, token_hash, expires_at, client_ip, created_at) VALUES ($1, $2, $3, $4, NOW()) `, userID, string(tokenHash), expiresAt, clientIP) if err != nil { return "", fmt.Errorf("failed to store refresh token: %w", err) } log.Printf("AUDIT: Refresh token created for user_id %d from IP %s, expires %s", userID, clientIP, expiresAt.Format(time.RFC3339)) return token, nil } // validateRefreshToken validates a refresh token and returns the user ID func validateRefreshToken(token string) (int, int, error) { // Get all non-expired tokens and check against provided token rows, err := db.Query(` SELECT id, user_id, token_hash FROM refresh_tokens WHERE expires_at > NOW() AND revoked_at IS NULL `) if err != nil { return 0, 0, fmt.Errorf("failed to query refresh tokens: %w", err) } defer rows.Close() for rows.Next() { var tokenID, userID int var tokenHash string if err := rows.Scan(&tokenID, &userID, &tokenHash); err != nil { continue } // Compare token with hash if bcrypt.CompareHashAndPassword([]byte(tokenHash), []byte(token)) == nil { return tokenID, userID, nil } } return 0, 0, errors.New("invalid or expired refresh token") } // revokeRefreshToken marks a refresh token as revoked func revokeRefreshToken(tokenID int) error { _, err := db.Exec(` UPDATE refresh_tokens SET revoked_at = NOW() WHERE id = $1 `, tokenID) return err } // revokeAllUserRefreshTokens revokes all refresh tokens for a user func revokeAllUserRefreshTokens(userID int) error { result, err := db.Exec(` UPDATE refresh_tokens SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL `, userID) if err != nil { return err } rowsAffected, _ := result.RowsAffected() log.Printf("AUDIT: Revoked %d refresh tokens for user_id %d", rowsAffected, userID) return nil } // ===== CSRF TOKEN MANAGEMENT ===== // generateCSRFToken creates a cryptographically secure CSRF token func generateCSRFToken() (string, error) { tokenBytes := make([]byte, csrfTokenLength) if _, err := rand.Read(tokenBytes); err != nil { return "", fmt.Errorf("failed to generate CSRF token: %w", err) } return hex.EncodeToString(tokenBytes), nil } // storeCSRFToken stores a CSRF token in the database linked to a user func storeCSRFToken(userID int, csrfToken string, clientIP string) error { // Hash the CSRF token before storing tokenHash, err := bcrypt.GenerateFromPassword([]byte(csrfToken), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("failed to hash CSRF token: %w", err) } expiresAt := time.Now().Add(accessTokenExpiry) // CSRF token expires with access token // Store hashed token in database _, err = db.Exec(` INSERT INTO csrf_tokens (user_id, token_hash, expires_at, client_ip, created_at) VALUES ($1, $2, $3, $4, NOW()) `, userID, string(tokenHash), expiresAt, clientIP) if err != nil { return fmt.Errorf("failed to store CSRF token: %w", err) } return nil } // validateCSRFToken validates a CSRF token for a user func validateCSRFToken(userID int, token string) bool { if token == "" { return false } // Get all non-expired tokens for this user and check against provided token rows, err := db.Query(` SELECT token_hash FROM csrf_tokens WHERE user_id = $1 AND expires_at > NOW() `, userID) if err != nil { log.Printf("Error querying CSRF tokens: %v", err) return false } defer rows.Close() for rows.Next() { var tokenHash string if err := rows.Scan(&tokenHash); err != nil { continue } // Compare token with hash if bcrypt.CompareHashAndPassword([]byte(tokenHash), []byte(token)) == nil { return true } } return false } // cleanupExpiredCSRFTokens removes expired CSRF tokens (call periodically) func cleanupExpiredCSRFTokens() { result, err := db.Exec(`DELETE FROM csrf_tokens WHERE expires_at < NOW()`) if err != nil { log.Printf("Error cleaning up expired CSRF tokens: %v", err) return } rowsAffected, _ := result.RowsAffected() if rowsAffected > 0 { log.Printf("Cleaned up %d expired CSRF tokens", rowsAffected) } } // revokeUserCSRFTokens revokes all CSRF tokens for a user (on logout) func revokeUserCSRFTokens(userID int) error { _, err := db.Exec(`DELETE FROM csrf_tokens WHERE user_id = $1`, userID) return err } // generateTokenPair creates access, refresh, and CSRF tokens func generateTokenPair(userID int, clientIP string) (*AuthTokenResponse, error) { accessToken, err := generateAccessToken(userID) if err != nil { return nil, fmt.Errorf("failed to generate access token: %w", err) } refreshToken, err := generateRefreshToken(userID, clientIP) if err != nil { return nil, fmt.Errorf("failed to generate refresh token: %w", err) } // Generate CSRF token csrfToken, err := generateCSRFToken() if err != nil { return nil, fmt.Errorf("failed to generate CSRF token: %w", err) } // Store CSRF token in database if err := storeCSRFToken(userID, csrfToken, clientIP); err != nil { log.Printf("Warning: failed to store CSRF token: %v", err) // Continue anyway - CSRF validation will fail but auth still works } return &AuthTokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, CsrfToken: csrfToken, ExpiresIn: int64(accessTokenExpiry.Seconds()), TokenType: "Bearer", }, nil } // Legacy function for backward compatibility func generateToken(userID int) (string, error) { return generateAccessToken(userID) } func extractRoles(claims jwt.MapClaims) ([]string, error) { rawRoles, ok := claims["roles"] if !ok { return nil, errors.New("roles not present in token") } switch v := rawRoles.(type) { case []interface{}: roles := make([]string, 0, len(v)) for _, roleVal := range v { roleStr, ok := roleVal.(string) if !ok { return nil, errors.New("role value not a string") } roles = append(roles, roleStr) } return roles, nil case []string: return v, nil default: return nil, errors.New("roles claim has unexpected type") } } func verifyEthereumSignature(address, message, signature string) bool { address = normalizeEthereumAddress(address) if address == "" { return false } // Decode signature sigBytes, err := hex.DecodeString(strings.TrimPrefix(signature, "0x")) if err != nil || len(sigBytes) != 65 { return false } // Ethereum specific: adjust v value if sigBytes[64] == 27 || sigBytes[64] == 28 { sigBytes[64] -= 27 } // Hash message with Ethereum prefix messageHash := crypto.Keccak256Hash([]byte(fmt.Sprintf("\x19Ethereum Signed Message:\n%d%s", len(message), message))) // Recover public key pubKey, err := crypto.SigToPub(messageHash.Bytes(), sigBytes) if err != nil { return false } // Get address from public key recoveredAddr := crypto.PubkeyToAddress(*pubKey).Hex() return strings.ToLower(recoveredAddr) == address } // === Test helpers and shared utilities === func generateJWT(user User, roles []string) (string, error) { if len(jwtSecret) == 0 { return "", errors.New("JWT_SECRET not configured") } token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "user_id": user.ID, "userId": user.ID, "email": user.Email, "roles": roles, "exp": time.Now().Add(time.Hour).Unix(), }) return token.SignedString(jwtSecret) } func hashPassword(password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return "", err } return string(hash), nil } func checkPasswordHash(password, hash string) bool { return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil } func normalizeEthereumAddress(addr string) string { addr = strings.ToLower(strings.TrimSpace(addr)) if addr == "" { return "" } if !strings.HasPrefix(addr, "0x") { addr = "0x" + addr } if len(addr) != 42 { // 0x + 40 hex chars return "" } return addr } // ===== ADMIN ENDPOINTS ===== // PromoteUserRoleRequest for promoting user roles type PromoteUserRoleRequest struct { UserID int `json:"userId"` Role string `json:"role"` } // DemoteUserRoleRequest for demoting user roles type DemoteUserRoleRequest struct { UserID int `json:"userId"` Role string `json:"role"` } // handleGetAllUsers returns all users (ADMIN/SUPERUSER only) func handleGetAllUsers(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } rows, err := db.Query(` SELECT u.id, u.name, u.email, COALESCE(u.is_initial_superuser, false), COALESCE(u.is_protected, false), u.created_at FROM users u ORDER BY u.created_at DESC `) if err != nil { log.Printf("Error fetching users: %v", err) http.Error(w, "Failed to fetch users", http.StatusInternalServerError) return } defer rows.Close() var users []User for rows.Next() { var u User err := rows.Scan(&u.ID, &u.Name, &u.Email, &u.IsInitialSuperuser, &u.IsProtected, &u.CreatedAt) if err != nil { log.Printf("Error scanning user: %v", err) http.Error(w, "Failed to parse users", http.StatusInternalServerError) return } users = append(users, u) } // Fetch roles for each user for i := range users { roleRows, err := db.Query("SELECT role FROM user_roles WHERE user_id = $1", users[i].ID) if err != nil { log.Printf("Error fetching roles for user %d: %v", users[i].ID, err) continue } var roles []string for roleRows.Next() { var role string roleRows.Scan(&role) roles = append(roles, role) } roleRows.Close() users[i].Roles = roles } // Return empty array instead of null if no users if users == nil { users = []User{} } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(users) } // handleDemoteUserRole allows ADMIN/SUPERUSER users to remove roles from other users // ADMIN can only demote CLIENT, STAFF, ADMIN roles (cannot touch SUPERUSER) // SUPERUSER can demote any role except SUPERUSER (use /superuser/demote for that) func handleDemoteUserRole(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) adminUserID := int(claims["userId"].(float64)) callerRoles, _ := extractRoles(claims) var req DemoteUserRoleRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } // Validate role - SUPERUSER demotion must go through /superuser/demote if strings.ToUpper(req.Role) == "SUPERUSER" { http.Error(w, "Use /superuser/demote endpoint to remove SUPERUSER role", http.StatusBadRequest) return } validRoles := map[string]bool{"CLIENT": true, "STAFF": true, "ADMIN": true} if !validRoles[strings.ToUpper(req.Role)] { http.Error(w, "Invalid role. Must be CLIENT, STAFF, or ADMIN", http.StatusBadRequest) return } req.Role = strings.ToUpper(req.Role) // Check if caller is SUPERUSER isSuperuser := false for _, r := range callerRoles { if r == "SUPERUSER" { isSuperuser = true break } } // If not superuser, verify target is not a superuser (ADMINs cannot touch SUPERUSERs) if !isSuperuser { var targetHasSuperuser bool err := db.QueryRow(` SELECT EXISTS(SELECT 1 FROM user_roles WHERE user_id = $1 AND role = 'SUPERUSER') `, req.UserID).Scan(&targetHasSuperuser) if err == nil && targetHasSuperuser { http.Error(w, "Forbidden: ADMINs cannot modify SUPERUSER accounts", http.StatusForbidden) return } } // Prevent admin from demoting themselves from ADMIN role if req.UserID == adminUserID && req.Role == "ADMIN" && !isSuperuser { http.Error(w, "Cannot remove your own ADMIN role", http.StatusForbidden) return } // Validate user exists var userName string err := db.QueryRow("SELECT name FROM users WHERE id = $1", req.UserID).Scan(&userName) if err == sql.ErrNoRows { http.Error(w, "User not found", http.StatusNotFound) return } else if err != nil { log.Printf("Error checking user existence: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } // Check if user has this role var existingRoleID int err = db.QueryRow(` SELECT id FROM user_roles WHERE user_id = $1 AND role = $2 `, req.UserID, req.Role).Scan(&existingRoleID) if err == sql.ErrNoRows { http.Error(w, fmt.Sprintf("User does not have %s role", req.Role), http.StatusNotFound) return } else if err != nil { log.Printf("Error checking existing role: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } // Check if this is the user's only role - don't allow removal var roleCount int err = db.QueryRow("SELECT COUNT(*) FROM user_roles WHERE user_id = $1", req.UserID).Scan(&roleCount) if err != nil { log.Printf("Error counting roles: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } if roleCount <= 1 { http.Error(w, "Cannot remove user's only role. Assign a different role first.", http.StatusBadRequest) return } // Delete the role _, err = db.Exec(` DELETE FROM user_roles WHERE user_id = $1 AND role = $2 `, req.UserID, req.Role) if err != nil { log.Printf("Error deleting role: %v", err) http.Error(w, "Failed to remove role", http.StatusInternalServerError) return } // Log the role change for audit trail log.Printf("AUDIT: Admin user %d removed %s role from user %d (%s)", adminUserID, req.Role, req.UserID, userName) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{ "message": fmt.Sprintf("Successfully removed %s role from user %d", req.Role, req.UserID), }) } // handlePromoteUserRole allows ADMIN/SUPERUSER users to grant roles to other users // ADMIN can only promote to CLIENT, STAFF, ADMIN // SUPERUSER can promote to any role except SUPERUSER (use /superuser/promote for that) func handlePromoteUserRole(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) adminUserID := int(claims["userId"].(float64)) callerRoles, _ := extractRoles(claims) var req PromoteUserRoleRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } // Validate role - SUPERUSER promotion must go through /superuser/promote validRoles := map[string]bool{"CLIENT": true, "STAFF": true, "ADMIN": true} if !validRoles[strings.ToUpper(req.Role)] { if strings.ToUpper(req.Role) == "SUPERUSER" { http.Error(w, "Use /superuser/promote endpoint to promote to SUPERUSER", http.StatusBadRequest) return } http.Error(w, "Invalid role. Must be CLIENT, STAFF, or ADMIN", http.StatusBadRequest) return } req.Role = strings.ToUpper(req.Role) // Check if caller is SUPERUSER (they can do anything except promote to SUPERUSER here) isSuperuser := false for _, r := range callerRoles { if r == "SUPERUSER" { isSuperuser = true break } } // If not superuser, verify target is not a superuser (ADMINs cannot touch SUPERUSERs) if !isSuperuser { var targetHasSuperuser bool err := db.QueryRow(` SELECT EXISTS(SELECT 1 FROM user_roles WHERE user_id = $1 AND role = 'SUPERUSER') `, req.UserID).Scan(&targetHasSuperuser) if err == nil && targetHasSuperuser { http.Error(w, "Forbidden: ADMINs cannot modify SUPERUSER accounts", http.StatusForbidden) return } } // Validate user exists var userName string err := db.QueryRow("SELECT name FROM users WHERE id = $1", req.UserID).Scan(&userName) if err == sql.ErrNoRows { http.Error(w, "User not found", http.StatusNotFound) return } else if err != nil { log.Printf("Error checking user existence: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } // Check if user already has this role var existingRoleID int err = db.QueryRow(` SELECT id FROM user_roles WHERE user_id = $1 AND role = $2 `, req.UserID, req.Role).Scan(&existingRoleID) if err == nil { http.Error(w, fmt.Sprintf("User already has %s role", req.Role), http.StatusConflict) return } else if err != sql.ErrNoRows { log.Printf("Error checking existing role: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } // Insert new role _, err = db.Exec(` INSERT INTO user_roles (user_id, role, created_at) VALUES ($1, $2, NOW()) `, req.UserID, req.Role) if err != nil { log.Printf("Error inserting role: %v", err) http.Error(w, "Failed to assign role", http.StatusInternalServerError) return } // Log the role change for audit trail log.Printf("AUDIT: Admin user %d granted %s role to user %d (%s)", adminUserID, req.Role, req.UserID, userName) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{ "message": fmt.Sprintf("Successfully granted %s role to user %d", req.Role, req.UserID), }) } // ============================================================================= // SUPERUSER MANAGEMENT ENDPOINTS // ============================================================================= // SuperuserPromoteRequest for promoting user to SUPERUSER type SuperuserPromoteRequest struct { UserID int `json:"userId"` } // SuperuserDemoteRequest for demoting a SUPERUSER type SuperuserDemoteRequest struct { UserID int `json:"userId"` } // TransferInitialSuperuserRequest for transferring initial superuser status type TransferInitialSuperuserRequest struct { NewSuperuserID int `json:"newSuperuserId"` Reason string `json:"reason,omitempty"` } // handlePromoteSuperuser promotes a user to SUPERUSER (SUPERUSER only) func handlePromoteSuperuser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) callerUserID := int(claims["userId"].(float64)) var req SuperuserPromoteRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } // Validate user exists var userName string err := db.QueryRow("SELECT name FROM users WHERE id = $1", req.UserID).Scan(&userName) if err == sql.ErrNoRows { http.Error(w, "User not found", http.StatusNotFound) return } else if err != nil { log.Printf("Error checking user existence: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } // Check if user already has SUPERUSER role var existingRoleID int err = db.QueryRow(` SELECT id FROM user_roles WHERE user_id = $1 AND role = 'SUPERUSER' `, req.UserID).Scan(&existingRoleID) if err == nil { http.Error(w, "User is already a SUPERUSER", http.StatusConflict) return } else if err != sql.ErrNoRows { log.Printf("Error checking existing role: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } // Insert SUPERUSER role _, err = db.Exec(` INSERT INTO user_roles (user_id, role, created_at) VALUES ($1, 'SUPERUSER', NOW()) `, req.UserID) if err != nil { log.Printf("Error inserting SUPERUSER role: %v", err) http.Error(w, "Failed to promote user to SUPERUSER", http.StatusInternalServerError) return } log.Printf("AUDIT: SUPERUSER %d promoted user %d (%s) to SUPERUSER", callerUserID, req.UserID, userName) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{ "message": fmt.Sprintf("Successfully promoted user %d to SUPERUSER", req.UserID), }) } // handleDemoteSuperuser removes SUPERUSER role from a user (SUPERUSER only) // Cannot demote initial superuser - they must transfer first func handleDemoteSuperuser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) callerUserID := int(claims["userId"].(float64)) var req SuperuserDemoteRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } // Cannot demote yourself if req.UserID == callerUserID { http.Error(w, "Cannot demote yourself. Have another SUPERUSER do it.", http.StatusForbidden) return } // Check if target is the initial superuser var isInitialSU bool var userName string err := db.QueryRow(` SELECT name, COALESCE(is_initial_superuser, false) FROM users WHERE id = $1 `, req.UserID).Scan(&userName, &isInitialSU) if err == sql.ErrNoRows { http.Error(w, "User not found", http.StatusNotFound) return } else if err != nil { log.Printf("Error checking user: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } if isInitialSU { http.Error(w, "Cannot demote the INITIAL SUPERUSER. They must transfer their status first using /superuser/transfer", http.StatusForbidden) return } // Check if user has SUPERUSER role var existingRoleID int err = db.QueryRow(` SELECT id FROM user_roles WHERE user_id = $1 AND role = 'SUPERUSER' `, req.UserID).Scan(&existingRoleID) if err == sql.ErrNoRows { http.Error(w, "User does not have SUPERUSER role", http.StatusNotFound) return } else if err != nil { log.Printf("Error checking role: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } // Ensure user has at least one other role var roleCount int err = db.QueryRow("SELECT COUNT(*) FROM user_roles WHERE user_id = $1", req.UserID).Scan(&roleCount) if err != nil { log.Printf("Error counting roles: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } if roleCount <= 1 { // Add CLIENT role before removing SUPERUSER _, err = db.Exec(` INSERT INTO user_roles (user_id, role, created_at) VALUES ($1, 'CLIENT', NOW()) ON CONFLICT (user_id, role) DO NOTHING `, req.UserID) if err != nil { log.Printf("Error adding fallback CLIENT role: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } } // Remove SUPERUSER role _, err = db.Exec(` DELETE FROM user_roles WHERE user_id = $1 AND role = 'SUPERUSER' `, req.UserID) if err != nil { log.Printf("Error removing SUPERUSER role: %v", err) http.Error(w, "Failed to demote SUPERUSER", http.StatusInternalServerError) return } log.Printf("AUDIT: SUPERUSER %d demoted user %d (%s) from SUPERUSER", callerUserID, req.UserID, userName) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{ "message": fmt.Sprintf("Successfully removed SUPERUSER role from user %d", req.UserID), }) } // handleTransferInitialSuperuser transfers initial superuser status to another user // ONLY the current initial superuser can call this (self-transfer only) func handleTransferInitialSuperuser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } claims := r.Context().Value(userContextKey).(jwt.MapClaims) callerUserID := int(claims["userId"].(float64)) // Verify caller is the initial superuser var isInitialSU bool err := db.QueryRow(` SELECT COALESCE(is_initial_superuser, false) FROM users WHERE id = $1 `, callerUserID).Scan(&isInitialSU) if err != nil { log.Printf("Error checking initial superuser status: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } if !isInitialSU { http.Error(w, "Forbidden: Only the INITIAL SUPERUSER can transfer their status", http.StatusForbidden) return } var req TransferInitialSuperuserRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.NewSuperuserID == callerUserID { http.Error(w, "Cannot transfer to yourself", http.StatusBadRequest) return } // Validate new superuser exists var newUserName string err = db.QueryRow("SELECT name FROM users WHERE id = $1", req.NewSuperuserID).Scan(&newUserName) if err == sql.ErrNoRows { http.Error(w, "Target user not found", http.StatusNotFound) return } else if err != nil { log.Printf("Error checking user existence: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } // Start transaction for transfer tx, err := db.Begin() if err != nil { log.Printf("Error starting transaction: %v", err) http.Error(w, "Database error", http.StatusInternalServerError) return } defer tx.Rollback() // 1. Remove initial_superuser flag from current user _, err = tx.Exec(` UPDATE users SET is_initial_superuser = false WHERE id = $1 `, callerUserID) if err != nil { log.Printf("Error removing initial superuser flag: %v", err) http.Error(w, "Transfer failed", http.StatusInternalServerError) return } // 2. Set initial_superuser and is_protected on new user _, err = tx.Exec(` UPDATE users SET is_initial_superuser = true, is_protected = true WHERE id = $1 `, req.NewSuperuserID) if err != nil { log.Printf("Error setting initial superuser flag: %v", err) http.Error(w, "Transfer failed", http.StatusInternalServerError) return } // 3. Ensure new user has SUPERUSER role _, err = tx.Exec(` INSERT INTO user_roles (user_id, role, created_at) VALUES ($1, 'SUPERUSER', NOW()) ON CONFLICT (user_id, role) DO NOTHING `, req.NewSuperuserID) if err != nil { log.Printf("Error granting SUPERUSER role: %v", err) http.Error(w, "Transfer failed", http.StatusInternalServerError) return } // 4. Record the transfer for audit _, err = tx.Exec(` INSERT INTO superuser_transfers (from_user_id, to_user_id, reason) VALUES ($1, $2, $3) `, callerUserID, req.NewSuperuserID, req.Reason) if err != nil { log.Printf("Error recording transfer: %v", err) // Don't fail on audit record } if err = tx.Commit(); err != nil { log.Printf("Error committing transfer: %v", err) http.Error(w, "Transfer failed", http.StatusInternalServerError) return } log.Printf("AUDIT: INITIAL SUPERUSER transferred from user %d to user %d (%s). Reason: %s", callerUserID, req.NewSuperuserID, newUserName, req.Reason) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{ "message": fmt.Sprintf("Successfully transferred INITIAL SUPERUSER status to user %d (%s)", req.NewSuperuserID, newUserName), }) }