package main import ( "context" "database/sql" "encoding/json" "errors" "fmt" "log" "net" "net/http" "os" "os/signal" "strconv" "strings" "sync" "syscall" "time" "github.com/golang-jwt/jwt/v5" _ "github.com/lib/pq" ) // Rate limiting configuration const ( rateLimitWindow = 1 * time.Minute maxWriteRequests = 30 // Max write requests per minute per IP maxReadRequests = 100 // Max read requests per minute per IP maxRequestBody = 1 << 20 // 1MB max request body size ) // Input validation limits for forum content const ( maxQuestionTitleLength = 200 maxQuestionContentLength = 50000 // ~50KB for questions maxAnswerContentLength = 50000 // ~50KB for answers maxTagLength = 50 maxTagsCount = 10 ) type rateLimiter struct { mu sync.RWMutex requests map[string]*requestInfo } type requestInfo struct { count int firstReq time.Time } var writeLimiter = &rateLimiter{requests: make(map[string]*requestInfo)} var readLimiter = &rateLimiter{requests: make(map[string]*requestInfo)} func (rl *rateLimiter) checkRateLimit(key string, maxRequests int, window time.Duration) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() info, exists := rl.requests[key] if !exists { rl.requests[key] = &requestInfo{count: 1, firstReq: now} return false } if now.Sub(info.firstReq) > window { info.count = 1 info.firstReq = now return false } info.count++ return info.count > maxRequests } func getClientIP(r *http.Request) string { xff := r.Header.Get("X-Forwarded-For") if xff != "" { if idx := strings.Index(xff, ","); idx != -1 { return strings.TrimSpace(xff[:idx]) } return strings.TrimSpace(xff) } xri := r.Header.Get("X-Real-IP") if xri != "" { return strings.TrimSpace(xri) } ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return ip } // Context keys type contextKey string var userContextKey = contextKey("userClaims") // Question represents a forum question type Question struct { ID int `json:"id"` Title string `json:"title"` Content string `json:"content"` AuthorID int `json:"authorId"` AuthorName string `json:"authorName"` Tags []string `json:"tags"` Upvotes int `json:"upvotes"` Downvotes int `json:"downvotes"` AnswerCount int `json:"answerCount"` ViewCount int `json:"viewCount"` AcceptedID *int `json:"acceptedAnswerId,omitempty"` Status string `json:"status"` // OPEN, ANSWERED, CLOSED CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } // Answer represents an answer to a question type Answer struct { ID int `json:"id"` QuestionID int `json:"questionId"` Content string `json:"content"` AuthorID int `json:"authorId"` AuthorName string `json:"authorName"` Upvotes int `json:"upvotes"` Downvotes int `json:"downvotes"` IsAccepted bool `json:"isAccepted"` IsVerified bool `json:"isVerified"` // Admin-verified correct answer VerifiedBy *int `json:"verifiedBy,omitempty"` VerifiedAt *time.Time `json:"verifiedAt,omitempty"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } // Vote represents a user's vote on a question or answer type Vote struct { ID int `json:"id"` UserID int `json:"userId"` TargetType string `json:"targetType"` // "question" or "answer" TargetID int `json:"targetId"` VoteType int `json:"voteType"` // 1 = upvote, -1 = downvote CreatedAt time.Time `json:"createdAt"` } type CreateQuestionRequest struct { Title string `json:"title"` Content string `json:"content"` Tags []string `json:"tags"` } type CreateAnswerRequest struct { Content string `json:"content"` } type VoteRequest struct { VoteType int `json:"voteType"` // 1 or -1 } var ( db *sql.DB jwtSecret []byte ) func initDB() { var err error dbHost := strings.TrimSpace(os.Getenv("DB_HOST")) dbUser := strings.TrimSpace(os.Getenv("DB_USER")) dbPassword := strings.TrimSpace(os.Getenv("DB_PASSWORD")) dbName := strings.TrimSpace(os.Getenv("DB_NAME")) dbSSLMode := strings.TrimSpace(os.Getenv("DB_SSL_MODE")) dbSchema := strings.TrimSpace(os.Getenv("DB_SCHEMA")) if dbHost == "" || dbUser == "" || dbPassword == "" || dbName == "" { log.Fatal("Database configuration missing: DB_HOST, DB_USER, DB_PASSWORD, DB_NAME required") } if dbSSLMode == "" { dbSSLMode = "require" } // Validate schema value if provided validSchemas := map[string]bool{"": true, "public": true, "dev": true, "testing": true, "prod": true} if !validSchemas[dbSchema] { log.Fatalf("Invalid DB_SCHEMA '%s'. Must be: dev, testing, prod, or empty for public", dbSchema) } connStr := fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=%s", dbHost, dbUser, dbPassword, dbName, dbSSLMode) if dbSchema != "" && dbSchema != "public" { connStr += fmt.Sprintf(" search_path=%s,public", dbSchema) } db, err = sql.Open("postgres", connStr) if err != nil { log.Fatal(err) } // Configure connection pool limits db.SetMaxOpenConns(25) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(5 * time.Minute) db.SetConnMaxIdleTime(1 * time.Minute) if err = db.Ping(); err != nil { log.Fatal(err) } schemaInfo := "public" if dbSchema != "" && dbSchema != "public" { schemaInfo = dbSchema } log.Printf("Connected to database (SSL mode: %s, schema: %s, max_conns: 25)", dbSSLMode, schemaInfo) // Create tables createTablesSQL := ` CREATE TABLE IF NOT EXISTS forum_questions ( id SERIAL PRIMARY KEY, title VARCHAR(500) NOT NULL, content TEXT NOT NULL, author_id INTEGER NOT NULL, author_name VARCHAR(255) NOT NULL, tags TEXT[] DEFAULT '{}', upvotes INTEGER DEFAULT 0, downvotes INTEGER DEFAULT 0, answer_count INTEGER DEFAULT 0, view_count INTEGER DEFAULT 0, accepted_answer_id INTEGER, status VARCHAR(50) DEFAULT 'OPEN', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS forum_answers ( id SERIAL PRIMARY KEY, question_id INTEGER NOT NULL REFERENCES forum_questions(id) ON DELETE CASCADE, content TEXT NOT NULL, author_id INTEGER NOT NULL, author_name VARCHAR(255) NOT NULL, upvotes INTEGER DEFAULT 0, downvotes INTEGER DEFAULT 0, is_accepted BOOLEAN DEFAULT FALSE, is_verified BOOLEAN DEFAULT FALSE, verified_by INTEGER, verified_at TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS forum_votes ( id SERIAL PRIMARY KEY, user_id INTEGER NOT NULL, target_type VARCHAR(20) NOT NULL, target_id INTEGER NOT NULL, vote_type INTEGER NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, UNIQUE(user_id, target_type, target_id) ); CREATE INDEX IF NOT EXISTS idx_questions_author ON forum_questions(author_id); CREATE INDEX IF NOT EXISTS idx_questions_status ON forum_questions(status); CREATE INDEX IF NOT EXISTS idx_questions_created ON forum_questions(created_at DESC); CREATE INDEX IF NOT EXISTS idx_answers_question ON forum_answers(question_id); CREATE INDEX IF NOT EXISTS idx_answers_author ON forum_answers(author_id); CREATE INDEX IF NOT EXISTS idx_votes_user ON forum_votes(user_id); CREATE INDEX IF NOT EXISTS idx_votes_target ON forum_votes(target_type, target_id); ` if _, err := db.Exec(createTablesSQL); err != nil { log.Fatal("Failed to create tables:", err) } log.Println("Database tables initialized") } func loadConfig() { jwtSecret = []byte(strings.TrimSpace(os.Getenv("JWT_SECRET"))) if len(jwtSecret) < 32 { log.Fatal("JWT_SECRET must be set and at least 32 characters") } } func enableCORS(w http.ResponseWriter) { corsOrigin := strings.TrimSpace(os.Getenv("CORS_ALLOW_ORIGIN")) if corsOrigin == "" { corsOrigin = "http://localhost:8090" } w.Header().Set("Access-Control-Allow-Origin", corsOrigin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") // Security headers w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Content-Security-Policy", "default-src 'self'") } func authMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { http.Error(w, "Authorization header required", http.StatusUnauthorized) return } tokenString := strings.TrimPrefix(authHeader, "Bearer ") if tokenString == authHeader { http.Error(w, "Invalid authorization format", http.StatusUnauthorized) return } token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method") } return jwtSecret, nil }) if err != nil || !token.Valid { http.Error(w, "Invalid or expired token", http.StatusUnauthorized) return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { http.Error(w, "Invalid token claims", http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), userContextKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) } } func optionalAuth(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { next.ServeHTTP(w, r) return } tokenString := strings.TrimPrefix(authHeader, "Bearer ") if tokenString == authHeader { next.ServeHTTP(w, r) return } token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method") } return jwtSecret, nil }) if err != nil || !token.Valid { next.ServeHTTP(w, r) return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { next.ServeHTTP(w, r) return } ctx := context.WithValue(r.Context(), userContextKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) } } func requireRole(next http.HandlerFunc, allowedRoles ...string) http.HandlerFunc { return authMiddleware(func(w http.ResponseWriter, r *http.Request) { claims := r.Context().Value(userContextKey).(jwt.MapClaims) userRoles, err := extractRoles(claims) if err != nil { http.Error(w, "No roles found", http.StatusForbidden) return } for _, userRole := range userRoles { for _, allowedRole := range allowedRoles { if userRole == allowedRole { next.ServeHTTP(w, r) return } } } http.Error(w, "Insufficient permissions", http.StatusForbidden) }) } func extractRoles(claims jwt.MapClaims) ([]string, error) { rawRoles, ok := claims["roles"] if !ok { return nil, errors.New("roles missing") } switch v := rawRoles.(type) { case []interface{}: out := make([]string, 0, len(v)) for _, r := range v { roleStr, ok := r.(string) if !ok { return nil, errors.New("role not string") } out = append(out, roleStr) } return out, nil case []string: return v, nil default: return nil, errors.New("invalid roles type") } } func hasRole(claims jwt.MapClaims, role string) bool { roles, err := extractRoles(claims) if err != nil { return false } for _, r := range roles { // SUPERUSER has all permissions if r == "SUPERUSER" { return true } if r == role { return true } } return false } func getUserID(claims jwt.MapClaims) int { if id, ok := claims["userId"].(float64); ok { return int(id) } if id, ok := claims["user_id"].(float64); ok { return int(id) } return 0 } func getUserName(claims jwt.MapClaims) string { if name, ok := claims["name"].(string); ok && name != "" { return name } if email, ok := claims["email"].(string); ok && email != "" { parts := strings.Split(email, "@") return parts[0] } return "Anonymous" } func respondJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } func respondError(w http.ResponseWriter, status int, message string) { respondJSON(w, status, map[string]string{"error": message}) } func parseTags(tagsStr string) []string { if tagsStr == "" || tagsStr == "{}" { return []string{} } tagsStr = strings.Trim(tagsStr, "{}") if tagsStr == "" { return []string{} } return strings.Split(tagsStr, ",") } func tagsToPostgres(tags []string) string { if len(tags) == 0 { return "{}" } return "{" + strings.Join(tags, ",") + "}" } // ============ QUESTION HANDLERS ============ // GET /questions - List questions func listQuestionsHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } // Parse query parameters tag := r.URL.Query().Get("tag") status := r.URL.Query().Get("status") sort := r.URL.Query().Get("sort") // "newest", "votes", "unanswered" limitStr := r.URL.Query().Get("limit") offsetStr := r.URL.Query().Get("offset") limit := 20 offset := 0 if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 100 { limit = l } if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 { offset = o } query := `SELECT id, title, content, author_id, author_name, tags, upvotes, downvotes, answer_count, view_count, accepted_answer_id, status, created_at, updated_at FROM forum_questions WHERE 1=1` args := []interface{}{} argNum := 1 if tag != "" { query += fmt.Sprintf(" AND $%d = ANY(tags)", argNum) args = append(args, tag) argNum++ } if status != "" { query += fmt.Sprintf(" AND status = $%d", argNum) args = append(args, status) argNum++ } // Sort order switch sort { case "votes": query += " ORDER BY (upvotes - downvotes) DESC, created_at DESC" case "unanswered": query += " AND answer_count = 0 ORDER BY created_at DESC" default: // newest query += " ORDER BY created_at DESC" } query += fmt.Sprintf(" LIMIT $%d OFFSET $%d", argNum, argNum+1) args = append(args, limit, offset) rows, err := db.Query(query, args...) if err != nil { log.Println("Database error:", err) respondError(w, http.StatusInternalServerError, "Failed to fetch questions") return } defer rows.Close() questions := []Question{} for rows.Next() { var q Question var tags string var acceptedID sql.NullInt64 err := rows.Scan(&q.ID, &q.Title, &q.Content, &q.AuthorID, &q.AuthorName, &tags, &q.Upvotes, &q.Downvotes, &q.AnswerCount, &q.ViewCount, &acceptedID, &q.Status, &q.CreatedAt, &q.UpdatedAt) if err != nil { log.Println("Scan error:", err) continue } q.Tags = parseTags(tags) if acceptedID.Valid { id := int(acceptedID.Int64) q.AcceptedID = &id } questions = append(questions, q) } if err = rows.Err(); err != nil { log.Println("Rows error:", err) } respondJSON(w, http.StatusOK, questions) } // GET /questions/:id - Get question with answers func getQuestionHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } idStr := strings.TrimPrefix(r.URL.Path, "/questions/") id, err := strconv.Atoi(idStr) if err != nil { respondError(w, http.StatusBadRequest, "Invalid question ID") return } // Increment view count db.Exec("UPDATE forum_questions SET view_count = view_count + 1 WHERE id = $1", id) // Get question var q Question var tags string var acceptedID sql.NullInt64 err = db.QueryRow(`SELECT id, title, content, author_id, author_name, tags, upvotes, downvotes, answer_count, view_count, accepted_answer_id, status, created_at, updated_at FROM forum_questions WHERE id = $1`, id). Scan(&q.ID, &q.Title, &q.Content, &q.AuthorID, &q.AuthorName, &tags, &q.Upvotes, &q.Downvotes, &q.AnswerCount, &q.ViewCount, &acceptedID, &q.Status, &q.CreatedAt, &q.UpdatedAt) if err == sql.ErrNoRows { respondError(w, http.StatusNotFound, "Question not found") return } else if err != nil { respondError(w, http.StatusInternalServerError, "Database error") return } q.Tags = parseTags(tags) if acceptedID.Valid { aid := int(acceptedID.Int64) q.AcceptedID = &aid } // Get answers rows, err := db.Query(`SELECT id, question_id, content, author_id, author_name, upvotes, downvotes, is_accepted, is_verified, verified_by, verified_at, created_at, updated_at FROM forum_answers WHERE question_id = $1 ORDER BY is_accepted DESC, (upvotes - downvotes) DESC, created_at ASC`, id) if err != nil { respondError(w, http.StatusInternalServerError, "Failed to fetch answers") return } defer rows.Close() answers := []Answer{} for rows.Next() { var a Answer var verifiedBy sql.NullInt64 var verifiedAt sql.NullTime err := rows.Scan(&a.ID, &a.QuestionID, &a.Content, &a.AuthorID, &a.AuthorName, &a.Upvotes, &a.Downvotes, &a.IsAccepted, &a.IsVerified, &verifiedBy, &verifiedAt, &a.CreatedAt, &a.UpdatedAt) if err != nil { continue } if verifiedBy.Valid { vb := int(verifiedBy.Int64) a.VerifiedBy = &vb } if verifiedAt.Valid { a.VerifiedAt = &verifiedAt.Time } answers = append(answers, a) } respondJSON(w, http.StatusOK, map[string]interface{}{ "question": q, "answers": answers, }) } // POST /questions - Create question func createQuestionHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) userName := getUserName(claims) var req CreateQuestionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, http.StatusBadRequest, "Invalid request body") return } if req.Title == "" || req.Content == "" { respondError(w, http.StatusBadRequest, "Title and content are required") return } var q Question err := db.QueryRow(`INSERT INTO forum_questions (title, content, author_id, author_name, tags) VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`, req.Title, req.Content, userID, userName, tagsToPostgres(req.Tags)). Scan(&q.ID, &q.CreatedAt, &q.UpdatedAt) if err != nil { log.Println("Database error:", err) respondError(w, http.StatusInternalServerError, "Failed to create question") return } q.Title = req.Title q.Content = req.Content q.AuthorID = userID q.AuthorName = userName q.Tags = req.Tags q.Status = "OPEN" log.Printf("AUDIT: User %d created question %d", userID, q.ID) respondJSON(w, http.StatusCreated, q) } // PUT /questions/:id - Update question func updateQuestionHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) isAdmin := hasRole(claims, "ADMIN") idStr := strings.TrimPrefix(r.URL.Path, "/questions/") id, err := strconv.Atoi(idStr) if err != nil { respondError(w, http.StatusBadRequest, "Invalid question ID") return } // Check ownership var authorID int err = db.QueryRow("SELECT author_id FROM forum_questions WHERE id = $1", id).Scan(&authorID) if err == sql.ErrNoRows { respondError(w, http.StatusNotFound, "Question not found") return } if authorID != userID && !isAdmin { respondError(w, http.StatusForbidden, "You can only edit your own questions") return } var req CreateQuestionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, http.StatusBadRequest, "Invalid request body") return } _, err = db.Exec(`UPDATE forum_questions SET title = $1, content = $2, tags = $3, updated_at = CURRENT_TIMESTAMP WHERE id = $4`, req.Title, req.Content, tagsToPostgres(req.Tags), id) if err != nil { respondError(w, http.StatusInternalServerError, "Failed to update question") return } log.Printf("AUDIT: User %d updated question %d", userID, id) respondJSON(w, http.StatusOK, map[string]string{"message": "Question updated"}) } // DELETE /questions/:id - Delete question func deleteQuestionHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) isAdmin := hasRole(claims, "ADMIN") idStr := strings.TrimPrefix(r.URL.Path, "/questions/") id, err := strconv.Atoi(idStr) if err != nil { respondError(w, http.StatusBadRequest, "Invalid question ID") return } // Check ownership var authorID int err = db.QueryRow("SELECT author_id FROM forum_questions WHERE id = $1", id).Scan(&authorID) if err == sql.ErrNoRows { respondError(w, http.StatusNotFound, "Question not found") return } if authorID != userID && !isAdmin { respondError(w, http.StatusForbidden, "You can only delete your own questions") return } _, err = db.Exec("DELETE FROM forum_questions WHERE id = $1", id) if err != nil { respondError(w, http.StatusInternalServerError, "Failed to delete question") return } log.Printf("AUDIT: User %d deleted question %d", userID, id) respondJSON(w, http.StatusOK, map[string]string{"message": "Question deleted"}) } // ============ ANSWER HANDLERS ============ // POST /questions/:id/answers - Create answer func createAnswerHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) userName := getUserName(claims) // Extract question ID from path like /questions/123/answers path := strings.TrimPrefix(r.URL.Path, "/questions/") parts := strings.Split(path, "/") if len(parts) < 2 { respondError(w, http.StatusBadRequest, "Invalid path") return } questionID, err := strconv.Atoi(parts[0]) if err != nil { respondError(w, http.StatusBadRequest, "Invalid question ID") return } // Check question exists var status string err = db.QueryRow("SELECT status FROM forum_questions WHERE id = $1", questionID).Scan(&status) if err == sql.ErrNoRows { respondError(w, http.StatusNotFound, "Question not found") return } if status == "CLOSED" { respondError(w, http.StatusBadRequest, "This question is closed") return } var req CreateAnswerRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, http.StatusBadRequest, "Invalid request body") return } if req.Content == "" { respondError(w, http.StatusBadRequest, "Content is required") return } var a Answer err = db.QueryRow(`INSERT INTO forum_answers (question_id, content, author_id, author_name) VALUES ($1, $2, $3, $4) RETURNING id, created_at, updated_at`, questionID, req.Content, userID, userName). Scan(&a.ID, &a.CreatedAt, &a.UpdatedAt) if err != nil { respondError(w, http.StatusInternalServerError, "Failed to create answer") return } // Update answer count db.Exec("UPDATE forum_questions SET answer_count = answer_count + 1, status = 'ANSWERED' WHERE id = $1", questionID) a.QuestionID = questionID a.Content = req.Content a.AuthorID = userID a.AuthorName = userName log.Printf("AUDIT: User %d answered question %d", userID, questionID) respondJSON(w, http.StatusCreated, a) } // PUT /answers/:id - Update answer func updateAnswerHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) isAdmin := hasRole(claims, "ADMIN") idStr := strings.TrimPrefix(r.URL.Path, "/answers/") id, err := strconv.Atoi(idStr) if err != nil { respondError(w, http.StatusBadRequest, "Invalid answer ID") return } // Check ownership var authorID int err = db.QueryRow("SELECT author_id FROM forum_answers WHERE id = $1", id).Scan(&authorID) if err == sql.ErrNoRows { respondError(w, http.StatusNotFound, "Answer not found") return } if authorID != userID && !isAdmin { respondError(w, http.StatusForbidden, "You can only edit your own answers") return } var req CreateAnswerRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, http.StatusBadRequest, "Invalid request body") return } _, err = db.Exec(`UPDATE forum_answers SET content = $1, updated_at = CURRENT_TIMESTAMP WHERE id = $2`, req.Content, id) if err != nil { respondError(w, http.StatusInternalServerError, "Failed to update answer") return } log.Printf("AUDIT: User %d updated answer %d", userID, id) respondJSON(w, http.StatusOK, map[string]string{"message": "Answer updated"}) } // DELETE /answers/:id - Delete answer func deleteAnswerHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) isAdmin := hasRole(claims, "ADMIN") idStr := strings.TrimPrefix(r.URL.Path, "/answers/") id, err := strconv.Atoi(idStr) if err != nil { respondError(w, http.StatusBadRequest, "Invalid answer ID") return } // Check ownership and get question ID var authorID, questionID int err = db.QueryRow("SELECT author_id, question_id FROM forum_answers WHERE id = $1", id).Scan(&authorID, &questionID) if err == sql.ErrNoRows { respondError(w, http.StatusNotFound, "Answer not found") return } if authorID != userID && !isAdmin { respondError(w, http.StatusForbidden, "You can only delete your own answers") return } _, err = db.Exec("DELETE FROM forum_answers WHERE id = $1", id) if err != nil { respondError(w, http.StatusInternalServerError, "Failed to delete answer") return } // Update answer count db.Exec("UPDATE forum_questions SET answer_count = answer_count - 1 WHERE id = $1", questionID) log.Printf("AUDIT: User %d deleted answer %d", userID, id) respondJSON(w, http.StatusOK, map[string]string{"message": "Answer deleted"}) } // POST /answers/:id/accept - Accept answer (question author only) func acceptAnswerHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) idStr := strings.TrimPrefix(r.URL.Path, "/answers/") idStr = strings.TrimSuffix(idStr, "/accept") id, err := strconv.Atoi(idStr) if err != nil { respondError(w, http.StatusBadRequest, "Invalid answer ID") return } // Get answer and question var questionID int err = db.QueryRow("SELECT question_id FROM forum_answers WHERE id = $1", id).Scan(&questionID) if err == sql.ErrNoRows { respondError(w, http.StatusNotFound, "Answer not found") return } // Check if user owns the question var questionAuthorID int err = db.QueryRow("SELECT author_id FROM forum_questions WHERE id = $1", questionID).Scan(&questionAuthorID) if err != nil { respondError(w, http.StatusInternalServerError, "Database error") return } if questionAuthorID != userID { respondError(w, http.StatusForbidden, "Only the question author can accept answers") return } // Unaccept any previously accepted answer db.Exec("UPDATE forum_answers SET is_accepted = FALSE WHERE question_id = $1", questionID) // Accept this answer db.Exec("UPDATE forum_answers SET is_accepted = TRUE WHERE id = $1", id) db.Exec("UPDATE forum_questions SET accepted_answer_id = $1, status = 'ANSWERED' WHERE id = $2", id, questionID) log.Printf("AUDIT: User %d accepted answer %d for question %d", userID, id, questionID) respondJSON(w, http.StatusOK, map[string]string{"message": "Answer accepted"}) } // POST /answers/:id/verify - Verify answer (ADMIN only) func verifyAnswerHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) idStr := strings.TrimPrefix(r.URL.Path, "/answers/") idStr = strings.TrimSuffix(idStr, "/verify") id, err := strconv.Atoi(idStr) if err != nil { respondError(w, http.StatusBadRequest, "Invalid answer ID") return } _, err = db.Exec(`UPDATE forum_answers SET is_verified = TRUE, verified_by = $1, verified_at = CURRENT_TIMESTAMP WHERE id = $2`, userID, id) if err != nil { respondError(w, http.StatusInternalServerError, "Failed to verify answer") return } log.Printf("AUDIT: Admin %d verified answer %d", userID, id) respondJSON(w, http.StatusOK, map[string]string{"message": "Answer verified"}) } // ============ VOTING HANDLERS ============ // POST /questions/:id/vote - Vote on question func voteQuestionHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) idStr := strings.TrimPrefix(r.URL.Path, "/questions/") idStr = strings.TrimSuffix(idStr, "/vote") id, err := strconv.Atoi(idStr) if err != nil { respondError(w, http.StatusBadRequest, "Invalid question ID") return } var req VoteRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, http.StatusBadRequest, "Invalid request body") return } if req.VoteType != 1 && req.VoteType != -1 { respondError(w, http.StatusBadRequest, "VoteType must be 1 (upvote) or -1 (downvote)") return } // Check if user already voted var existingVote int err = db.QueryRow("SELECT vote_type FROM forum_votes WHERE user_id = $1 AND target_type = 'question' AND target_id = $2", userID, id).Scan(&existingVote) if err == sql.ErrNoRows { // New vote db.Exec("INSERT INTO forum_votes (user_id, target_type, target_id, vote_type) VALUES ($1, 'question', $2, $3)", userID, id, req.VoteType) if req.VoteType == 1 { db.Exec("UPDATE forum_questions SET upvotes = upvotes + 1 WHERE id = $1", id) } else { db.Exec("UPDATE forum_questions SET downvotes = downvotes + 1 WHERE id = $1", id) } } else if existingVote != req.VoteType { // Changing vote db.Exec("UPDATE forum_votes SET vote_type = $1 WHERE user_id = $2 AND target_type = 'question' AND target_id = $3", req.VoteType, userID, id) if req.VoteType == 1 { db.Exec("UPDATE forum_questions SET upvotes = upvotes + 1, downvotes = downvotes - 1 WHERE id = $1", id) } else { db.Exec("UPDATE forum_questions SET upvotes = upvotes - 1, downvotes = downvotes + 1 WHERE id = $1", id) } } // If same vote, do nothing respondJSON(w, http.StatusOK, map[string]string{"message": "Vote recorded"}) } // POST /answers/:id/vote - Vote on answer func voteAnswerHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) idStr := strings.TrimPrefix(r.URL.Path, "/answers/") idStr = strings.TrimSuffix(idStr, "/vote") id, err := strconv.Atoi(idStr) if err != nil { respondError(w, http.StatusBadRequest, "Invalid answer ID") return } var req VoteRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, http.StatusBadRequest, "Invalid request body") return } if req.VoteType != 1 && req.VoteType != -1 { respondError(w, http.StatusBadRequest, "VoteType must be 1 or -1") return } // Check if user already voted var existingVote int err = db.QueryRow("SELECT vote_type FROM forum_votes WHERE user_id = $1 AND target_type = 'answer' AND target_id = $2", userID, id).Scan(&existingVote) if err == sql.ErrNoRows { // New vote db.Exec("INSERT INTO forum_votes (user_id, target_type, target_id, vote_type) VALUES ($1, 'answer', $2, $3)", userID, id, req.VoteType) if req.VoteType == 1 { db.Exec("UPDATE forum_answers SET upvotes = upvotes + 1 WHERE id = $1", id) } else { db.Exec("UPDATE forum_answers SET downvotes = downvotes + 1 WHERE id = $1", id) } } else if existingVote != req.VoteType { // Changing vote db.Exec("UPDATE forum_votes SET vote_type = $1 WHERE user_id = $2 AND target_type = 'answer' AND target_id = $3", req.VoteType, userID, id) if req.VoteType == 1 { db.Exec("UPDATE forum_answers SET upvotes = upvotes + 1, downvotes = downvotes - 1 WHERE id = $1", id) } else { db.Exec("UPDATE forum_answers SET upvotes = upvotes - 1, downvotes = downvotes + 1 WHERE id = $1", id) } } respondJSON(w, http.StatusOK, map[string]string{"message": "Vote recorded"}) } // POST /questions/:id/close - Close question (ADMIN only) func closeQuestionHandler(w http.ResponseWriter, r *http.Request) { enableCORS(w) claims := r.Context().Value(userContextKey).(jwt.MapClaims) userID := getUserID(claims) idStr := strings.TrimPrefix(r.URL.Path, "/questions/") idStr = strings.TrimSuffix(idStr, "/close") id, err := strconv.Atoi(idStr) if err != nil { respondError(w, http.StatusBadRequest, "Invalid question ID") return } _, err = db.Exec("UPDATE forum_questions SET status = 'CLOSED', updated_at = CURRENT_TIMESTAMP WHERE id = $1", id) if err != nil { respondError(w, http.StatusInternalServerError, "Failed to close question") return } log.Printf("AUDIT: Admin %d closed question %d", userID, id) respondJSON(w, http.StatusOK, map[string]string{"message": "Question closed"}) } func main() { loadConfig() initDB() defer db.Close() // Public endpoints http.HandleFunc("/questions", func(w http.ResponseWriter, r *http.Request) { enableCORS(w) if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } if r.Method == http.MethodGet { listQuestionsHandler(w, r) } else if r.Method == http.MethodPost { authMiddleware(createQuestionHandler)(w, r) } else { respondError(w, http.StatusMethodNotAllowed, "Method not allowed") } }) http.HandleFunc("/questions/", func(w http.ResponseWriter, r *http.Request) { enableCORS(w) if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } path := r.URL.Path switch { case strings.HasSuffix(path, "/answers"): if r.Method == http.MethodPost { authMiddleware(createAnswerHandler)(w, r) } else { respondError(w, http.StatusMethodNotAllowed, "Method not allowed") } case strings.HasSuffix(path, "/vote"): if r.Method == http.MethodPost { authMiddleware(voteQuestionHandler)(w, r) } else { respondError(w, http.StatusMethodNotAllowed, "Method not allowed") } case strings.HasSuffix(path, "/close"): if r.Method == http.MethodPost { requireRole(closeQuestionHandler, "ADMIN")(w, r) } else { respondError(w, http.StatusMethodNotAllowed, "Method not allowed") } default: if r.Method == http.MethodGet { getQuestionHandler(w, r) } else if r.Method == http.MethodPut { authMiddleware(updateQuestionHandler)(w, r) } else if r.Method == http.MethodDelete { authMiddleware(deleteQuestionHandler)(w, r) } else { respondError(w, http.StatusMethodNotAllowed, "Method not allowed") } } }) http.HandleFunc("/answers/", func(w http.ResponseWriter, r *http.Request) { enableCORS(w) if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } path := r.URL.Path switch { case strings.HasSuffix(path, "/accept"): if r.Method == http.MethodPost { authMiddleware(acceptAnswerHandler)(w, r) } else { respondError(w, http.StatusMethodNotAllowed, "Method not allowed") } case strings.HasSuffix(path, "/verify"): if r.Method == http.MethodPost { requireRole(verifyAnswerHandler, "ADMIN")(w, r) } else { respondError(w, http.StatusMethodNotAllowed, "Method not allowed") } case strings.HasSuffix(path, "/vote"): if r.Method == http.MethodPost { authMiddleware(voteAnswerHandler)(w, r) } else { respondError(w, http.StatusMethodNotAllowed, "Method not allowed") } default: if r.Method == http.MethodPut { authMiddleware(updateAnswerHandler)(w, r) } else if r.Method == http.MethodDelete { authMiddleware(deleteAnswerHandler)(w, r) } else { respondError(w, http.StatusMethodNotAllowed, "Method not allowed") } } }) // Health check (both /health and /healthz for compatibility) healthHandler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprintln(w, "ok") } http.HandleFunc("/health", healthHandler) http.HandleFunc("/healthz", healthHandler) port := os.Getenv("PORT") if port == "" { port = "8080" } // Wrap all routes with rate limiting and body size limit rateLimitedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Limit request body size to prevent DoS if r.Body != nil { r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody) } clientIP := getClientIP(r) if r.Method == http.MethodGet || r.Method == http.MethodHead { if readLimiter.checkRateLimit(clientIP, maxReadRequests, rateLimitWindow) { log.Printf("SECURITY: Read rate limit exceeded for IP %s on %s", clientIP, r.URL.Path) http.Error(w, "Too many requests. Please slow down.", http.StatusTooManyRequests) return } } else if r.Method != http.MethodOptions { if writeLimiter.checkRateLimit(clientIP, maxWriteRequests, rateLimitWindow) { log.Printf("SECURITY: Write rate limit exceeded for IP %s on %s", clientIP, r.URL.Path) http.Error(w, "Too many requests. Please slow down.", http.StatusTooManyRequests) return } } http.DefaultServeMux.ServeHTTP(w, r) }) server := &http.Server{ Addr: ":" + port, Handler: rateLimitedHandler, ReadHeaderTimeout: 10 * time.Second, ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, IdleTimeout: 60 * time.Second, } // Graceful shutdown done := make(chan bool, 1) quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) go func() { <-quit log.Println("Forum Service shutting down...") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() server.SetKeepAlivesEnabled(false) if err := server.Shutdown(ctx); err != nil { log.Printf("Could not gracefully shutdown: %v", err) } close(done) }() log.Printf("Forum service starting on port %s\n", port) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Server error: %v", err) } <-done log.Println("Forum Service stopped") }