package main import ( "context" "database/sql" "encoding/json" "fmt" "log" "net" "net/http" "os" "os/signal" "regexp" "strings" "sync" "syscall" "time" "github.com/golang-jwt/jwt/v5" _ "github.com/lib/pq" ) // Rate limiting configuration const ( rateLimitWindow = 1 * time.Minute maxSubmitRequests = 5 // Max contact form submissions per minute per IP maxReadRequests = 50 // Max read requests per minute per IP maxRequestBody = 1 << 20 // 1MB max request body size ) type rateLimiter struct { mu sync.RWMutex requests map[string]*requestInfo } type requestInfo struct { count int firstReq time.Time } var submitLimiter = &rateLimiter{requests: make(map[string]*requestInfo)} var readLimiter = &rateLimiter{requests: make(map[string]*requestInfo)} func (rl *rateLimiter) checkRateLimit(key string, maxRequests int, window time.Duration) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() info, exists := rl.requests[key] if !exists { rl.requests[key] = &requestInfo{count: 1, firstReq: now} return false } if now.Sub(info.firstReq) > window { info.count = 1 info.firstReq = now return false } info.count++ return info.count > maxRequests } func getClientIP(r *http.Request) string { xff := r.Header.Get("X-Forwarded-For") if xff != "" { if idx := strings.Index(xff, ","); idx != -1 { return strings.TrimSpace(xff[:idx]) } return strings.TrimSpace(xff) } xri := r.Header.Get("X-Real-IP") if xri != "" { return strings.TrimSpace(xri) } ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return ip } var jwtSecret []byte var db *sql.DB // ContactSubmission represents a contact form submission type ContactSubmission struct { ID int `json:"id"` Name string `json:"name"` Email string `json:"email"` Phone string `json:"phone,omitempty"` Subject string `json:"subject,omitempty"` Message string `json:"message"` Status string `json:"status"` // NEW, READ, REPLIED, ARCHIVED CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } // ContactResponse for API responses type ContactResponse struct { Success bool `json:"success"` Message string `json:"message"` ID int `json:"id,omitempty"` } func main() { port := os.Getenv("PORT") if port == "" { port = "8080" } // Initialize JWT secret (required for authentication) secret := os.Getenv("JWT_SECRET") if secret == "" || len(secret) < 32 { log.Fatal("JWT_SECRET must be set and at least 32 characters") } jwtSecret = []byte(secret) // Initialize database if err := initDB(); err != nil { log.Fatalf("Failed to initialize database: %v", err) } defer db.Close() // Setup routes mux := http.NewServeMux() // Public endpoints mux.HandleFunc("/submit", corsMiddleware(rateLimitSubmit(submitHandler))) mux.HandleFunc("/health", corsMiddleware(healthHandler)) mux.HandleFunc("/healthz", corsMiddleware(healthHandler)) // Admin endpoints (protected) mux.HandleFunc("/submissions", corsMiddleware(authMiddleware(listSubmissionsHandler))) mux.HandleFunc("/submissions/", corsMiddleware(authMiddleware(submissionHandler))) // Wrap with rate limiting and body size limit for all requests rateLimitedMux := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Limit request body size to prevent DoS if r.Body != nil { r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody) } clientIP := getClientIP(r) // Read requests (GET) have higher limits if r.Method == http.MethodGet { if readLimiter.checkRateLimit(clientIP, maxReadRequests, rateLimitWindow) { log.Printf("SECURITY: Read rate limit exceeded for IP %s on %s", clientIP, r.URL.Path) http.Error(w, "Too many requests. Please slow down.", http.StatusTooManyRequests) return } } mux.ServeHTTP(w, r) }) server := &http.Server{ Addr: ":" + port, Handler: rateLimitedMux, ReadHeaderTimeout: 10 * time.Second, ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, IdleTimeout: 60 * time.Second, } // Graceful shutdown done := make(chan bool, 1) quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) go func() { <-quit log.Println("Contact Service shutting down...") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() server.SetKeepAlivesEnabled(false) if err := server.Shutdown(ctx); err != nil { log.Printf("Could not gracefully shutdown: %v", err) } close(done) }() log.Printf("Contact Service starting on port %s", port) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Server error: %v", err) } <-done log.Println("Contact Service stopped") } // rateLimitSubmit applies strict rate limiting for contact form submissions func rateLimitSubmit(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { clientIP := getClientIP(r) if submitLimiter.checkRateLimit(clientIP, maxSubmitRequests, rateLimitWindow) { log.Printf("SECURITY: Contact form rate limit exceeded for IP %s", clientIP) http.Error(w, "Too many submissions. Please wait before trying again.", http.StatusTooManyRequests) return } } next.ServeHTTP(w, r) } } func initDB() error { host := os.Getenv("DB_HOST") user := os.Getenv("DB_USER") password := os.Getenv("DB_PASSWORD") dbname := os.Getenv("DB_NAME") // Require all database configuration - no hardcoded defaults for security if host == "" || user == "" || password == "" || dbname == "" { return fmt.Errorf("database configuration missing: DB_HOST, DB_USER, DB_PASSWORD, DB_NAME required") } sslmode := os.Getenv("DB_SSL_MODE") if sslmode == "" { sslmode = "disable" } schema := os.Getenv("DB_SCHEMA") // Validate schema value if provided validSchemas := map[string]bool{"": true, "public": true, "dev": true, "testing": true, "prod": true} if !validSchemas[schema] { return fmt.Errorf("invalid DB_SCHEMA '%s'. Must be: dev, testing, prod, or empty for public", schema) } connStr := fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=%s", host, user, password, dbname, sslmode) if schema != "" && schema != "public" { connStr += fmt.Sprintf(" search_path=%s,public", schema) } var err error db, err = sql.Open("postgres", connStr) if err != nil { return fmt.Errorf("failed to open database: %w", err) } // Configure connection pool limits db.SetMaxOpenConns(25) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(5 * time.Minute) db.SetConnMaxIdleTime(1 * time.Minute) if err := db.Ping(); err != nil { return fmt.Errorf("failed to ping database: %w", err) } // Run migrations if err := runMigrations(); err != nil { return fmt.Errorf("failed to run migrations: %w", err) } schemaInfo := "public" if schema != "" && schema != "public" { schemaInfo = schema } log.Printf("Database connection established (schema: %s, max_conns: 25)", schemaInfo) return nil } func runMigrations() error { migration := ` CREATE TABLE IF NOT EXISTS contact_submissions ( id SERIAL PRIMARY KEY, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, phone VARCHAR(50), subject VARCHAR(255), message TEXT NOT NULL, status VARCHAR(20) DEFAULT 'NEW', created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ); CREATE INDEX IF NOT EXISTS idx_contact_submissions_status ON contact_submissions(status); CREATE INDEX IF NOT EXISTS idx_contact_submissions_created_at ON contact_submissions(created_at DESC); ` _, err := db.Exec(migration) if err != nil { return fmt.Errorf("failed to execute contact_submissions migration: %w", err) } return nil } func corsMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { origin := os.Getenv("CORS_ALLOW_ORIGIN") if origin == "" { origin = "http://localhost:5173,http://localhost:8090,http://localhost:8091" } requestOrigin := r.Header.Get("Origin") allowed := false for _, o := range strings.Split(origin, ",") { if strings.TrimSpace(o) == requestOrigin || o == "*" { allowed = true break } } if allowed { w.Header().Set("Access-Control-Allow-Origin", requestOrigin) } // CORS headers w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") w.Header().Set("Access-Control-Allow-Credentials", "true") // Security headers w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Content-Security-Policy", "default-src 'self'") w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next(w, r) } } func authMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Check for Authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } tokenString := strings.TrimPrefix(authHeader, "Bearer ") if tokenString == "" { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } // Parse and validate JWT token token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { // Verify signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return jwtSecret, nil }) if err != nil { log.Printf("JWT validation error: %v", err) http.Error(w, "Invalid token", http.StatusUnauthorized) return } if !token.Valid { http.Error(w, "Invalid token", http.StatusUnauthorized) return } // Check for admin/staff roles (only they can view submissions) claims, ok := token.Claims.(jwt.MapClaims) if !ok { http.Error(w, "Invalid token claims", http.StatusUnauthorized) return } // Verify user has appropriate role roles, ok := claims["roles"].([]interface{}) if !ok { http.Error(w, "Forbidden", http.StatusForbidden) return } hasPermission := false for _, role := range roles { roleStr, _ := role.(string) if roleStr == "SUPERUSER" || roleStr == "ADMIN" || roleStr == "STAFF" { hasPermission = true break } } if !hasPermission { http.Error(w, "Forbidden - insufficient permissions", http.StatusForbidden) return } next(w, r) } } func healthHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"status": "healthy"}) } func submitHandler(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var submission struct { Name string `json:"name"` Email string `json:"email"` Phone string `json:"phone"` Subject string `json:"subject"` Message string `json:"message"` } if err := json.NewDecoder(r.Body).Decode(&submission); err != nil { sendError(w, "Invalid request body", http.StatusBadRequest) return } // Validate required fields submission.Name = strings.TrimSpace(submission.Name) submission.Email = strings.TrimSpace(submission.Email) submission.Message = strings.TrimSpace(submission.Message) if submission.Name == "" { sendError(w, "Name is required", http.StatusBadRequest) return } if submission.Email == "" { sendError(w, "Email is required", http.StatusBadRequest) return } if !isValidEmail(submission.Email) { sendError(w, "Invalid email address", http.StatusBadRequest) return } if submission.Message == "" { sendError(w, "Message is required", http.StatusBadRequest) return } if len(submission.Message) < 10 { sendError(w, "Message must be at least 10 characters", http.StatusBadRequest) return } // Insert into database var id int err := db.QueryRow(` INSERT INTO contact_submissions (name, email, phone, subject, message, status) VALUES ($1, $2, $3, $4, $5, 'NEW') RETURNING id `, submission.Name, submission.Email, submission.Phone, submission.Subject, submission.Message).Scan(&id) if err != nil { log.Printf("Failed to insert contact submission: %v", err) sendError(w, "Failed to submit message", http.StatusInternalServerError) return } log.Printf("New contact submission from %s <%s>", submission.Name, submission.Email) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(ContactResponse{ Success: true, Message: "Thank you for your message! We will get back to you shortly.", ID: id, }) } func listSubmissionsHandler(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } status := r.URL.Query().Get("status") var rows *sql.Rows var err error if status != "" { rows, err = db.Query(` SELECT id, name, email, phone, subject, message, status, created_at, updated_at FROM contact_submissions WHERE status = $1 ORDER BY created_at DESC `, status) } else { rows, err = db.Query(` SELECT id, name, email, phone, subject, message, status, created_at, updated_at FROM contact_submissions ORDER BY created_at DESC `) } if err != nil { log.Printf("Failed to query submissions: %v", err) sendError(w, "Failed to fetch submissions", http.StatusInternalServerError) return } defer rows.Close() submissions := []ContactSubmission{} for rows.Next() { var s ContactSubmission var phone, subject sql.NullString if err := rows.Scan(&s.ID, &s.Name, &s.Email, &phone, &subject, &s.Message, &s.Status, &s.CreatedAt, &s.UpdatedAt); err != nil { log.Printf("Failed to scan submission: %v", err) continue } if phone.Valid { s.Phone = phone.String } if subject.Valid { s.Subject = subject.String } submissions = append(submissions, s) } w.Header().Set("Content-Type", "application/json") // Wrap response in object with submissions array and total count for frontend compatibility json.NewEncoder(w).Encode(map[string]interface{}{ "submissions": submissions, "total": len(submissions), }) } func submissionHandler(w http.ResponseWriter, r *http.Request) { // Extract ID from URL path path := strings.TrimPrefix(r.URL.Path, "/submissions/") if path == "" { http.Error(w, "Submission ID required", http.StatusBadRequest) return } var id int if _, err := fmt.Sscanf(path, "%d", &id); err != nil { http.Error(w, "Invalid submission ID", http.StatusBadRequest) return } switch r.Method { case "GET": getSubmission(w, id) case "PUT": updateSubmission(w, r, id) case "DELETE": deleteSubmission(w, id) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } func getSubmission(w http.ResponseWriter, id int) { var s ContactSubmission var phone, subject sql.NullString err := db.QueryRow(` SELECT id, name, email, phone, subject, message, status, created_at, updated_at FROM contact_submissions WHERE id = $1 `, id).Scan(&s.ID, &s.Name, &s.Email, &phone, &subject, &s.Message, &s.Status, &s.CreatedAt, &s.UpdatedAt) if err == sql.ErrNoRows { http.Error(w, "Submission not found", http.StatusNotFound) return } if err != nil { sendError(w, "Failed to fetch submission", http.StatusInternalServerError) return } if phone.Valid { s.Phone = phone.String } if subject.Valid { s.Subject = subject.String } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(s) } func updateSubmission(w http.ResponseWriter, r *http.Request, id int) { var update struct { Status string `json:"status"` } if err := json.NewDecoder(r.Body).Decode(&update); err != nil { sendError(w, "Invalid request body", http.StatusBadRequest) return } // Validate status validStatuses := map[string]bool{"NEW": true, "READ": true, "REPLIED": true, "ARCHIVED": true} if !validStatuses[update.Status] { sendError(w, "Invalid status. Must be NEW, READ, REPLIED, or ARCHIVED", http.StatusBadRequest) return } result, err := db.Exec(` UPDATE contact_submissions SET status = $1, updated_at = NOW() WHERE id = $2 `, update.Status, id) if err != nil { sendError(w, "Failed to update submission", http.StatusInternalServerError) return } rowsAffected, _ := result.RowsAffected() if rowsAffected == 0 { http.Error(w, "Submission not found", http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(ContactResponse{Success: true, Message: "Submission updated"}) } func deleteSubmission(w http.ResponseWriter, id int) { result, err := db.Exec(`DELETE FROM contact_submissions WHERE id = $1`, id) if err != nil { sendError(w, "Failed to delete submission", http.StatusInternalServerError) return } rowsAffected, _ := result.RowsAffected() if rowsAffected == 0 { http.Error(w, "Submission not found", http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(ContactResponse{Success: true, Message: "Submission deleted"}) } func sendError(w http.ResponseWriter, message string, status int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(ContactResponse{Success: false, Message: message}) } func isValidEmail(email string) bool { emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) return emailRegex.MatchString(email) }