package main import ( "context" "crypto/aes" "crypto/cipher" "crypto/rand" "database/sql" "encoding/hex" "encoding/json" "fmt" "io" "log" "net/http" "os" "strconv" "strings" "time" "github.com/golang-jwt/jwt/v5" _ "github.com/lib/pq" ) var db *sql.DB var encryptionKey []byte // LLMConfig represents the configuration for an LLM provider type LLMConfig struct { APIKey string `json:"apiKey"` Model string `json:"model,omitempty"` Temperature float64 `json:"temperature,omitempty"` MaxTokens int `json:"maxTokens,omitempty"` } // ChatMessage represents a message in the chat history type ChatMessage struct { ID int `json:"id,omitempty"` Role string `json:"role"` Content string `json:"content"` TokensUsed int `json:"tokensUsed,omitempty"` Timestamp time.Time `json:"timestamp"` } // ChatRequest represents a chat request type ChatRequest struct { Message string `json:"message"` History []ChatMessage `json:"history"` } // ChatResponse represents a chat response type ChatResponse struct { Response string `json:"response"` TokensUsed int `json:"tokensUsed"` Model string `json:"model"` } // ErrorResponse represents an error response type ErrorResponse struct { Error string `json:"error"` Message string `json:"message"` } func main() { // Load encryption key keyHex := os.Getenv("ENCRYPTION_KEY") if keyHex == "" { log.Fatal("ENCRYPTION_KEY environment variable not set") } var err error encryptionKey, err = hex.DecodeString(keyHex) if err != nil || len(encryptionKey) != 32 { log.Fatal("ENCRYPTION_KEY must be a 64-character hex string (32 bytes)") } // Database connection dbHost := os.Getenv("DB_HOST") dbUser := os.Getenv("DB_USER") dbPassword := os.Getenv("DB_PASSWORD") dbName := os.Getenv("DB_NAME") dbSchema := os.Getenv("DB_SCHEMA") if dbSchema == "" { dbSchema = "public" } connStr := fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=disable search_path=%s", dbHost, dbUser, dbPassword, dbName, dbSchema) db, err = sql.Open("postgres", connStr) if err != nil { log.Fatalf("Failed to connect to database: %v", err) } defer db.Close() // Test database connection if err = db.Ping(); err != nil { log.Fatalf("Failed to ping database: %v", err) } log.Println("LLM Service started successfully") // Routes http.HandleFunc("/health", healthHandler) http.HandleFunc("/llm/models", corsMiddleware(getModelsHandler)) http.HandleFunc("/llm/configs", corsMiddleware(authMiddleware(getConfigsHandler))) http.HandleFunc("/llm/config/", corsMiddleware(authMiddleware(configHandler))) http.HandleFunc("/llm/chat/", corsMiddleware(authMiddleware(chatHandler))) http.HandleFunc("/llm/history/", corsMiddleware(authMiddleware(historyHandler))) port := os.Getenv("PORT") if port == "" { port = "8080" } log.Printf("LLM Service listening on port %s", port) log.Fatal(http.ListenAndServe(":"+port, nil)) } // healthHandler returns service health status func healthHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"status": "healthy"}) } // ModelInfo represents information about an available model type ModelInfo struct { ID string `json:"id"` Name string `json:"name"` Description string `json:"description,omitempty"` IsDefault bool `json:"isDefault,omitempty"` } // getModelsHandler returns available models for each provider // It fetches models dynamically from provider APIs where possible func getModelsHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { respondWithError(w, http.StatusMethodNotAllowed, "Method not allowed") return } models := make(map[string][]ModelInfo) // Fetch OpenAI models dynamically models["openai"] = fetchOpenAIModels() // Fetch Anthropic/Claude models dynamically models["claude"] = fetchClaudeModels() // Gemini, Qwen, HuggingFace - use curated lists (APIs require auth or don't have model listing) models["gemini"] = []ModelInfo{ {ID: "gemini-2.0-flash-exp", Name: "Gemini 2.0 Flash", Description: "Latest fast model", IsDefault: true}, {ID: "gemini-1.5-pro", Name: "Gemini 1.5 Pro", Description: "High capability model"}, {ID: "gemini-1.5-flash", Name: "Gemini 1.5 Flash", Description: "Fast and efficient"}, {ID: "gemini-1.0-pro", Name: "Gemini 1.0 Pro", Description: "Stable production model"}, } models["qwen"] = []ModelInfo{ {ID: "qwen-max", Name: "Qwen Max", Description: "Most capable Qwen model", IsDefault: true}, {ID: "qwen-plus", Name: "Qwen Plus", Description: "Balanced performance"}, {ID: "qwen-turbo", Name: "Qwen Turbo", Description: "Fast responses"}, {ID: "qwen-long", Name: "Qwen Long", Description: "Extended context length"}, } models["huggingface"] = []ModelInfo{ {ID: "meta-llama/Llama-3.3-70B-Instruct", Name: "Llama 3.3 70B", Description: "Latest Llama model", IsDefault: true}, {ID: "meta-llama/Llama-3.2-90B-Vision-Instruct", Name: "Llama 3.2 90B Vision", Description: "Multimodal model"}, {ID: "meta-llama/Llama-3.1-70B-Instruct", Name: "Llama 3.1 70B", Description: "Previous generation"}, {ID: "mistralai/Mixtral-8x7B-Instruct-v0.1", Name: "Mixtral 8x7B", Description: "Mixture of experts"}, {ID: "microsoft/Phi-3-medium-128k-instruct", Name: "Phi-3 Medium", Description: "Compact but capable"}, } respondWithJSON(w, http.StatusOK, map[string]interface{}{"models": models}) } // fetchOpenAIModels fetches the list of available models from OpenAI API func fetchOpenAIModels() []ModelInfo { // Fallback models if API call fails fallback := []ModelInfo{ {ID: "gpt-4o", Name: "GPT-4o", Description: "Most capable multimodal model", IsDefault: true}, {ID: "gpt-4o-mini", Name: "GPT-4o Mini", Description: "Fast and affordable"}, {ID: "gpt-4-turbo", Name: "GPT-4 Turbo", Description: "High capability with vision"}, {ID: "gpt-3.5-turbo", Name: "GPT-3.5 Turbo", Description: "Fast and cost-effective"}, } // Try to get API key from environment or a configured user apiKey := os.Getenv("OPENAI_API_KEY") if apiKey == "" { // Try to get from first configured user var encryptedKey string err := db.QueryRow(`SELECT api_key_encrypted FROM llm_configs WHERE provider = 'openai' LIMIT 1`).Scan(&encryptedKey) if err == nil { decrypted, err := decrypt(encryptedKey) if err == nil { apiKey = decrypted } } } if apiKey == "" { return fallback } client := &http.Client{Timeout: 10 * time.Second} req, err := http.NewRequest("GET", "https://api.openai.com/v1/models", nil) if err != nil { log.Printf("Error creating OpenAI request: %v", err) return fallback } req.Header.Set("Authorization", "Bearer "+apiKey) resp, err := client.Do(req) if err != nil { log.Printf("Error fetching OpenAI models: %v", err) return fallback } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { log.Printf("OpenAI API returned status %d", resp.StatusCode) return fallback } var result struct { Data []struct { ID string `json:"id"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` } `json:"data"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { log.Printf("Error decoding OpenAI response: %v", err) return fallback } // Filter to chat models only and sort by preference var models []ModelInfo chatModelPrefixes := []string{"gpt-4", "gpt-3.5", "o1", "o3", "chatgpt"} for _, m := range result.Data { isChatModel := false for _, prefix := range chatModelPrefixes { if strings.HasPrefix(m.ID, prefix) { isChatModel = true break } } if !isChatModel { continue } // Skip specific non-chat variants if strings.Contains(m.ID, "instruct") || strings.Contains(m.ID, "vision") || strings.Contains(m.ID, "audio") || strings.Contains(m.ID, "realtime") || strings.Contains(m.ID, "transcribe") || strings.Contains(m.ID, "tts") { continue } name := formatModelName(m.ID) models = append(models, ModelInfo{ ID: m.ID, Name: name, IsDefault: m.ID == "gpt-4o", }) } // Sort models: gpt-4o first, then gpt-4, then o1/o3, then gpt-3.5 sortModels(models) if len(models) == 0 { return fallback } return models } // fetchClaudeModels fetches the list of available models from Anthropic API func fetchClaudeModels() []ModelInfo { fallback := []ModelInfo{ {ID: "claude-sonnet-4-5-20250929", Name: "Claude Sonnet 4.5", Description: "Latest balanced model", IsDefault: true}, {ID: "claude-opus-4-5-20251101", Name: "Claude Opus 4.5", Description: "Most capable model"}, {ID: "claude-3-5-sonnet-20241022", Name: "Claude 3.5 Sonnet", Description: "Previous generation balanced"}, {ID: "claude-3-5-haiku-20241022", Name: "Claude 3.5 Haiku", Description: "Fast and affordable"}, } apiKey := os.Getenv("ANTHROPIC_API_KEY") if apiKey == "" { var encryptedKey string err := db.QueryRow(`SELECT api_key_encrypted FROM llm_configs WHERE provider = 'claude' LIMIT 1`).Scan(&encryptedKey) if err == nil { decrypted, err := decrypt(encryptedKey) if err == nil { apiKey = decrypted } } } if apiKey == "" { return fallback } client := &http.Client{Timeout: 10 * time.Second} req, err := http.NewRequest("GET", "https://api.anthropic.com/v1/models", nil) if err != nil { log.Printf("Error creating Anthropic request: %v", err) return fallback } req.Header.Set("x-api-key", apiKey) req.Header.Set("anthropic-version", "2023-06-01") resp, err := client.Do(req) if err != nil { log.Printf("Error fetching Claude models: %v", err) return fallback } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { log.Printf("Anthropic API returned status %d", resp.StatusCode) return fallback } var result struct { Data []struct { ID string `json:"id"` DisplayName string `json:"display_name"` Type string `json:"type"` } `json:"data"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { log.Printf("Error decoding Anthropic response: %v", err) return fallback } var models []ModelInfo for _, m := range result.Data { if m.Type != "model" { continue } name := m.DisplayName if name == "" { name = formatModelName(m.ID) } models = append(models, ModelInfo{ ID: m.ID, Name: name, IsDefault: strings.Contains(m.ID, "sonnet") && strings.Contains(m.ID, "4-5"), }) } if len(models) == 0 { return fallback } return models } // formatModelName converts model ID to display name func formatModelName(id string) string { name := strings.ReplaceAll(id, "-", " ") name = strings.ReplaceAll(name, "_", " ") // Capitalize first letter of each word words := strings.Fields(name) for i, word := range words { if len(word) > 0 { words[i] = strings.ToUpper(string(word[0])) + word[1:] } } return strings.Join(words, " ") } // sortModels sorts models by preference func sortModels(models []ModelInfo) { priority := func(id string) int { switch { case strings.HasPrefix(id, "gpt-4o") && !strings.Contains(id, "mini"): return 0 case strings.HasPrefix(id, "gpt-4o-mini"): return 1 case strings.HasPrefix(id, "gpt-4"): return 2 case strings.HasPrefix(id, "o1") || strings.HasPrefix(id, "o3"): return 3 case strings.HasPrefix(id, "gpt-3.5"): return 4 case strings.HasPrefix(id, "chatgpt"): return 5 default: return 6 } } for i := 0; i < len(models)-1; i++ { for j := i + 1; j < len(models); j++ { if priority(models[i].ID) > priority(models[j].ID) { models[i], models[j] = models[j], models[i] } } } } // CORS middleware func corsMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next(w, r) } } // authMiddleware validates JWT tokens and extracts user ID func authMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { respondWithError(w, http.StatusUnauthorized, "Authorization header required") return } tokenString := strings.TrimPrefix(authHeader, "Bearer ") jwtSecret := os.Getenv("JWT_SECRET") token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(jwtSecret), nil }) if err != nil { log.Printf("JWT parse error: %v", err) respondWithError(w, http.StatusUnauthorized, "Invalid token") return } if !token.Valid { log.Printf("JWT token is not valid") respondWithError(w, http.StatusUnauthorized, "Invalid token") return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { respondWithError(w, http.StatusUnauthorized, "Invalid token claims") return } userID, ok := claims["user_id"].(float64) if !ok { respondWithError(w, http.StatusUnauthorized, "Invalid user_id in token") return } // Add user ID to context ctx := context.WithValue(r.Context(), "userID", int(userID)) next(w, r.WithContext(ctx)) } } // getConfigsHandler returns all LLM configurations for the authenticated user func getConfigsHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { respondWithError(w, http.StatusMethodNotAllowed, "Method not allowed") return } userID := r.Context().Value("userID").(int) rows, err := db.Query(` SELECT provider, api_key_encrypted, model, temperature, max_tokens FROM llm_configs WHERE user_id = $1 `, userID) if err != nil { respondWithError(w, http.StatusInternalServerError, "Failed to query configs") return } defer rows.Close() configs := make(map[string]LLMConfig) for rows.Next() { var provider, apiKeyEncrypted, model string var temperature float64 var maxTokens int if err := rows.Scan(&provider, &apiKeyEncrypted, &model, &temperature, &maxTokens); err != nil { respondWithError(w, http.StatusInternalServerError, "Failed to scan config") return } // Mask API key for security (don't return actual key) configs[provider] = LLMConfig{ APIKey: "sk-***", Model: model, Temperature: temperature, MaxTokens: maxTokens, } } respondWithJSON(w, http.StatusOK, map[string]interface{}{"configs": configs}) } // configHandler handles CRUD operations for individual provider configs func configHandler(w http.ResponseWriter, r *http.Request) { userID := r.Context().Value("userID").(int) provider := strings.TrimPrefix(r.URL.Path, "/llm/config/") if provider == "" { respondWithError(w, http.StatusBadRequest, "Provider required") return } switch r.Method { case http.MethodPost: saveConfig(w, r, userID, provider) case http.MethodDelete: deleteConfig(w, r, userID, provider) default: respondWithError(w, http.StatusMethodNotAllowed, "Method not allowed") } } // saveConfig saves or updates an LLM configuration func saveConfig(w http.ResponseWriter, r *http.Request, userID int, provider string) { var config LLMConfig if err := json.NewDecoder(r.Body).Decode(&config); err != nil { respondWithError(w, http.StatusBadRequest, "Invalid request body") return } if config.APIKey == "" { respondWithError(w, http.StatusBadRequest, "API key required") return } // Set defaults if config.Temperature == 0 { config.Temperature = 0.7 } if config.MaxTokens == 0 { config.MaxTokens = 2048 } // Encrypt API key encryptedKey, err := encrypt(config.APIKey) if err != nil { respondWithError(w, http.StatusInternalServerError, "Failed to encrypt API key") return } // Upsert configuration _, err = db.Exec(` INSERT INTO llm_configs (user_id, provider, api_key_encrypted, model, temperature, max_tokens, updated_at) VALUES ($1, $2, $3, $4, $5, $6, NOW()) ON CONFLICT (user_id, provider) DO UPDATE SET api_key_encrypted = EXCLUDED.api_key_encrypted, model = EXCLUDED.model, temperature = EXCLUDED.temperature, max_tokens = EXCLUDED.max_tokens, updated_at = NOW() `, userID, provider, encryptedKey, config.Model, config.Temperature, config.MaxTokens) if err != nil { log.Printf("Failed to save config: %v", err) respondWithError(w, http.StatusInternalServerError, "Failed to save configuration") return } respondWithJSON(w, http.StatusOK, map[string]interface{}{ "success": true, "message": "Configuration saved", }) } // deleteConfig deletes an LLM configuration func deleteConfig(w http.ResponseWriter, r *http.Request, userID int, provider string) { result, err := db.Exec(`DELETE FROM llm_configs WHERE user_id = $1 AND provider = $2`, userID, provider) if err != nil { respondWithError(w, http.StatusInternalServerError, "Failed to delete configuration") return } rowsAffected, _ := result.RowsAffected() if rowsAffected == 0 { respondWithError(w, http.StatusNotFound, "Configuration not found") return } // Also delete chat history for this provider db.Exec(`DELETE FROM llm_chat_history WHERE user_id = $1 AND provider = $2`, userID, provider) respondWithJSON(w, http.StatusOK, map[string]interface{}{ "success": true, "message": "Configuration deleted", }) } // chatHandler handles chat requests to LLM providers func chatHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { respondWithError(w, http.StatusMethodNotAllowed, "Method not allowed") return } userID := r.Context().Value("userID").(int) provider := strings.TrimPrefix(r.URL.Path, "/llm/chat/") if provider == "" { respondWithError(w, http.StatusBadRequest, "Provider required") return } var chatReq ChatRequest if err := json.NewDecoder(r.Body).Decode(&chatReq); err != nil { respondWithError(w, http.StatusBadRequest, "Invalid request body") return } if chatReq.Message == "" { respondWithError(w, http.StatusBadRequest, "Message required") return } // Get provider configuration var apiKeyEncrypted, model string var temperature float64 var maxTokens int err := db.QueryRow(` SELECT api_key_encrypted, model, temperature, max_tokens FROM llm_configs WHERE user_id = $1 AND provider = $2 `, userID, provider).Scan(&apiKeyEncrypted, &model, &temperature, &maxTokens) if err == sql.ErrNoRows { respondWithError(w, http.StatusNotFound, "Provider not configured") return } if err != nil { respondWithError(w, http.StatusInternalServerError, "Failed to get configuration") return } // Decrypt API key apiKey, err := decrypt(apiKeyEncrypted) if err != nil { respondWithError(w, http.StatusInternalServerError, "Failed to decrypt API key") return } // Save user message to history db.Exec(` INSERT INTO llm_chat_history (user_id, provider, role, content, created_at) VALUES ($1, $2, 'user', $3, NOW()) `, userID, provider, chatReq.Message) // Call appropriate LLM provider var response string var tokensUsed int switch provider { case "openai": response, tokensUsed, err = callOpenAI(apiKey, model, chatReq.Message, chatReq.History, temperature, maxTokens) case "gemini": response, tokensUsed, err = callGemini(apiKey, model, chatReq.Message, chatReq.History, temperature, maxTokens) case "claude": response, tokensUsed, err = callClaude(apiKey, model, chatReq.Message, chatReq.History, temperature, maxTokens) case "qwen": response, tokensUsed, err = callQwen(apiKey, model, chatReq.Message, chatReq.History, temperature, maxTokens) case "huggingface": response, tokensUsed, err = callHuggingFace(apiKey, model, chatReq.Message, chatReq.History, temperature, maxTokens) default: respondWithError(w, http.StatusBadRequest, "Unsupported provider") return } if err != nil { log.Printf("LLM call failed: %v", err) respondWithError(w, http.StatusInternalServerError, fmt.Sprintf("LLM call failed: %v", err)) return } // Save assistant response to history db.Exec(` INSERT INTO llm_chat_history (user_id, provider, role, content, tokens_used, created_at) VALUES ($1, $2, 'assistant', $3, $4, NOW()) `, userID, provider, response, tokensUsed) respondWithJSON(w, http.StatusOK, ChatResponse{ Response: response, TokensUsed: tokensUsed, Model: model, }) } // historyHandler handles chat history operations func historyHandler(w http.ResponseWriter, r *http.Request) { userID := r.Context().Value("userID").(int) provider := strings.TrimPrefix(r.URL.Path, "/llm/history/") if provider == "" { respondWithError(w, http.StatusBadRequest, "Provider required") return } switch r.Method { case http.MethodGet: getHistory(w, r, userID, provider) case http.MethodDelete: clearHistory(w, r, userID, provider) default: respondWithError(w, http.StatusMethodNotAllowed, "Method not allowed") } } // getHistory returns chat history for a provider func getHistory(w http.ResponseWriter, r *http.Request, userID int, provider string) { limit := 50 offset := 0 if limitStr := r.URL.Query().Get("limit"); limitStr != "" { if l, err := strconv.Atoi(limitStr); err == nil { limit = l } } if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { if o, err := strconv.Atoi(offsetStr); err == nil { offset = o } } rows, err := db.Query(` SELECT id, role, content, tokens_used, created_at FROM llm_chat_history WHERE user_id = $1 AND provider = $2 ORDER BY created_at DESC LIMIT $3 OFFSET $4 `, userID, provider, limit, offset) if err != nil { respondWithError(w, http.StatusInternalServerError, "Failed to get history") return } defer rows.Close() var history []ChatMessage for rows.Next() { var msg ChatMessage var tokensUsed sql.NullInt64 if err := rows.Scan(&msg.ID, &msg.Role, &msg.Content, &tokensUsed, &msg.Timestamp); err != nil { respondWithError(w, http.StatusInternalServerError, "Failed to scan history") return } if tokensUsed.Valid { msg.TokensUsed = int(tokensUsed.Int64) } history = append(history, msg) } // Get total count var total int db.QueryRow(`SELECT COUNT(*) FROM llm_chat_history WHERE user_id = $1 AND provider = $2`, userID, provider).Scan(&total) respondWithJSON(w, http.StatusOK, map[string]interface{}{ "history": history, "total": total, }) } // clearHistory deletes all chat history for a provider func clearHistory(w http.ResponseWriter, r *http.Request, userID int, provider string) { _, err := db.Exec(`DELETE FROM llm_chat_history WHERE user_id = $1 AND provider = $2`, userID, provider) if err != nil { respondWithError(w, http.StatusInternalServerError, "Failed to clear history") return } respondWithJSON(w, http.StatusOK, map[string]interface{}{ "success": true, "message": "Chat history cleared", }) } // Encryption/Decryption functions func encrypt(plaintext string) (string, error) { block, err := aes.NewCipher(encryptionKey) if err != nil { return "", err } gcm, err := cipher.NewGCM(block) if err != nil { return "", err } nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return "", err } ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) return hex.EncodeToString(ciphertext), nil } func decrypt(ciphertext string) (string, error) { data, err := hex.DecodeString(ciphertext) if err != nil { return "", err } block, err := aes.NewCipher(encryptionKey) if err != nil { return "", err } gcm, err := cipher.NewGCM(block) if err != nil { return "", err } nonceSize := gcm.NonceSize() if len(data) < nonceSize { return "", fmt.Errorf("ciphertext too short") } nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:] plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil) if err != nil { return "", err } return string(plaintext), nil } // Helper functions func respondWithJSON(w http.ResponseWriter, code int, payload interface{}) { response, _ := json.Marshal(payload) w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) w.Write(response) } func respondWithError(w http.ResponseWriter, code int, message string) { respondWithJSON(w, code, ErrorResponse{ Error: http.StatusText(code), Message: message, }) }