Files
web-hosts/domains/coppertone.tech/backend/functions/forum-service/main.go
2025-12-26 13:38:04 +01:00

1310 lines
38 KiB
Go

package main
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/golang-jwt/jwt/v5"
_ "github.com/lib/pq"
)
// Rate limiting configuration
const (
rateLimitWindow = 1 * time.Minute
maxWriteRequests = 30 // Max write requests per minute per IP
maxReadRequests = 100 // Max read requests per minute per IP
maxRequestBody = 1 << 20 // 1MB max request body size
)
// Input validation limits for forum content
const (
maxQuestionTitleLength = 200
maxQuestionContentLength = 50000 // ~50KB for questions
maxAnswerContentLength = 50000 // ~50KB for answers
maxTagLength = 50
maxTagsCount = 10
)
type rateLimiter struct {
mu sync.RWMutex
requests map[string]*requestInfo
}
type requestInfo struct {
count int
firstReq time.Time
}
var writeLimiter = &rateLimiter{requests: make(map[string]*requestInfo)}
var readLimiter = &rateLimiter{requests: make(map[string]*requestInfo)}
func (rl *rateLimiter) checkRateLimit(key string, maxRequests int, window time.Duration) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
info, exists := rl.requests[key]
if !exists {
rl.requests[key] = &requestInfo{count: 1, firstReq: now}
return false
}
if now.Sub(info.firstReq) > window {
info.count = 1
info.firstReq = now
return false
}
info.count++
return info.count > maxRequests
}
func getClientIP(r *http.Request) string {
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
xri := r.Header.Get("X-Real-IP")
if xri != "" {
return strings.TrimSpace(xri)
}
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}
// Context keys
type contextKey string
var userContextKey = contextKey("userClaims")
// Question represents a forum question
type Question struct {
ID int `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
AuthorID int `json:"authorId"`
AuthorName string `json:"authorName"`
Tags []string `json:"tags"`
Upvotes int `json:"upvotes"`
Downvotes int `json:"downvotes"`
AnswerCount int `json:"answerCount"`
ViewCount int `json:"viewCount"`
AcceptedID *int `json:"acceptedAnswerId,omitempty"`
Status string `json:"status"` // OPEN, ANSWERED, CLOSED
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// Answer represents an answer to a question
type Answer struct {
ID int `json:"id"`
QuestionID int `json:"questionId"`
Content string `json:"content"`
AuthorID int `json:"authorId"`
AuthorName string `json:"authorName"`
Upvotes int `json:"upvotes"`
Downvotes int `json:"downvotes"`
IsAccepted bool `json:"isAccepted"`
IsVerified bool `json:"isVerified"` // Admin-verified correct answer
VerifiedBy *int `json:"verifiedBy,omitempty"`
VerifiedAt *time.Time `json:"verifiedAt,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// Vote represents a user's vote on a question or answer
type Vote struct {
ID int `json:"id"`
UserID int `json:"userId"`
TargetType string `json:"targetType"` // "question" or "answer"
TargetID int `json:"targetId"`
VoteType int `json:"voteType"` // 1 = upvote, -1 = downvote
CreatedAt time.Time `json:"createdAt"`
}
type CreateQuestionRequest struct {
Title string `json:"title"`
Content string `json:"content"`
Tags []string `json:"tags"`
}
type CreateAnswerRequest struct {
Content string `json:"content"`
}
type VoteRequest struct {
VoteType int `json:"voteType"` // 1 or -1
}
var (
db *sql.DB
jwtSecret []byte
)
func initDB() {
var err error
dbHost := strings.TrimSpace(os.Getenv("DB_HOST"))
dbUser := strings.TrimSpace(os.Getenv("DB_USER"))
dbPassword := strings.TrimSpace(os.Getenv("DB_PASSWORD"))
dbName := strings.TrimSpace(os.Getenv("DB_NAME"))
dbSSLMode := strings.TrimSpace(os.Getenv("DB_SSL_MODE"))
dbSchema := strings.TrimSpace(os.Getenv("DB_SCHEMA"))
if dbHost == "" || dbUser == "" || dbPassword == "" || dbName == "" {
log.Fatal("Database configuration missing: DB_HOST, DB_USER, DB_PASSWORD, DB_NAME required")
}
if dbSSLMode == "" {
dbSSLMode = "require"
}
// Validate schema value if provided
validSchemas := map[string]bool{"": true, "public": true, "dev": true, "testing": true, "prod": true}
if !validSchemas[dbSchema] {
log.Fatalf("Invalid DB_SCHEMA '%s'. Must be: dev, testing, prod, or empty for public", dbSchema)
}
connStr := fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=%s", dbHost, dbUser, dbPassword, dbName, dbSSLMode)
if dbSchema != "" && dbSchema != "public" {
connStr += fmt.Sprintf(" search_path=%s,public", dbSchema)
}
db, err = sql.Open("postgres", connStr)
if err != nil {
log.Fatal(err)
}
// Configure connection pool limits
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
db.SetConnMaxIdleTime(1 * time.Minute)
if err = db.Ping(); err != nil {
log.Fatal(err)
}
schemaInfo := "public"
if dbSchema != "" && dbSchema != "public" {
schemaInfo = dbSchema
}
log.Printf("Connected to database (SSL mode: %s, schema: %s, max_conns: 25)", dbSSLMode, schemaInfo)
// Create tables
createTablesSQL := `
CREATE TABLE IF NOT EXISTS forum_questions (
id SERIAL PRIMARY KEY,
title VARCHAR(500) NOT NULL,
content TEXT NOT NULL,
author_id INTEGER NOT NULL,
author_name VARCHAR(255) NOT NULL,
tags TEXT[] DEFAULT '{}',
upvotes INTEGER DEFAULT 0,
downvotes INTEGER DEFAULT 0,
answer_count INTEGER DEFAULT 0,
view_count INTEGER DEFAULT 0,
accepted_answer_id INTEGER,
status VARCHAR(50) DEFAULT 'OPEN',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS forum_answers (
id SERIAL PRIMARY KEY,
question_id INTEGER NOT NULL REFERENCES forum_questions(id) ON DELETE CASCADE,
content TEXT NOT NULL,
author_id INTEGER NOT NULL,
author_name VARCHAR(255) NOT NULL,
upvotes INTEGER DEFAULT 0,
downvotes INTEGER DEFAULT 0,
is_accepted BOOLEAN DEFAULT FALSE,
is_verified BOOLEAN DEFAULT FALSE,
verified_by INTEGER,
verified_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS forum_votes (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL,
target_type VARCHAR(20) NOT NULL,
target_id INTEGER NOT NULL,
vote_type INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, target_type, target_id)
);
CREATE INDEX IF NOT EXISTS idx_questions_author ON forum_questions(author_id);
CREATE INDEX IF NOT EXISTS idx_questions_status ON forum_questions(status);
CREATE INDEX IF NOT EXISTS idx_questions_created ON forum_questions(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_answers_question ON forum_answers(question_id);
CREATE INDEX IF NOT EXISTS idx_answers_author ON forum_answers(author_id);
CREATE INDEX IF NOT EXISTS idx_votes_user ON forum_votes(user_id);
CREATE INDEX IF NOT EXISTS idx_votes_target ON forum_votes(target_type, target_id);
`
if _, err := db.Exec(createTablesSQL); err != nil {
log.Fatal("Failed to create tables:", err)
}
log.Println("Database tables initialized")
}
func loadConfig() {
jwtSecret = []byte(strings.TrimSpace(os.Getenv("JWT_SECRET")))
if len(jwtSecret) < 32 {
log.Fatal("JWT_SECRET must be set and at least 32 characters")
}
}
func enableCORS(w http.ResponseWriter) {
corsOrigin := strings.TrimSpace(os.Getenv("CORS_ALLOW_ORIGIN"))
if corsOrigin == "" {
corsOrigin = "http://localhost:8090"
}
w.Header().Set("Access-Control-Allow-Origin", corsOrigin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
// Security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'")
}
func authMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Authorization header required", http.StatusUnauthorized)
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
http.Error(w, "Invalid authorization format", http.StatusUnauthorized)
return
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), userContextKey, claims)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
func optionalAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
next.ServeHTTP(w, r)
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
next.ServeHTTP(w, r)
return
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
next.ServeHTTP(w, r)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), userContextKey, claims)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
func requireRole(next http.HandlerFunc, allowedRoles ...string) http.HandlerFunc {
return authMiddleware(func(w http.ResponseWriter, r *http.Request) {
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userRoles, err := extractRoles(claims)
if err != nil {
http.Error(w, "No roles found", http.StatusForbidden)
return
}
for _, userRole := range userRoles {
for _, allowedRole := range allowedRoles {
if userRole == allowedRole {
next.ServeHTTP(w, r)
return
}
}
}
http.Error(w, "Insufficient permissions", http.StatusForbidden)
})
}
func extractRoles(claims jwt.MapClaims) ([]string, error) {
rawRoles, ok := claims["roles"]
if !ok {
return nil, errors.New("roles missing")
}
switch v := rawRoles.(type) {
case []interface{}:
out := make([]string, 0, len(v))
for _, r := range v {
roleStr, ok := r.(string)
if !ok {
return nil, errors.New("role not string")
}
out = append(out, roleStr)
}
return out, nil
case []string:
return v, nil
default:
return nil, errors.New("invalid roles type")
}
}
func hasRole(claims jwt.MapClaims, role string) bool {
roles, err := extractRoles(claims)
if err != nil {
return false
}
for _, r := range roles {
// SUPERUSER has all permissions
if r == "SUPERUSER" {
return true
}
if r == role {
return true
}
}
return false
}
func getUserID(claims jwt.MapClaims) int {
if id, ok := claims["userId"].(float64); ok {
return int(id)
}
if id, ok := claims["user_id"].(float64); ok {
return int(id)
}
return 0
}
func getUserName(claims jwt.MapClaims) string {
if name, ok := claims["name"].(string); ok && name != "" {
return name
}
if email, ok := claims["email"].(string); ok && email != "" {
parts := strings.Split(email, "@")
return parts[0]
}
return "Anonymous"
}
func respondJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func respondError(w http.ResponseWriter, status int, message string) {
respondJSON(w, status, map[string]string{"error": message})
}
func parseTags(tagsStr string) []string {
if tagsStr == "" || tagsStr == "{}" {
return []string{}
}
tagsStr = strings.Trim(tagsStr, "{}")
if tagsStr == "" {
return []string{}
}
return strings.Split(tagsStr, ",")
}
func tagsToPostgres(tags []string) string {
if len(tags) == 0 {
return "{}"
}
return "{" + strings.Join(tags, ",") + "}"
}
// ============ QUESTION HANDLERS ============
// GET /questions - List questions
func listQuestionsHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
// Parse query parameters
tag := r.URL.Query().Get("tag")
status := r.URL.Query().Get("status")
sort := r.URL.Query().Get("sort") // "newest", "votes", "unanswered"
limitStr := r.URL.Query().Get("limit")
offsetStr := r.URL.Query().Get("offset")
limit := 20
offset := 0
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 100 {
limit = l
}
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
offset = o
}
query := `SELECT id, title, content, author_id, author_name, tags, upvotes, downvotes,
answer_count, view_count, accepted_answer_id, status, created_at, updated_at
FROM forum_questions WHERE 1=1`
args := []interface{}{}
argNum := 1
if tag != "" {
query += fmt.Sprintf(" AND $%d = ANY(tags)", argNum)
args = append(args, tag)
argNum++
}
if status != "" {
query += fmt.Sprintf(" AND status = $%d", argNum)
args = append(args, status)
argNum++
}
// Sort order
switch sort {
case "votes":
query += " ORDER BY (upvotes - downvotes) DESC, created_at DESC"
case "unanswered":
query += " AND answer_count = 0 ORDER BY created_at DESC"
default: // newest
query += " ORDER BY created_at DESC"
}
query += fmt.Sprintf(" LIMIT $%d OFFSET $%d", argNum, argNum+1)
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
log.Println("Database error:", err)
respondError(w, http.StatusInternalServerError, "Failed to fetch questions")
return
}
defer rows.Close()
questions := []Question{}
for rows.Next() {
var q Question
var tags string
var acceptedID sql.NullInt64
err := rows.Scan(&q.ID, &q.Title, &q.Content, &q.AuthorID, &q.AuthorName, &tags,
&q.Upvotes, &q.Downvotes, &q.AnswerCount, &q.ViewCount, &acceptedID, &q.Status,
&q.CreatedAt, &q.UpdatedAt)
if err != nil {
log.Println("Scan error:", err)
continue
}
q.Tags = parseTags(tags)
if acceptedID.Valid {
id := int(acceptedID.Int64)
q.AcceptedID = &id
}
questions = append(questions, q)
}
if err = rows.Err(); err != nil {
log.Println("Rows error:", err)
}
respondJSON(w, http.StatusOK, questions)
}
// GET /questions/:id - Get question with answers
func getQuestionHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
idStr := strings.TrimPrefix(r.URL.Path, "/questions/")
id, err := strconv.Atoi(idStr)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid question ID")
return
}
// Increment view count
db.Exec("UPDATE forum_questions SET view_count = view_count + 1 WHERE id = $1", id)
// Get question
var q Question
var tags string
var acceptedID sql.NullInt64
err = db.QueryRow(`SELECT id, title, content, author_id, author_name, tags, upvotes, downvotes,
answer_count, view_count, accepted_answer_id, status, created_at, updated_at
FROM forum_questions WHERE id = $1`, id).
Scan(&q.ID, &q.Title, &q.Content, &q.AuthorID, &q.AuthorName, &tags,
&q.Upvotes, &q.Downvotes, &q.AnswerCount, &q.ViewCount, &acceptedID, &q.Status,
&q.CreatedAt, &q.UpdatedAt)
if err == sql.ErrNoRows {
respondError(w, http.StatusNotFound, "Question not found")
return
} else if err != nil {
respondError(w, http.StatusInternalServerError, "Database error")
return
}
q.Tags = parseTags(tags)
if acceptedID.Valid {
aid := int(acceptedID.Int64)
q.AcceptedID = &aid
}
// Get answers
rows, err := db.Query(`SELECT id, question_id, content, author_id, author_name, upvotes, downvotes,
is_accepted, is_verified, verified_by, verified_at, created_at, updated_at
FROM forum_answers WHERE question_id = $1 ORDER BY is_accepted DESC, (upvotes - downvotes) DESC, created_at ASC`, id)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to fetch answers")
return
}
defer rows.Close()
answers := []Answer{}
for rows.Next() {
var a Answer
var verifiedBy sql.NullInt64
var verifiedAt sql.NullTime
err := rows.Scan(&a.ID, &a.QuestionID, &a.Content, &a.AuthorID, &a.AuthorName,
&a.Upvotes, &a.Downvotes, &a.IsAccepted, &a.IsVerified, &verifiedBy, &verifiedAt,
&a.CreatedAt, &a.UpdatedAt)
if err != nil {
continue
}
if verifiedBy.Valid {
vb := int(verifiedBy.Int64)
a.VerifiedBy = &vb
}
if verifiedAt.Valid {
a.VerifiedAt = &verifiedAt.Time
}
answers = append(answers, a)
}
respondJSON(w, http.StatusOK, map[string]interface{}{
"question": q,
"answers": answers,
})
}
// POST /questions - Create question
func createQuestionHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
userName := getUserName(claims)
var req CreateQuestionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.Title == "" || req.Content == "" {
respondError(w, http.StatusBadRequest, "Title and content are required")
return
}
var q Question
err := db.QueryRow(`INSERT INTO forum_questions (title, content, author_id, author_name, tags)
VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
req.Title, req.Content, userID, userName, tagsToPostgres(req.Tags)).
Scan(&q.ID, &q.CreatedAt, &q.UpdatedAt)
if err != nil {
log.Println("Database error:", err)
respondError(w, http.StatusInternalServerError, "Failed to create question")
return
}
q.Title = req.Title
q.Content = req.Content
q.AuthorID = userID
q.AuthorName = userName
q.Tags = req.Tags
q.Status = "OPEN"
log.Printf("AUDIT: User %d created question %d", userID, q.ID)
respondJSON(w, http.StatusCreated, q)
}
// PUT /questions/:id - Update question
func updateQuestionHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
isAdmin := hasRole(claims, "ADMIN")
idStr := strings.TrimPrefix(r.URL.Path, "/questions/")
id, err := strconv.Atoi(idStr)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid question ID")
return
}
// Check ownership
var authorID int
err = db.QueryRow("SELECT author_id FROM forum_questions WHERE id = $1", id).Scan(&authorID)
if err == sql.ErrNoRows {
respondError(w, http.StatusNotFound, "Question not found")
return
}
if authorID != userID && !isAdmin {
respondError(w, http.StatusForbidden, "You can only edit your own questions")
return
}
var req CreateQuestionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
_, err = db.Exec(`UPDATE forum_questions SET title = $1, content = $2, tags = $3, updated_at = CURRENT_TIMESTAMP
WHERE id = $4`, req.Title, req.Content, tagsToPostgres(req.Tags), id)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to update question")
return
}
log.Printf("AUDIT: User %d updated question %d", userID, id)
respondJSON(w, http.StatusOK, map[string]string{"message": "Question updated"})
}
// DELETE /questions/:id - Delete question
func deleteQuestionHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
isAdmin := hasRole(claims, "ADMIN")
idStr := strings.TrimPrefix(r.URL.Path, "/questions/")
id, err := strconv.Atoi(idStr)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid question ID")
return
}
// Check ownership
var authorID int
err = db.QueryRow("SELECT author_id FROM forum_questions WHERE id = $1", id).Scan(&authorID)
if err == sql.ErrNoRows {
respondError(w, http.StatusNotFound, "Question not found")
return
}
if authorID != userID && !isAdmin {
respondError(w, http.StatusForbidden, "You can only delete your own questions")
return
}
_, err = db.Exec("DELETE FROM forum_questions WHERE id = $1", id)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete question")
return
}
log.Printf("AUDIT: User %d deleted question %d", userID, id)
respondJSON(w, http.StatusOK, map[string]string{"message": "Question deleted"})
}
// ============ ANSWER HANDLERS ============
// POST /questions/:id/answers - Create answer
func createAnswerHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
userName := getUserName(claims)
// Extract question ID from path like /questions/123/answers
path := strings.TrimPrefix(r.URL.Path, "/questions/")
parts := strings.Split(path, "/")
if len(parts) < 2 {
respondError(w, http.StatusBadRequest, "Invalid path")
return
}
questionID, err := strconv.Atoi(parts[0])
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid question ID")
return
}
// Check question exists
var status string
err = db.QueryRow("SELECT status FROM forum_questions WHERE id = $1", questionID).Scan(&status)
if err == sql.ErrNoRows {
respondError(w, http.StatusNotFound, "Question not found")
return
}
if status == "CLOSED" {
respondError(w, http.StatusBadRequest, "This question is closed")
return
}
var req CreateAnswerRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.Content == "" {
respondError(w, http.StatusBadRequest, "Content is required")
return
}
var a Answer
err = db.QueryRow(`INSERT INTO forum_answers (question_id, content, author_id, author_name)
VALUES ($1, $2, $3, $4) RETURNING id, created_at, updated_at`,
questionID, req.Content, userID, userName).
Scan(&a.ID, &a.CreatedAt, &a.UpdatedAt)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to create answer")
return
}
// Update answer count
db.Exec("UPDATE forum_questions SET answer_count = answer_count + 1, status = 'ANSWERED' WHERE id = $1", questionID)
a.QuestionID = questionID
a.Content = req.Content
a.AuthorID = userID
a.AuthorName = userName
log.Printf("AUDIT: User %d answered question %d", userID, questionID)
respondJSON(w, http.StatusCreated, a)
}
// PUT /answers/:id - Update answer
func updateAnswerHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
isAdmin := hasRole(claims, "ADMIN")
idStr := strings.TrimPrefix(r.URL.Path, "/answers/")
id, err := strconv.Atoi(idStr)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid answer ID")
return
}
// Check ownership
var authorID int
err = db.QueryRow("SELECT author_id FROM forum_answers WHERE id = $1", id).Scan(&authorID)
if err == sql.ErrNoRows {
respondError(w, http.StatusNotFound, "Answer not found")
return
}
if authorID != userID && !isAdmin {
respondError(w, http.StatusForbidden, "You can only edit your own answers")
return
}
var req CreateAnswerRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
_, err = db.Exec(`UPDATE forum_answers SET content = $1, updated_at = CURRENT_TIMESTAMP WHERE id = $2`,
req.Content, id)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to update answer")
return
}
log.Printf("AUDIT: User %d updated answer %d", userID, id)
respondJSON(w, http.StatusOK, map[string]string{"message": "Answer updated"})
}
// DELETE /answers/:id - Delete answer
func deleteAnswerHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
isAdmin := hasRole(claims, "ADMIN")
idStr := strings.TrimPrefix(r.URL.Path, "/answers/")
id, err := strconv.Atoi(idStr)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid answer ID")
return
}
// Check ownership and get question ID
var authorID, questionID int
err = db.QueryRow("SELECT author_id, question_id FROM forum_answers WHERE id = $1", id).Scan(&authorID, &questionID)
if err == sql.ErrNoRows {
respondError(w, http.StatusNotFound, "Answer not found")
return
}
if authorID != userID && !isAdmin {
respondError(w, http.StatusForbidden, "You can only delete your own answers")
return
}
_, err = db.Exec("DELETE FROM forum_answers WHERE id = $1", id)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to delete answer")
return
}
// Update answer count
db.Exec("UPDATE forum_questions SET answer_count = answer_count - 1 WHERE id = $1", questionID)
log.Printf("AUDIT: User %d deleted answer %d", userID, id)
respondJSON(w, http.StatusOK, map[string]string{"message": "Answer deleted"})
}
// POST /answers/:id/accept - Accept answer (question author only)
func acceptAnswerHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
idStr := strings.TrimPrefix(r.URL.Path, "/answers/")
idStr = strings.TrimSuffix(idStr, "/accept")
id, err := strconv.Atoi(idStr)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid answer ID")
return
}
// Get answer and question
var questionID int
err = db.QueryRow("SELECT question_id FROM forum_answers WHERE id = $1", id).Scan(&questionID)
if err == sql.ErrNoRows {
respondError(w, http.StatusNotFound, "Answer not found")
return
}
// Check if user owns the question
var questionAuthorID int
err = db.QueryRow("SELECT author_id FROM forum_questions WHERE id = $1", questionID).Scan(&questionAuthorID)
if err != nil {
respondError(w, http.StatusInternalServerError, "Database error")
return
}
if questionAuthorID != userID {
respondError(w, http.StatusForbidden, "Only the question author can accept answers")
return
}
// Unaccept any previously accepted answer
db.Exec("UPDATE forum_answers SET is_accepted = FALSE WHERE question_id = $1", questionID)
// Accept this answer
db.Exec("UPDATE forum_answers SET is_accepted = TRUE WHERE id = $1", id)
db.Exec("UPDATE forum_questions SET accepted_answer_id = $1, status = 'ANSWERED' WHERE id = $2", id, questionID)
log.Printf("AUDIT: User %d accepted answer %d for question %d", userID, id, questionID)
respondJSON(w, http.StatusOK, map[string]string{"message": "Answer accepted"})
}
// POST /answers/:id/verify - Verify answer (ADMIN only)
func verifyAnswerHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
idStr := strings.TrimPrefix(r.URL.Path, "/answers/")
idStr = strings.TrimSuffix(idStr, "/verify")
id, err := strconv.Atoi(idStr)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid answer ID")
return
}
_, err = db.Exec(`UPDATE forum_answers SET is_verified = TRUE, verified_by = $1,
verified_at = CURRENT_TIMESTAMP WHERE id = $2`, userID, id)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to verify answer")
return
}
log.Printf("AUDIT: Admin %d verified answer %d", userID, id)
respondJSON(w, http.StatusOK, map[string]string{"message": "Answer verified"})
}
// ============ VOTING HANDLERS ============
// POST /questions/:id/vote - Vote on question
func voteQuestionHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
idStr := strings.TrimPrefix(r.URL.Path, "/questions/")
idStr = strings.TrimSuffix(idStr, "/vote")
id, err := strconv.Atoi(idStr)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid question ID")
return
}
var req VoteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.VoteType != 1 && req.VoteType != -1 {
respondError(w, http.StatusBadRequest, "VoteType must be 1 (upvote) or -1 (downvote)")
return
}
// Check if user already voted
var existingVote int
err = db.QueryRow("SELECT vote_type FROM forum_votes WHERE user_id = $1 AND target_type = 'question' AND target_id = $2",
userID, id).Scan(&existingVote)
if err == sql.ErrNoRows {
// New vote
db.Exec("INSERT INTO forum_votes (user_id, target_type, target_id, vote_type) VALUES ($1, 'question', $2, $3)",
userID, id, req.VoteType)
if req.VoteType == 1 {
db.Exec("UPDATE forum_questions SET upvotes = upvotes + 1 WHERE id = $1", id)
} else {
db.Exec("UPDATE forum_questions SET downvotes = downvotes + 1 WHERE id = $1", id)
}
} else if existingVote != req.VoteType {
// Changing vote
db.Exec("UPDATE forum_votes SET vote_type = $1 WHERE user_id = $2 AND target_type = 'question' AND target_id = $3",
req.VoteType, userID, id)
if req.VoteType == 1 {
db.Exec("UPDATE forum_questions SET upvotes = upvotes + 1, downvotes = downvotes - 1 WHERE id = $1", id)
} else {
db.Exec("UPDATE forum_questions SET upvotes = upvotes - 1, downvotes = downvotes + 1 WHERE id = $1", id)
}
}
// If same vote, do nothing
respondJSON(w, http.StatusOK, map[string]string{"message": "Vote recorded"})
}
// POST /answers/:id/vote - Vote on answer
func voteAnswerHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
idStr := strings.TrimPrefix(r.URL.Path, "/answers/")
idStr = strings.TrimSuffix(idStr, "/vote")
id, err := strconv.Atoi(idStr)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid answer ID")
return
}
var req VoteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, http.StatusBadRequest, "Invalid request body")
return
}
if req.VoteType != 1 && req.VoteType != -1 {
respondError(w, http.StatusBadRequest, "VoteType must be 1 or -1")
return
}
// Check if user already voted
var existingVote int
err = db.QueryRow("SELECT vote_type FROM forum_votes WHERE user_id = $1 AND target_type = 'answer' AND target_id = $2",
userID, id).Scan(&existingVote)
if err == sql.ErrNoRows {
// New vote
db.Exec("INSERT INTO forum_votes (user_id, target_type, target_id, vote_type) VALUES ($1, 'answer', $2, $3)",
userID, id, req.VoteType)
if req.VoteType == 1 {
db.Exec("UPDATE forum_answers SET upvotes = upvotes + 1 WHERE id = $1", id)
} else {
db.Exec("UPDATE forum_answers SET downvotes = downvotes + 1 WHERE id = $1", id)
}
} else if existingVote != req.VoteType {
// Changing vote
db.Exec("UPDATE forum_votes SET vote_type = $1 WHERE user_id = $2 AND target_type = 'answer' AND target_id = $3",
req.VoteType, userID, id)
if req.VoteType == 1 {
db.Exec("UPDATE forum_answers SET upvotes = upvotes + 1, downvotes = downvotes - 1 WHERE id = $1", id)
} else {
db.Exec("UPDATE forum_answers SET upvotes = upvotes - 1, downvotes = downvotes + 1 WHERE id = $1", id)
}
}
respondJSON(w, http.StatusOK, map[string]string{"message": "Vote recorded"})
}
// POST /questions/:id/close - Close question (ADMIN only)
func closeQuestionHandler(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
userID := getUserID(claims)
idStr := strings.TrimPrefix(r.URL.Path, "/questions/")
idStr = strings.TrimSuffix(idStr, "/close")
id, err := strconv.Atoi(idStr)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid question ID")
return
}
_, err = db.Exec("UPDATE forum_questions SET status = 'CLOSED', updated_at = CURRENT_TIMESTAMP WHERE id = $1", id)
if err != nil {
respondError(w, http.StatusInternalServerError, "Failed to close question")
return
}
log.Printf("AUDIT: Admin %d closed question %d", userID, id)
respondJSON(w, http.StatusOK, map[string]string{"message": "Question closed"})
}
func main() {
loadConfig()
initDB()
defer db.Close()
// Public endpoints
http.HandleFunc("/questions", func(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
if r.Method == http.MethodGet {
listQuestionsHandler(w, r)
} else if r.Method == http.MethodPost {
authMiddleware(createQuestionHandler)(w, r)
} else {
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
}
})
http.HandleFunc("/questions/", func(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
path := r.URL.Path
switch {
case strings.HasSuffix(path, "/answers"):
if r.Method == http.MethodPost {
authMiddleware(createAnswerHandler)(w, r)
} else {
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
}
case strings.HasSuffix(path, "/vote"):
if r.Method == http.MethodPost {
authMiddleware(voteQuestionHandler)(w, r)
} else {
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
}
case strings.HasSuffix(path, "/close"):
if r.Method == http.MethodPost {
requireRole(closeQuestionHandler, "ADMIN")(w, r)
} else {
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
}
default:
if r.Method == http.MethodGet {
getQuestionHandler(w, r)
} else if r.Method == http.MethodPut {
authMiddleware(updateQuestionHandler)(w, r)
} else if r.Method == http.MethodDelete {
authMiddleware(deleteQuestionHandler)(w, r)
} else {
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
}
}
})
http.HandleFunc("/answers/", func(w http.ResponseWriter, r *http.Request) {
enableCORS(w)
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
path := r.URL.Path
switch {
case strings.HasSuffix(path, "/accept"):
if r.Method == http.MethodPost {
authMiddleware(acceptAnswerHandler)(w, r)
} else {
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
}
case strings.HasSuffix(path, "/verify"):
if r.Method == http.MethodPost {
requireRole(verifyAnswerHandler, "ADMIN")(w, r)
} else {
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
}
case strings.HasSuffix(path, "/vote"):
if r.Method == http.MethodPost {
authMiddleware(voteAnswerHandler)(w, r)
} else {
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
}
default:
if r.Method == http.MethodPut {
authMiddleware(updateAnswerHandler)(w, r)
} else if r.Method == http.MethodDelete {
authMiddleware(deleteAnswerHandler)(w, r)
} else {
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
}
}
})
// Health check (both /health and /healthz for compatibility)
healthHandler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "ok")
}
http.HandleFunc("/health", healthHandler)
http.HandleFunc("/healthz", healthHandler)
port := os.Getenv("PORT")
if port == "" {
port = "8080"
}
// Wrap all routes with rate limiting and body size limit
rateLimitedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Limit request body size to prevent DoS
if r.Body != nil {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody)
}
clientIP := getClientIP(r)
if r.Method == http.MethodGet || r.Method == http.MethodHead {
if readLimiter.checkRateLimit(clientIP, maxReadRequests, rateLimitWindow) {
log.Printf("SECURITY: Read rate limit exceeded for IP %s on %s", clientIP, r.URL.Path)
http.Error(w, "Too many requests. Please slow down.", http.StatusTooManyRequests)
return
}
} else if r.Method != http.MethodOptions {
if writeLimiter.checkRateLimit(clientIP, maxWriteRequests, rateLimitWindow) {
log.Printf("SECURITY: Write rate limit exceeded for IP %s on %s", clientIP, r.URL.Path)
http.Error(w, "Too many requests. Please slow down.", http.StatusTooManyRequests)
return
}
}
http.DefaultServeMux.ServeHTTP(w, r)
})
server := &http.Server{
Addr: ":" + port,
Handler: rateLimitedHandler,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
// Graceful shutdown
done := make(chan bool, 1)
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-quit
log.Println("Forum Service shutting down...")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
server.SetKeepAlivesEnabled(false)
if err := server.Shutdown(ctx); err != nil {
log.Printf("Could not gracefully shutdown: %v", err)
}
close(done)
}()
log.Printf("Forum service starting on port %s\n", port)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server error: %v", err)
}
<-done
log.Println("Forum Service stopped")
}