Files
web-hosts/domains/coppertone.tech/backend/functions/contact-service/main.go
2025-12-26 13:38:04 +01:00

658 lines
18 KiB
Go

package main
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"regexp"
"strings"
"sync"
"syscall"
"time"
"github.com/golang-jwt/jwt/v5"
_ "github.com/lib/pq"
)
// Rate limiting configuration
const (
rateLimitWindow = 1 * time.Minute
maxSubmitRequests = 5 // Max contact form submissions per minute per IP
maxReadRequests = 50 // Max read requests per minute per IP
maxRequestBody = 1 << 20 // 1MB max request body size
)
type rateLimiter struct {
mu sync.RWMutex
requests map[string]*requestInfo
}
type requestInfo struct {
count int
firstReq time.Time
}
var submitLimiter = &rateLimiter{requests: make(map[string]*requestInfo)}
var readLimiter = &rateLimiter{requests: make(map[string]*requestInfo)}
func (rl *rateLimiter) checkRateLimit(key string, maxRequests int, window time.Duration) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
info, exists := rl.requests[key]
if !exists {
rl.requests[key] = &requestInfo{count: 1, firstReq: now}
return false
}
if now.Sub(info.firstReq) > window {
info.count = 1
info.firstReq = now
return false
}
info.count++
return info.count > maxRequests
}
func getClientIP(r *http.Request) string {
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
xri := r.Header.Get("X-Real-IP")
if xri != "" {
return strings.TrimSpace(xri)
}
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}
var jwtSecret []byte
var db *sql.DB
// ContactSubmission represents a contact form submission
type ContactSubmission struct {
ID int `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Phone string `json:"phone,omitempty"`
Subject string `json:"subject,omitempty"`
Message string `json:"message"`
Status string `json:"status"` // NEW, READ, REPLIED, ARCHIVED
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// ContactResponse for API responses
type ContactResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
ID int `json:"id,omitempty"`
}
func main() {
port := os.Getenv("PORT")
if port == "" {
port = "8080"
}
// Initialize JWT secret (required for authentication)
secret := os.Getenv("JWT_SECRET")
if secret == "" || len(secret) < 32 {
log.Fatal("JWT_SECRET must be set and at least 32 characters")
}
jwtSecret = []byte(secret)
// Initialize database
if err := initDB(); err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer db.Close()
// Setup routes
mux := http.NewServeMux()
// Public endpoints
mux.HandleFunc("/submit", corsMiddleware(rateLimitSubmit(submitHandler)))
mux.HandleFunc("/health", corsMiddleware(healthHandler))
mux.HandleFunc("/healthz", corsMiddleware(healthHandler))
// Admin endpoints (protected)
mux.HandleFunc("/submissions", corsMiddleware(authMiddleware(listSubmissionsHandler)))
mux.HandleFunc("/submissions/", corsMiddleware(authMiddleware(submissionHandler)))
// Wrap with rate limiting and body size limit for all requests
rateLimitedMux := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Limit request body size to prevent DoS
if r.Body != nil {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBody)
}
clientIP := getClientIP(r)
// Read requests (GET) have higher limits
if r.Method == http.MethodGet {
if readLimiter.checkRateLimit(clientIP, maxReadRequests, rateLimitWindow) {
log.Printf("SECURITY: Read rate limit exceeded for IP %s on %s", clientIP, r.URL.Path)
http.Error(w, "Too many requests. Please slow down.", http.StatusTooManyRequests)
return
}
}
mux.ServeHTTP(w, r)
})
server := &http.Server{
Addr: ":" + port,
Handler: rateLimitedMux,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
// Graceful shutdown
done := make(chan bool, 1)
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-quit
log.Println("Contact Service shutting down...")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
server.SetKeepAlivesEnabled(false)
if err := server.Shutdown(ctx); err != nil {
log.Printf("Could not gracefully shutdown: %v", err)
}
close(done)
}()
log.Printf("Contact Service starting on port %s", port)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server error: %v", err)
}
<-done
log.Println("Contact Service stopped")
}
// rateLimitSubmit applies strict rate limiting for contact form submissions
func rateLimitSubmit(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
clientIP := getClientIP(r)
if submitLimiter.checkRateLimit(clientIP, maxSubmitRequests, rateLimitWindow) {
log.Printf("SECURITY: Contact form rate limit exceeded for IP %s", clientIP)
http.Error(w, "Too many submissions. Please wait before trying again.", http.StatusTooManyRequests)
return
}
}
next.ServeHTTP(w, r)
}
}
func initDB() error {
host := os.Getenv("DB_HOST")
user := os.Getenv("DB_USER")
password := os.Getenv("DB_PASSWORD")
dbname := os.Getenv("DB_NAME")
// Require all database configuration - no hardcoded defaults for security
if host == "" || user == "" || password == "" || dbname == "" {
return fmt.Errorf("database configuration missing: DB_HOST, DB_USER, DB_PASSWORD, DB_NAME required")
}
sslmode := os.Getenv("DB_SSL_MODE")
if sslmode == "" {
sslmode = "disable"
}
schema := os.Getenv("DB_SCHEMA")
// Validate schema value if provided
validSchemas := map[string]bool{"": true, "public": true, "dev": true, "testing": true, "prod": true}
if !validSchemas[schema] {
return fmt.Errorf("invalid DB_SCHEMA '%s'. Must be: dev, testing, prod, or empty for public", schema)
}
connStr := fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=%s",
host, user, password, dbname, sslmode)
if schema != "" && schema != "public" {
connStr += fmt.Sprintf(" search_path=%s,public", schema)
}
var err error
db, err = sql.Open("postgres", connStr)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
// Configure connection pool limits
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
db.SetConnMaxIdleTime(1 * time.Minute)
if err := db.Ping(); err != nil {
return fmt.Errorf("failed to ping database: %w", err)
}
// Run migrations
if err := runMigrations(); err != nil {
return fmt.Errorf("failed to run migrations: %w", err)
}
schemaInfo := "public"
if schema != "" && schema != "public" {
schemaInfo = schema
}
log.Printf("Database connection established (schema: %s, max_conns: 25)", schemaInfo)
return nil
}
func runMigrations() error {
migration := `
CREATE TABLE IF NOT EXISTS contact_submissions (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
email VARCHAR(255) NOT NULL,
phone VARCHAR(50),
subject VARCHAR(255),
message TEXT NOT NULL,
status VARCHAR(20) DEFAULT 'NEW',
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_contact_submissions_status ON contact_submissions(status);
CREATE INDEX IF NOT EXISTS idx_contact_submissions_created_at ON contact_submissions(created_at DESC);
`
_, err := db.Exec(migration)
if err != nil {
return fmt.Errorf("failed to execute contact_submissions migration: %w", err)
}
return nil
}
func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
origin := os.Getenv("CORS_ALLOW_ORIGIN")
if origin == "" {
origin = "http://localhost:5173,http://localhost:8090,http://localhost:8091"
}
requestOrigin := r.Header.Get("Origin")
allowed := false
for _, o := range strings.Split(origin, ",") {
if strings.TrimSpace(o) == requestOrigin || o == "*" {
allowed = true
break
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", requestOrigin)
}
// CORS headers
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Allow-Credentials", "true")
// Security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'")
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next(w, r)
}
}
func authMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Check for Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == "" {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
// Parse and validate JWT token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil {
log.Printf("JWT validation error: %v", err)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
if !token.Valid {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
// Check for admin/staff roles (only they can view submissions)
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
return
}
// Verify user has appropriate role
roles, ok := claims["roles"].([]interface{})
if !ok {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
hasPermission := false
for _, role := range roles {
roleStr, _ := role.(string)
if roleStr == "SUPERUSER" || roleStr == "ADMIN" || roleStr == "STAFF" {
hasPermission = true
break
}
}
if !hasPermission {
http.Error(w, "Forbidden - insufficient permissions", http.StatusForbidden)
return
}
next(w, r)
}
}
func healthHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
}
func submitHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var submission struct {
Name string `json:"name"`
Email string `json:"email"`
Phone string `json:"phone"`
Subject string `json:"subject"`
Message string `json:"message"`
}
if err := json.NewDecoder(r.Body).Decode(&submission); err != nil {
sendError(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate required fields
submission.Name = strings.TrimSpace(submission.Name)
submission.Email = strings.TrimSpace(submission.Email)
submission.Message = strings.TrimSpace(submission.Message)
if submission.Name == "" {
sendError(w, "Name is required", http.StatusBadRequest)
return
}
if submission.Email == "" {
sendError(w, "Email is required", http.StatusBadRequest)
return
}
if !isValidEmail(submission.Email) {
sendError(w, "Invalid email address", http.StatusBadRequest)
return
}
if submission.Message == "" {
sendError(w, "Message is required", http.StatusBadRequest)
return
}
if len(submission.Message) < 10 {
sendError(w, "Message must be at least 10 characters", http.StatusBadRequest)
return
}
// Insert into database
var id int
err := db.QueryRow(`
INSERT INTO contact_submissions (name, email, phone, subject, message, status)
VALUES ($1, $2, $3, $4, $5, 'NEW')
RETURNING id
`, submission.Name, submission.Email, submission.Phone, submission.Subject, submission.Message).Scan(&id)
if err != nil {
log.Printf("Failed to insert contact submission: %v", err)
sendError(w, "Failed to submit message", http.StatusInternalServerError)
return
}
log.Printf("New contact submission from %s <%s>", submission.Name, submission.Email)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(ContactResponse{
Success: true,
Message: "Thank you for your message! We will get back to you shortly.",
ID: id,
})
}
func listSubmissionsHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
status := r.URL.Query().Get("status")
var rows *sql.Rows
var err error
if status != "" {
rows, err = db.Query(`
SELECT id, name, email, phone, subject, message, status, created_at, updated_at
FROM contact_submissions
WHERE status = $1
ORDER BY created_at DESC
`, status)
} else {
rows, err = db.Query(`
SELECT id, name, email, phone, subject, message, status, created_at, updated_at
FROM contact_submissions
ORDER BY created_at DESC
`)
}
if err != nil {
log.Printf("Failed to query submissions: %v", err)
sendError(w, "Failed to fetch submissions", http.StatusInternalServerError)
return
}
defer rows.Close()
submissions := []ContactSubmission{}
for rows.Next() {
var s ContactSubmission
var phone, subject sql.NullString
if err := rows.Scan(&s.ID, &s.Name, &s.Email, &phone, &subject, &s.Message, &s.Status, &s.CreatedAt, &s.UpdatedAt); err != nil {
log.Printf("Failed to scan submission: %v", err)
continue
}
if phone.Valid {
s.Phone = phone.String
}
if subject.Valid {
s.Subject = subject.String
}
submissions = append(submissions, s)
}
w.Header().Set("Content-Type", "application/json")
// Wrap response in object with submissions array and total count for frontend compatibility
json.NewEncoder(w).Encode(map[string]interface{}{
"submissions": submissions,
"total": len(submissions),
})
}
func submissionHandler(w http.ResponseWriter, r *http.Request) {
// Extract ID from URL path
path := strings.TrimPrefix(r.URL.Path, "/submissions/")
if path == "" {
http.Error(w, "Submission ID required", http.StatusBadRequest)
return
}
var id int
if _, err := fmt.Sscanf(path, "%d", &id); err != nil {
http.Error(w, "Invalid submission ID", http.StatusBadRequest)
return
}
switch r.Method {
case "GET":
getSubmission(w, id)
case "PUT":
updateSubmission(w, r, id)
case "DELETE":
deleteSubmission(w, id)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
func getSubmission(w http.ResponseWriter, id int) {
var s ContactSubmission
var phone, subject sql.NullString
err := db.QueryRow(`
SELECT id, name, email, phone, subject, message, status, created_at, updated_at
FROM contact_submissions WHERE id = $1
`, id).Scan(&s.ID, &s.Name, &s.Email, &phone, &subject, &s.Message, &s.Status, &s.CreatedAt, &s.UpdatedAt)
if err == sql.ErrNoRows {
http.Error(w, "Submission not found", http.StatusNotFound)
return
}
if err != nil {
sendError(w, "Failed to fetch submission", http.StatusInternalServerError)
return
}
if phone.Valid {
s.Phone = phone.String
}
if subject.Valid {
s.Subject = subject.String
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(s)
}
func updateSubmission(w http.ResponseWriter, r *http.Request, id int) {
var update struct {
Status string `json:"status"`
}
if err := json.NewDecoder(r.Body).Decode(&update); err != nil {
sendError(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate status
validStatuses := map[string]bool{"NEW": true, "READ": true, "REPLIED": true, "ARCHIVED": true}
if !validStatuses[update.Status] {
sendError(w, "Invalid status. Must be NEW, READ, REPLIED, or ARCHIVED", http.StatusBadRequest)
return
}
result, err := db.Exec(`
UPDATE contact_submissions SET status = $1, updated_at = NOW() WHERE id = $2
`, update.Status, id)
if err != nil {
sendError(w, "Failed to update submission", http.StatusInternalServerError)
return
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
http.Error(w, "Submission not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(ContactResponse{Success: true, Message: "Submission updated"})
}
func deleteSubmission(w http.ResponseWriter, id int) {
result, err := db.Exec(`DELETE FROM contact_submissions WHERE id = $1`, id)
if err != nil {
sendError(w, "Failed to delete submission", http.StatusInternalServerError)
return
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
http.Error(w, "Submission not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(ContactResponse{Success: true, Message: "Submission deleted"})
}
func sendError(w http.ResponseWriter, message string, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(ContactResponse{Success: false, Message: message})
}
func isValidEmail(email string) bool {
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
return emailRegex.MatchString(email)
}