1310 lines
38 KiB
Go
1310 lines
38 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
// Rate limiting configuration
|
|
const (
|
|
rateLimitWindow = 1 * time.Minute
|
|
maxWriteRequests = 30 // Max write requests per minute per IP
|
|
maxReadRequests = 100 // Max read requests per minute per IP
|
|
maxRequestBody = 1 << 20 // 1MB max request body size
|
|
)
|
|
|
|
// Input validation limits for forum content
|
|
const (
|
|
maxQuestionTitleLength = 200
|
|
maxQuestionContentLength = 50000 // ~50KB for questions
|
|
maxAnswerContentLength = 50000 // ~50KB for answers
|
|
maxTagLength = 50
|
|
maxTagsCount = 10
|
|
)
|
|
|
|
type rateLimiter struct {
|
|
mu sync.RWMutex
|
|
requests map[string]*requestInfo
|
|
}
|
|
|
|
type requestInfo struct {
|
|
count int
|
|
firstReq time.Time
|
|
}
|
|
|
|
var writeLimiter = &rateLimiter{requests: make(map[string]*requestInfo)}
|
|
var readLimiter = &rateLimiter{requests: make(map[string]*requestInfo)}
|
|
|
|
func (rl *rateLimiter) checkRateLimit(key string, maxRequests int, window time.Duration) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
info, exists := rl.requests[key]
|
|
|
|
if !exists {
|
|
rl.requests[key] = &requestInfo{count: 1, firstReq: now}
|
|
return false
|
|
}
|
|
|
|
if now.Sub(info.firstReq) > window {
|
|
info.count = 1
|
|
info.firstReq = now
|
|
return false
|
|
}
|
|
|
|
info.count++
|
|
return info.count > maxRequests
|
|
}
|
|
|
|
func getClientIP(r *http.Request) string {
|
|
xff := r.Header.Get("X-Forwarded-For")
|
|
if xff != "" {
|
|
if idx := strings.Index(xff, ","); idx != -1 {
|
|
return strings.TrimSpace(xff[:idx])
|
|
}
|
|
return strings.TrimSpace(xff)
|
|
}
|
|
|
|
xri := r.Header.Get("X-Real-IP")
|
|
if xri != "" {
|
|
return strings.TrimSpace(xri)
|
|
}
|
|
|
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return ip
|
|
}
|
|
|
|
// Context keys
|
|
type contextKey string
|
|
|
|
var userContextKey = contextKey("userClaims")
|
|
|
|
// Question represents a forum question
|
|
type Question struct {
|
|
ID int `json:"id"`
|
|
Title string `json:"title"`
|
|
Content string `json:"content"`
|
|
AuthorID int `json:"authorId"`
|
|
AuthorName string `json:"authorName"`
|
|
Tags []string `json:"tags"`
|
|
Upvotes int `json:"upvotes"`
|
|
Downvotes int `json:"downvotes"`
|
|
AnswerCount int `json:"answerCount"`
|
|
ViewCount int `json:"viewCount"`
|
|
AcceptedID *int `json:"acceptedAnswerId,omitempty"`
|
|
Status string `json:"status"` // OPEN, ANSWERED, CLOSED
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
// Answer represents an answer to a question
|
|
type Answer struct {
|
|
ID int `json:"id"`
|
|
QuestionID int `json:"questionId"`
|
|
Content string `json:"content"`
|
|
AuthorID int `json:"authorId"`
|
|
AuthorName string `json:"authorName"`
|
|
Upvotes int `json:"upvotes"`
|
|
Downvotes int `json:"downvotes"`
|
|
IsAccepted bool `json:"isAccepted"`
|
|
IsVerified bool `json:"isVerified"` // Admin-verified correct answer
|
|
VerifiedBy *int `json:"verifiedBy,omitempty"`
|
|
VerifiedAt *time.Time `json:"verifiedAt,omitempty"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
// Vote represents a user's vote on a question or answer
|
|
type Vote struct {
|
|
ID int `json:"id"`
|
|
UserID int `json:"userId"`
|
|
TargetType string `json:"targetType"` // "question" or "answer"
|
|
TargetID int `json:"targetId"`
|
|
VoteType int `json:"voteType"` // 1 = upvote, -1 = downvote
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
}
|
|
|
|
type CreateQuestionRequest struct {
|
|
Title string `json:"title"`
|
|
Content string `json:"content"`
|
|
Tags []string `json:"tags"`
|
|
}
|
|
|
|
type CreateAnswerRequest struct {
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type VoteRequest struct {
|
|
VoteType int `json:"voteType"` // 1 or -1
|
|
}
|
|
|
|
var (
|
|
db *sql.DB
|
|
jwtSecret []byte
|
|
)
|
|
|
|
func initDB() {
|
|
var err error
|
|
dbHost := strings.TrimSpace(os.Getenv("DB_HOST"))
|
|
dbUser := strings.TrimSpace(os.Getenv("DB_USER"))
|
|
dbPassword := strings.TrimSpace(os.Getenv("DB_PASSWORD"))
|
|
dbName := strings.TrimSpace(os.Getenv("DB_NAME"))
|
|
dbSSLMode := strings.TrimSpace(os.Getenv("DB_SSL_MODE"))
|
|
dbSchema := strings.TrimSpace(os.Getenv("DB_SCHEMA"))
|
|
|
|
if dbHost == "" || dbUser == "" || dbPassword == "" || dbName == "" {
|
|
log.Fatal("Database configuration missing: DB_HOST, DB_USER, DB_PASSWORD, DB_NAME required")
|
|
}
|
|
|
|
if dbSSLMode == "" {
|
|
dbSSLMode = "require"
|
|
}
|
|
|
|
// Validate schema value if provided
|
|
validSchemas := map[string]bool{"": true, "public": true, "dev": true, "testing": true, "prod": true}
|
|
if !validSchemas[dbSchema] {
|
|
log.Fatalf("Invalid DB_SCHEMA '%s'. Must be: dev, testing, prod, or empty for public", dbSchema)
|
|
}
|
|
|
|
connStr := fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=%s", dbHost, dbUser, dbPassword, dbName, dbSSLMode)
|
|
if dbSchema != "" && dbSchema != "public" {
|
|
connStr += fmt.Sprintf(" search_path=%s,public", dbSchema)
|
|
}
|
|
|
|
db, err = sql.Open("postgres", connStr)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Configure connection pool limits
|
|
db.SetMaxOpenConns(25)
|
|
db.SetMaxIdleConns(5)
|
|
db.SetConnMaxLifetime(5 * time.Minute)
|
|
db.SetConnMaxIdleTime(1 * time.Minute)
|
|
|
|
if err = db.Ping(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
schemaInfo := "public"
|
|
if dbSchema != "" && dbSchema != "public" {
|
|
schemaInfo = dbSchema
|
|
}
|
|
log.Printf("Connected to database (SSL mode: %s, schema: %s, max_conns: 25)", dbSSLMode, schemaInfo)
|
|
|
|
// Create tables
|
|
createTablesSQL := `
|
|
CREATE TABLE IF NOT EXISTS forum_questions (
|
|
id SERIAL PRIMARY KEY,
|
|
title VARCHAR(500) NOT NULL,
|
|
content TEXT NOT NULL,
|
|
author_id INTEGER NOT NULL,
|
|
author_name VARCHAR(255) NOT NULL,
|
|
tags TEXT[] DEFAULT '{}',
|
|
upvotes INTEGER DEFAULT 0,
|
|
downvotes INTEGER DEFAULT 0,
|
|
answer_count INTEGER DEFAULT 0,
|
|
view_count INTEGER DEFAULT 0,
|
|
accepted_answer_id INTEGER,
|
|
status VARCHAR(50) DEFAULT 'OPEN',
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS forum_answers (
|
|
id SERIAL PRIMARY KEY,
|
|
question_id INTEGER NOT NULL REFERENCES forum_questions(id) ON DELETE CASCADE,
|
|
content TEXT NOT NULL,
|
|
author_id INTEGER NOT NULL,
|
|
author_name VARCHAR(255) NOT NULL,
|
|
upvotes INTEGER DEFAULT 0,
|
|
downvotes INTEGER DEFAULT 0,
|
|
is_accepted BOOLEAN DEFAULT FALSE,
|
|
is_verified BOOLEAN DEFAULT FALSE,
|
|
verified_by INTEGER,
|
|
verified_at TIMESTAMP,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS forum_votes (
|
|
id SERIAL PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
target_type VARCHAR(20) NOT NULL,
|
|
target_id INTEGER NOT NULL,
|
|
vote_type INTEGER NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
UNIQUE(user_id, target_type, target_id)
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_questions_author ON forum_questions(author_id);
|
|
CREATE INDEX IF NOT EXISTS idx_questions_status ON forum_questions(status);
|
|
CREATE INDEX IF NOT EXISTS idx_questions_created ON forum_questions(created_at DESC);
|
|
CREATE INDEX IF NOT EXISTS idx_answers_question ON forum_answers(question_id);
|
|
CREATE INDEX IF NOT EXISTS idx_answers_author ON forum_answers(author_id);
|
|
CREATE INDEX IF NOT EXISTS idx_votes_user ON forum_votes(user_id);
|
|
CREATE INDEX IF NOT EXISTS idx_votes_target ON forum_votes(target_type, target_id);
|
|
`
|
|
|
|
if _, err := db.Exec(createTablesSQL); err != nil {
|
|
log.Fatal("Failed to create tables:", err)
|
|
}
|
|
|
|
log.Println("Database tables initialized")
|
|
}
|
|
|
|
func loadConfig() {
|
|
jwtSecret = []byte(strings.TrimSpace(os.Getenv("JWT_SECRET")))
|
|
if len(jwtSecret) < 32 {
|
|
log.Fatal("JWT_SECRET must be set and at least 32 characters")
|
|
}
|
|
}
|
|
|
|
func enableCORS(w http.ResponseWriter) {
|
|
corsOrigin := strings.TrimSpace(os.Getenv("CORS_ALLOW_ORIGIN"))
|
|
if corsOrigin == "" {
|
|
corsOrigin = "http://localhost:8090"
|
|
}
|
|
w.Header().Set("Access-Control-Allow-Origin", corsOrigin)
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
|
|
|
// Security headers
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
|
}
|
|
|
|
func authMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
http.Error(w, "Authorization header required", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
if tokenString == authHeader {
|
|
http.Error(w, "Invalid authorization format", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method")
|
|
}
|
|
return jwtSecret, nil
|
|
})
|
|
|
|
if err != nil || !token.Valid {
|
|
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userContextKey, claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
func optionalAuth(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
if tokenString == authHeader {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method")
|
|
}
|
|
return jwtSecret, nil
|
|
})
|
|
|
|
if err != nil || !token.Valid {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userContextKey, claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
func requireRole(next http.HandlerFunc, allowedRoles ...string) http.HandlerFunc {
|
|
return authMiddleware(func(w http.ResponseWriter, r *http.Request) {
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userRoles, err := extractRoles(claims)
|
|
if err != nil {
|
|
http.Error(w, "No roles found", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
for _, userRole := range userRoles {
|
|
for _, allowedRole := range allowedRoles {
|
|
if userRole == allowedRole {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
http.Error(w, "Insufficient permissions", http.StatusForbidden)
|
|
})
|
|
}
|
|
|
|
func extractRoles(claims jwt.MapClaims) ([]string, error) {
|
|
rawRoles, ok := claims["roles"]
|
|
if !ok {
|
|
return nil, errors.New("roles missing")
|
|
}
|
|
|
|
switch v := rawRoles.(type) {
|
|
case []interface{}:
|
|
out := make([]string, 0, len(v))
|
|
for _, r := range v {
|
|
roleStr, ok := r.(string)
|
|
if !ok {
|
|
return nil, errors.New("role not string")
|
|
}
|
|
out = append(out, roleStr)
|
|
}
|
|
return out, nil
|
|
case []string:
|
|
return v, nil
|
|
default:
|
|
return nil, errors.New("invalid roles type")
|
|
}
|
|
}
|
|
|
|
func hasRole(claims jwt.MapClaims, role string) bool {
|
|
roles, err := extractRoles(claims)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
for _, r := range roles {
|
|
// SUPERUSER has all permissions
|
|
if r == "SUPERUSER" {
|
|
return true
|
|
}
|
|
if r == role {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func getUserID(claims jwt.MapClaims) int {
|
|
if id, ok := claims["userId"].(float64); ok {
|
|
return int(id)
|
|
}
|
|
if id, ok := claims["user_id"].(float64); ok {
|
|
return int(id)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func getUserName(claims jwt.MapClaims) string {
|
|
if name, ok := claims["name"].(string); ok && name != "" {
|
|
return name
|
|
}
|
|
if email, ok := claims["email"].(string); ok && email != "" {
|
|
parts := strings.Split(email, "@")
|
|
return parts[0]
|
|
}
|
|
return "Anonymous"
|
|
}
|
|
|
|
func respondJSON(w http.ResponseWriter, status int, data interface{}) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
json.NewEncoder(w).Encode(data)
|
|
}
|
|
|
|
func respondError(w http.ResponseWriter, status int, message string) {
|
|
respondJSON(w, status, map[string]string{"error": message})
|
|
}
|
|
|
|
func parseTags(tagsStr string) []string {
|
|
if tagsStr == "" || tagsStr == "{}" {
|
|
return []string{}
|
|
}
|
|
tagsStr = strings.Trim(tagsStr, "{}")
|
|
if tagsStr == "" {
|
|
return []string{}
|
|
}
|
|
return strings.Split(tagsStr, ",")
|
|
}
|
|
|
|
func tagsToPostgres(tags []string) string {
|
|
if len(tags) == 0 {
|
|
return "{}"
|
|
}
|
|
return "{" + strings.Join(tags, ",") + "}"
|
|
}
|
|
|
|
// ============ QUESTION HANDLERS ============
|
|
|
|
// GET /questions - List questions
|
|
func listQuestionsHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
// Parse query parameters
|
|
tag := r.URL.Query().Get("tag")
|
|
status := r.URL.Query().Get("status")
|
|
sort := r.URL.Query().Get("sort") // "newest", "votes", "unanswered"
|
|
limitStr := r.URL.Query().Get("limit")
|
|
offsetStr := r.URL.Query().Get("offset")
|
|
|
|
limit := 20
|
|
offset := 0
|
|
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 100 {
|
|
limit = l
|
|
}
|
|
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
|
offset = o
|
|
}
|
|
|
|
query := `SELECT id, title, content, author_id, author_name, tags, upvotes, downvotes,
|
|
answer_count, view_count, accepted_answer_id, status, created_at, updated_at
|
|
FROM forum_questions WHERE 1=1`
|
|
args := []interface{}{}
|
|
argNum := 1
|
|
|
|
if tag != "" {
|
|
query += fmt.Sprintf(" AND $%d = ANY(tags)", argNum)
|
|
args = append(args, tag)
|
|
argNum++
|
|
}
|
|
|
|
if status != "" {
|
|
query += fmt.Sprintf(" AND status = $%d", argNum)
|
|
args = append(args, status)
|
|
argNum++
|
|
}
|
|
|
|
// Sort order
|
|
switch sort {
|
|
case "votes":
|
|
query += " ORDER BY (upvotes - downvotes) DESC, created_at DESC"
|
|
case "unanswered":
|
|
query += " AND answer_count = 0 ORDER BY created_at DESC"
|
|
default: // newest
|
|
query += " ORDER BY created_at DESC"
|
|
}
|
|
|
|
query += fmt.Sprintf(" LIMIT $%d OFFSET $%d", argNum, argNum+1)
|
|
args = append(args, limit, offset)
|
|
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
log.Println("Database error:", err)
|
|
respondError(w, http.StatusInternalServerError, "Failed to fetch questions")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
questions := []Question{}
|
|
for rows.Next() {
|
|
var q Question
|
|
var tags string
|
|
var acceptedID sql.NullInt64
|
|
err := rows.Scan(&q.ID, &q.Title, &q.Content, &q.AuthorID, &q.AuthorName, &tags,
|
|
&q.Upvotes, &q.Downvotes, &q.AnswerCount, &q.ViewCount, &acceptedID, &q.Status,
|
|
&q.CreatedAt, &q.UpdatedAt)
|
|
if err != nil {
|
|
log.Println("Scan error:", err)
|
|
continue
|
|
}
|
|
q.Tags = parseTags(tags)
|
|
if acceptedID.Valid {
|
|
id := int(acceptedID.Int64)
|
|
q.AcceptedID = &id
|
|
}
|
|
questions = append(questions, q)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
log.Println("Rows error:", err)
|
|
}
|
|
|
|
respondJSON(w, http.StatusOK, questions)
|
|
}
|
|
|
|
// GET /questions/:id - Get question with answers
|
|
func getQuestionHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
idStr := strings.TrimPrefix(r.URL.Path, "/questions/")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid question ID")
|
|
return
|
|
}
|
|
|
|
// Increment view count
|
|
db.Exec("UPDATE forum_questions SET view_count = view_count + 1 WHERE id = $1", id)
|
|
|
|
// Get question
|
|
var q Question
|
|
var tags string
|
|
var acceptedID sql.NullInt64
|
|
err = db.QueryRow(`SELECT id, title, content, author_id, author_name, tags, upvotes, downvotes,
|
|
answer_count, view_count, accepted_answer_id, status, created_at, updated_at
|
|
FROM forum_questions WHERE id = $1`, id).
|
|
Scan(&q.ID, &q.Title, &q.Content, &q.AuthorID, &q.AuthorName, &tags,
|
|
&q.Upvotes, &q.Downvotes, &q.AnswerCount, &q.ViewCount, &acceptedID, &q.Status,
|
|
&q.CreatedAt, &q.UpdatedAt)
|
|
if err == sql.ErrNoRows {
|
|
respondError(w, http.StatusNotFound, "Question not found")
|
|
return
|
|
} else if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Database error")
|
|
return
|
|
}
|
|
|
|
q.Tags = parseTags(tags)
|
|
if acceptedID.Valid {
|
|
aid := int(acceptedID.Int64)
|
|
q.AcceptedID = &aid
|
|
}
|
|
|
|
// Get answers
|
|
rows, err := db.Query(`SELECT id, question_id, content, author_id, author_name, upvotes, downvotes,
|
|
is_accepted, is_verified, verified_by, verified_at, created_at, updated_at
|
|
FROM forum_answers WHERE question_id = $1 ORDER BY is_accepted DESC, (upvotes - downvotes) DESC, created_at ASC`, id)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to fetch answers")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
answers := []Answer{}
|
|
for rows.Next() {
|
|
var a Answer
|
|
var verifiedBy sql.NullInt64
|
|
var verifiedAt sql.NullTime
|
|
err := rows.Scan(&a.ID, &a.QuestionID, &a.Content, &a.AuthorID, &a.AuthorName,
|
|
&a.Upvotes, &a.Downvotes, &a.IsAccepted, &a.IsVerified, &verifiedBy, &verifiedAt,
|
|
&a.CreatedAt, &a.UpdatedAt)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if verifiedBy.Valid {
|
|
vb := int(verifiedBy.Int64)
|
|
a.VerifiedBy = &vb
|
|
}
|
|
if verifiedAt.Valid {
|
|
a.VerifiedAt = &verifiedAt.Time
|
|
}
|
|
answers = append(answers, a)
|
|
}
|
|
|
|
respondJSON(w, http.StatusOK, map[string]interface{}{
|
|
"question": q,
|
|
"answers": answers,
|
|
})
|
|
}
|
|
|
|
// POST /questions - Create question
|
|
func createQuestionHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
userName := getUserName(claims)
|
|
|
|
var req CreateQuestionRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if req.Title == "" || req.Content == "" {
|
|
respondError(w, http.StatusBadRequest, "Title and content are required")
|
|
return
|
|
}
|
|
|
|
var q Question
|
|
err := db.QueryRow(`INSERT INTO forum_questions (title, content, author_id, author_name, tags)
|
|
VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
|
|
req.Title, req.Content, userID, userName, tagsToPostgres(req.Tags)).
|
|
Scan(&q.ID, &q.CreatedAt, &q.UpdatedAt)
|
|
if err != nil {
|
|
log.Println("Database error:", err)
|
|
respondError(w, http.StatusInternalServerError, "Failed to create question")
|
|
return
|
|
}
|
|
|
|
q.Title = req.Title
|
|
q.Content = req.Content
|
|
q.AuthorID = userID
|
|
q.AuthorName = userName
|
|
q.Tags = req.Tags
|
|
q.Status = "OPEN"
|
|
|
|
log.Printf("AUDIT: User %d created question %d", userID, q.ID)
|
|
respondJSON(w, http.StatusCreated, q)
|
|
}
|
|
|
|
// PUT /questions/:id - Update question
|
|
func updateQuestionHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
isAdmin := hasRole(claims, "ADMIN")
|
|
|
|
idStr := strings.TrimPrefix(r.URL.Path, "/questions/")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid question ID")
|
|
return
|
|
}
|
|
|
|
// Check ownership
|
|
var authorID int
|
|
err = db.QueryRow("SELECT author_id FROM forum_questions WHERE id = $1", id).Scan(&authorID)
|
|
if err == sql.ErrNoRows {
|
|
respondError(w, http.StatusNotFound, "Question not found")
|
|
return
|
|
}
|
|
|
|
if authorID != userID && !isAdmin {
|
|
respondError(w, http.StatusForbidden, "You can only edit your own questions")
|
|
return
|
|
}
|
|
|
|
var req CreateQuestionRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
_, err = db.Exec(`UPDATE forum_questions SET title = $1, content = $2, tags = $3, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = $4`, req.Title, req.Content, tagsToPostgres(req.Tags), id)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to update question")
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: User %d updated question %d", userID, id)
|
|
respondJSON(w, http.StatusOK, map[string]string{"message": "Question updated"})
|
|
}
|
|
|
|
// DELETE /questions/:id - Delete question
|
|
func deleteQuestionHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
isAdmin := hasRole(claims, "ADMIN")
|
|
|
|
idStr := strings.TrimPrefix(r.URL.Path, "/questions/")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid question ID")
|
|
return
|
|
}
|
|
|
|
// Check ownership
|
|
var authorID int
|
|
err = db.QueryRow("SELECT author_id FROM forum_questions WHERE id = $1", id).Scan(&authorID)
|
|
if err == sql.ErrNoRows {
|
|
respondError(w, http.StatusNotFound, "Question not found")
|
|
return
|
|
}
|
|
|
|
if authorID != userID && !isAdmin {
|
|
respondError(w, http.StatusForbidden, "You can only delete your own questions")
|
|
return
|
|
}
|
|
|
|
_, err = db.Exec("DELETE FROM forum_questions WHERE id = $1", id)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to delete question")
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: User %d deleted question %d", userID, id)
|
|
respondJSON(w, http.StatusOK, map[string]string{"message": "Question deleted"})
|
|
}
|
|
|
|
// ============ ANSWER HANDLERS ============
|
|
|
|
// POST /questions/:id/answers - Create answer
|
|
func createAnswerHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
userName := getUserName(claims)
|
|
|
|
// Extract question ID from path like /questions/123/answers
|
|
path := strings.TrimPrefix(r.URL.Path, "/questions/")
|
|
parts := strings.Split(path, "/")
|
|
if len(parts) < 2 {
|
|
respondError(w, http.StatusBadRequest, "Invalid path")
|
|
return
|
|
}
|
|
questionID, err := strconv.Atoi(parts[0])
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid question ID")
|
|
return
|
|
}
|
|
|
|
// Check question exists
|
|
var status string
|
|
err = db.QueryRow("SELECT status FROM forum_questions WHERE id = $1", questionID).Scan(&status)
|
|
if err == sql.ErrNoRows {
|
|
respondError(w, http.StatusNotFound, "Question not found")
|
|
return
|
|
}
|
|
|
|
if status == "CLOSED" {
|
|
respondError(w, http.StatusBadRequest, "This question is closed")
|
|
return
|
|
}
|
|
|
|
var req CreateAnswerRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if req.Content == "" {
|
|
respondError(w, http.StatusBadRequest, "Content is required")
|
|
return
|
|
}
|
|
|
|
var a Answer
|
|
err = db.QueryRow(`INSERT INTO forum_answers (question_id, content, author_id, author_name)
|
|
VALUES ($1, $2, $3, $4) RETURNING id, created_at, updated_at`,
|
|
questionID, req.Content, userID, userName).
|
|
Scan(&a.ID, &a.CreatedAt, &a.UpdatedAt)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to create answer")
|
|
return
|
|
}
|
|
|
|
// Update answer count
|
|
db.Exec("UPDATE forum_questions SET answer_count = answer_count + 1, status = 'ANSWERED' WHERE id = $1", questionID)
|
|
|
|
a.QuestionID = questionID
|
|
a.Content = req.Content
|
|
a.AuthorID = userID
|
|
a.AuthorName = userName
|
|
|
|
log.Printf("AUDIT: User %d answered question %d", userID, questionID)
|
|
respondJSON(w, http.StatusCreated, a)
|
|
}
|
|
|
|
// PUT /answers/:id - Update answer
|
|
func updateAnswerHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
isAdmin := hasRole(claims, "ADMIN")
|
|
|
|
idStr := strings.TrimPrefix(r.URL.Path, "/answers/")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid answer ID")
|
|
return
|
|
}
|
|
|
|
// Check ownership
|
|
var authorID int
|
|
err = db.QueryRow("SELECT author_id FROM forum_answers WHERE id = $1", id).Scan(&authorID)
|
|
if err == sql.ErrNoRows {
|
|
respondError(w, http.StatusNotFound, "Answer not found")
|
|
return
|
|
}
|
|
|
|
if authorID != userID && !isAdmin {
|
|
respondError(w, http.StatusForbidden, "You can only edit your own answers")
|
|
return
|
|
}
|
|
|
|
var req CreateAnswerRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
_, err = db.Exec(`UPDATE forum_answers SET content = $1, updated_at = CURRENT_TIMESTAMP WHERE id = $2`,
|
|
req.Content, id)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to update answer")
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: User %d updated answer %d", userID, id)
|
|
respondJSON(w, http.StatusOK, map[string]string{"message": "Answer updated"})
|
|
}
|
|
|
|
// DELETE /answers/:id - Delete answer
|
|
func deleteAnswerHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
isAdmin := hasRole(claims, "ADMIN")
|
|
|
|
idStr := strings.TrimPrefix(r.URL.Path, "/answers/")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid answer ID")
|
|
return
|
|
}
|
|
|
|
// Check ownership and get question ID
|
|
var authorID, questionID int
|
|
err = db.QueryRow("SELECT author_id, question_id FROM forum_answers WHERE id = $1", id).Scan(&authorID, &questionID)
|
|
if err == sql.ErrNoRows {
|
|
respondError(w, http.StatusNotFound, "Answer not found")
|
|
return
|
|
}
|
|
|
|
if authorID != userID && !isAdmin {
|
|
respondError(w, http.StatusForbidden, "You can only delete your own answers")
|
|
return
|
|
}
|
|
|
|
_, err = db.Exec("DELETE FROM forum_answers WHERE id = $1", id)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to delete answer")
|
|
return
|
|
}
|
|
|
|
// Update answer count
|
|
db.Exec("UPDATE forum_questions SET answer_count = answer_count - 1 WHERE id = $1", questionID)
|
|
|
|
log.Printf("AUDIT: User %d deleted answer %d", userID, id)
|
|
respondJSON(w, http.StatusOK, map[string]string{"message": "Answer deleted"})
|
|
}
|
|
|
|
// POST /answers/:id/accept - Accept answer (question author only)
|
|
func acceptAnswerHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
|
|
idStr := strings.TrimPrefix(r.URL.Path, "/answers/")
|
|
idStr = strings.TrimSuffix(idStr, "/accept")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid answer ID")
|
|
return
|
|
}
|
|
|
|
// Get answer and question
|
|
var questionID int
|
|
err = db.QueryRow("SELECT question_id FROM forum_answers WHERE id = $1", id).Scan(&questionID)
|
|
if err == sql.ErrNoRows {
|
|
respondError(w, http.StatusNotFound, "Answer not found")
|
|
return
|
|
}
|
|
|
|
// Check if user owns the question
|
|
var questionAuthorID int
|
|
err = db.QueryRow("SELECT author_id FROM forum_questions WHERE id = $1", questionID).Scan(&questionAuthorID)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Database error")
|
|
return
|
|
}
|
|
|
|
if questionAuthorID != userID {
|
|
respondError(w, http.StatusForbidden, "Only the question author can accept answers")
|
|
return
|
|
}
|
|
|
|
// Unaccept any previously accepted answer
|
|
db.Exec("UPDATE forum_answers SET is_accepted = FALSE WHERE question_id = $1", questionID)
|
|
|
|
// Accept this answer
|
|
db.Exec("UPDATE forum_answers SET is_accepted = TRUE WHERE id = $1", id)
|
|
db.Exec("UPDATE forum_questions SET accepted_answer_id = $1, status = 'ANSWERED' WHERE id = $2", id, questionID)
|
|
|
|
log.Printf("AUDIT: User %d accepted answer %d for question %d", userID, id, questionID)
|
|
respondJSON(w, http.StatusOK, map[string]string{"message": "Answer accepted"})
|
|
}
|
|
|
|
// POST /answers/:id/verify - Verify answer (ADMIN only)
|
|
func verifyAnswerHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
|
|
idStr := strings.TrimPrefix(r.URL.Path, "/answers/")
|
|
idStr = strings.TrimSuffix(idStr, "/verify")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid answer ID")
|
|
return
|
|
}
|
|
|
|
_, err = db.Exec(`UPDATE forum_answers SET is_verified = TRUE, verified_by = $1,
|
|
verified_at = CURRENT_TIMESTAMP WHERE id = $2`, userID, id)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to verify answer")
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: Admin %d verified answer %d", userID, id)
|
|
respondJSON(w, http.StatusOK, map[string]string{"message": "Answer verified"})
|
|
}
|
|
|
|
// ============ VOTING HANDLERS ============
|
|
|
|
// POST /questions/:id/vote - Vote on question
|
|
func voteQuestionHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
|
|
idStr := strings.TrimPrefix(r.URL.Path, "/questions/")
|
|
idStr = strings.TrimSuffix(idStr, "/vote")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid question ID")
|
|
return
|
|
}
|
|
|
|
var req VoteRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if req.VoteType != 1 && req.VoteType != -1 {
|
|
respondError(w, http.StatusBadRequest, "VoteType must be 1 (upvote) or -1 (downvote)")
|
|
return
|
|
}
|
|
|
|
// Check if user already voted
|
|
var existingVote int
|
|
err = db.QueryRow("SELECT vote_type FROM forum_votes WHERE user_id = $1 AND target_type = 'question' AND target_id = $2",
|
|
userID, id).Scan(&existingVote)
|
|
|
|
if err == sql.ErrNoRows {
|
|
// New vote
|
|
db.Exec("INSERT INTO forum_votes (user_id, target_type, target_id, vote_type) VALUES ($1, 'question', $2, $3)",
|
|
userID, id, req.VoteType)
|
|
if req.VoteType == 1 {
|
|
db.Exec("UPDATE forum_questions SET upvotes = upvotes + 1 WHERE id = $1", id)
|
|
} else {
|
|
db.Exec("UPDATE forum_questions SET downvotes = downvotes + 1 WHERE id = $1", id)
|
|
}
|
|
} else if existingVote != req.VoteType {
|
|
// Changing vote
|
|
db.Exec("UPDATE forum_votes SET vote_type = $1 WHERE user_id = $2 AND target_type = 'question' AND target_id = $3",
|
|
req.VoteType, userID, id)
|
|
if req.VoteType == 1 {
|
|
db.Exec("UPDATE forum_questions SET upvotes = upvotes + 1, downvotes = downvotes - 1 WHERE id = $1", id)
|
|
} else {
|
|
db.Exec("UPDATE forum_questions SET upvotes = upvotes - 1, downvotes = downvotes + 1 WHERE id = $1", id)
|
|
}
|
|
}
|
|
// If same vote, do nothing
|
|
|
|
respondJSON(w, http.StatusOK, map[string]string{"message": "Vote recorded"})
|
|
}
|
|
|
|
// POST /answers/:id/vote - Vote on answer
|
|
func voteAnswerHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
|
|
idStr := strings.TrimPrefix(r.URL.Path, "/answers/")
|
|
idStr = strings.TrimSuffix(idStr, "/vote")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid answer ID")
|
|
return
|
|
}
|
|
|
|
var req VoteRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if req.VoteType != 1 && req.VoteType != -1 {
|
|
respondError(w, http.StatusBadRequest, "VoteType must be 1 or -1")
|
|
return
|
|
}
|
|
|
|
// Check if user already voted
|
|
var existingVote int
|
|
err = db.QueryRow("SELECT vote_type FROM forum_votes WHERE user_id = $1 AND target_type = 'answer' AND target_id = $2",
|
|
userID, id).Scan(&existingVote)
|
|
|
|
if err == sql.ErrNoRows {
|
|
// New vote
|
|
db.Exec("INSERT INTO forum_votes (user_id, target_type, target_id, vote_type) VALUES ($1, 'answer', $2, $3)",
|
|
userID, id, req.VoteType)
|
|
if req.VoteType == 1 {
|
|
db.Exec("UPDATE forum_answers SET upvotes = upvotes + 1 WHERE id = $1", id)
|
|
} else {
|
|
db.Exec("UPDATE forum_answers SET downvotes = downvotes + 1 WHERE id = $1", id)
|
|
}
|
|
} else if existingVote != req.VoteType {
|
|
// Changing vote
|
|
db.Exec("UPDATE forum_votes SET vote_type = $1 WHERE user_id = $2 AND target_type = 'answer' AND target_id = $3",
|
|
req.VoteType, userID, id)
|
|
if req.VoteType == 1 {
|
|
db.Exec("UPDATE forum_answers SET upvotes = upvotes + 1, downvotes = downvotes - 1 WHERE id = $1", id)
|
|
} else {
|
|
db.Exec("UPDATE forum_answers SET upvotes = upvotes - 1, downvotes = downvotes + 1 WHERE id = $1", id)
|
|
}
|
|
}
|
|
|
|
respondJSON(w, http.StatusOK, map[string]string{"message": "Vote recorded"})
|
|
}
|
|
|
|
// POST /questions/:id/close - Close question (ADMIN only)
|
|
func closeQuestionHandler(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
claims := r.Context().Value(userContextKey).(jwt.MapClaims)
|
|
userID := getUserID(claims)
|
|
|
|
idStr := strings.TrimPrefix(r.URL.Path, "/questions/")
|
|
idStr = strings.TrimSuffix(idStr, "/close")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
respondError(w, http.StatusBadRequest, "Invalid question ID")
|
|
return
|
|
}
|
|
|
|
_, err = db.Exec("UPDATE forum_questions SET status = 'CLOSED', updated_at = CURRENT_TIMESTAMP WHERE id = $1", id)
|
|
if err != nil {
|
|
respondError(w, http.StatusInternalServerError, "Failed to close question")
|
|
return
|
|
}
|
|
|
|
log.Printf("AUDIT: Admin %d closed question %d", userID, id)
|
|
respondJSON(w, http.StatusOK, map[string]string{"message": "Question closed"})
|
|
}
|
|
|
|
func main() {
|
|
loadConfig()
|
|
initDB()
|
|
defer db.Close()
|
|
|
|
// Public endpoints
|
|
http.HandleFunc("/questions", func(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
if r.Method == http.MethodGet {
|
|
listQuestionsHandler(w, r)
|
|
} else if r.Method == http.MethodPost {
|
|
authMiddleware(createQuestionHandler)(w, r)
|
|
} else {
|
|
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
})
|
|
|
|
http.HandleFunc("/questions/", func(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
path := r.URL.Path
|
|
|
|
switch {
|
|
case strings.HasSuffix(path, "/answers"):
|
|
if r.Method == http.MethodPost {
|
|
authMiddleware(createAnswerHandler)(w, r)
|
|
} else {
|
|
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
case strings.HasSuffix(path, "/vote"):
|
|
if r.Method == http.MethodPost {
|
|
authMiddleware(voteQuestionHandler)(w, r)
|
|
} else {
|
|
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
case strings.HasSuffix(path, "/close"):
|
|
if r.Method == http.MethodPost {
|
|
requireRole(closeQuestionHandler, "ADMIN")(w, r)
|
|
} else {
|
|
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
default:
|
|
if r.Method == http.MethodGet {
|
|
getQuestionHandler(w, r)
|
|
} else if r.Method == http.MethodPut {
|
|
authMiddleware(updateQuestionHandler)(w, r)
|
|
} else if r.Method == http.MethodDelete {
|
|
authMiddleware(deleteQuestionHandler)(w, r)
|
|
} else {
|
|
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
}
|
|
})
|
|
|
|
http.HandleFunc("/answers/", func(w http.ResponseWriter, r *http.Request) {
|
|
enableCORS(w)
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
path := r.URL.Path
|
|
|
|
switch {
|
|
case strings.HasSuffix(path, "/accept"):
|
|
if r.Method == http.MethodPost {
|
|
authMiddleware(acceptAnswerHandler)(w, r)
|
|
} else {
|
|
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
case strings.HasSuffix(path, "/verify"):
|
|
if r.Method == http.MethodPost {
|
|
requireRole(verifyAnswerHandler, "ADMIN")(w, r)
|
|
} else {
|
|
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
case strings.HasSuffix(path, "/vote"):
|
|
if r.Method == http.MethodPost {
|
|
authMiddleware(voteAnswerHandler)(w, r)
|
|
} else {
|
|
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
default:
|
|
if r.Method == http.MethodPut {
|
|
authMiddleware(updateAnswerHandler)(w, r)
|
|
} else if r.Method == http.MethodDelete {
|
|
authMiddleware(deleteAnswerHandler)(w, r)
|
|
} else {
|
|
respondError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
}
|
|
})
|
|
|
|
// Health check (both /health and /healthz for compatibility)
|
|
healthHandler := func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
fmt.Fprintln(w, "ok")
|
|
}
|
|
http.HandleFunc("/health", healthHandler)
|
|
http.HandleFunc("/healthz", healthHandler)
|
|
|
|
port := os.Getenv("PORT")
|
|
if port == "" {
|
|
port = "8080"
|
|
}
|
|
|
|
// Wrap all routes with rate limiting and body size limit
|
|
rateLimitedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Limit request body size to prevent DoS
|
|
if r.Body != nil {
|
|
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody)
|
|
}
|
|
|
|
clientIP := getClientIP(r)
|
|
|
|
if r.Method == http.MethodGet || r.Method == http.MethodHead {
|
|
if readLimiter.checkRateLimit(clientIP, maxReadRequests, rateLimitWindow) {
|
|
log.Printf("SECURITY: Read rate limit exceeded for IP %s on %s", clientIP, r.URL.Path)
|
|
http.Error(w, "Too many requests. Please slow down.", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
} else if r.Method != http.MethodOptions {
|
|
if writeLimiter.checkRateLimit(clientIP, maxWriteRequests, rateLimitWindow) {
|
|
log.Printf("SECURITY: Write rate limit exceeded for IP %s on %s", clientIP, r.URL.Path)
|
|
http.Error(w, "Too many requests. Please slow down.", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
}
|
|
|
|
http.DefaultServeMux.ServeHTTP(w, r)
|
|
})
|
|
|
|
server := &http.Server{
|
|
Addr: ":" + port,
|
|
Handler: rateLimitedHandler,
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
ReadTimeout: 15 * time.Second,
|
|
WriteTimeout: 15 * time.Second,
|
|
IdleTimeout: 60 * time.Second,
|
|
}
|
|
|
|
// Graceful shutdown
|
|
done := make(chan bool, 1)
|
|
quit := make(chan os.Signal, 1)
|
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
go func() {
|
|
<-quit
|
|
log.Println("Forum Service shutting down...")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
server.SetKeepAlivesEnabled(false)
|
|
if err := server.Shutdown(ctx); err != nil {
|
|
log.Printf("Could not gracefully shutdown: %v", err)
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
log.Printf("Forum service starting on port %s\n", port)
|
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
log.Fatalf("Server error: %v", err)
|
|
}
|
|
|
|
<-done
|
|
log.Println("Forum Service stopped")
|
|
}
|