860 lines
24 KiB
Go
860 lines
24 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
var db *sql.DB
|
|
var encryptionKey []byte
|
|
|
|
// LLMConfig represents the configuration for an LLM provider
|
|
type LLMConfig struct {
|
|
APIKey string `json:"apiKey"`
|
|
Model string `json:"model,omitempty"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
MaxTokens int `json:"maxTokens,omitempty"`
|
|
}
|
|
|
|
// ChatMessage represents a message in the chat history
|
|
type ChatMessage struct {
|
|
ID int `json:"id,omitempty"`
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
TokensUsed int `json:"tokensUsed,omitempty"`
|
|
Timestamp time.Time `json:"timestamp"`
|
|
}
|
|
|
|
// ChatRequest represents a chat request
|
|
type ChatRequest struct {
|
|
Message string `json:"message"`
|
|
History []ChatMessage `json:"history"`
|
|
}
|
|
|
|
// ChatResponse represents a chat response
|
|
type ChatResponse struct {
|
|
Response string `json:"response"`
|
|
TokensUsed int `json:"tokensUsed"`
|
|
Model string `json:"model"`
|
|
}
|
|
|
|
// ErrorResponse represents an error response
|
|
type ErrorResponse struct {
|
|
Error string `json:"error"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
func main() {
|
|
// Load encryption key
|
|
keyHex := os.Getenv("ENCRYPTION_KEY")
|
|
if keyHex == "" {
|
|
log.Fatal("ENCRYPTION_KEY environment variable not set")
|
|
}
|
|
var err error
|
|
encryptionKey, err = hex.DecodeString(keyHex)
|
|
if err != nil || len(encryptionKey) != 32 {
|
|
log.Fatal("ENCRYPTION_KEY must be a 64-character hex string (32 bytes)")
|
|
}
|
|
|
|
// Database connection
|
|
dbHost := os.Getenv("DB_HOST")
|
|
dbUser := os.Getenv("DB_USER")
|
|
dbPassword := os.Getenv("DB_PASSWORD")
|
|
dbName := os.Getenv("DB_NAME")
|
|
dbSchema := os.Getenv("DB_SCHEMA")
|
|
if dbSchema == "" {
|
|
dbSchema = "public"
|
|
}
|
|
|
|
connStr := fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=disable search_path=%s",
|
|
dbHost, dbUser, dbPassword, dbName, dbSchema)
|
|
|
|
db, err = sql.Open("postgres", connStr)
|
|
if err != nil {
|
|
log.Fatalf("Failed to connect to database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Test database connection
|
|
if err = db.Ping(); err != nil {
|
|
log.Fatalf("Failed to ping database: %v", err)
|
|
}
|
|
|
|
log.Println("LLM Service started successfully")
|
|
|
|
// Routes
|
|
http.HandleFunc("/health", healthHandler)
|
|
http.HandleFunc("/llm/models", corsMiddleware(getModelsHandler))
|
|
http.HandleFunc("/llm/configs", corsMiddleware(authMiddleware(getConfigsHandler)))
|
|
http.HandleFunc("/llm/config/", corsMiddleware(authMiddleware(configHandler)))
|
|
http.HandleFunc("/llm/chat/", corsMiddleware(authMiddleware(chatHandler)))
|
|
http.HandleFunc("/llm/history/", corsMiddleware(authMiddleware(historyHandler)))
|
|
|
|
port := os.Getenv("PORT")
|
|
if port == "" {
|
|
port = "8080"
|
|
}
|
|
|
|
log.Printf("LLM Service listening on port %s", port)
|
|
log.Fatal(http.ListenAndServe(":"+port, nil))
|
|
}
|
|
|
|
// healthHandler returns service health status
|
|
func healthHandler(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{"status": "healthy"})
|
|
}
|
|
|
|
// ModelInfo represents information about an available model
|
|
type ModelInfo struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
IsDefault bool `json:"isDefault,omitempty"`
|
|
}
|
|
|
|
// getModelsHandler returns available models for each provider
|
|
// It fetches models dynamically from provider APIs where possible
|
|
func getModelsHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
respondWithError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
return
|
|
}
|
|
|
|
models := make(map[string][]ModelInfo)
|
|
|
|
// Fetch OpenAI models dynamically
|
|
models["openai"] = fetchOpenAIModels()
|
|
|
|
// Fetch Anthropic/Claude models dynamically
|
|
models["claude"] = fetchClaudeModels()
|
|
|
|
// Gemini, Qwen, HuggingFace - use curated lists (APIs require auth or don't have model listing)
|
|
models["gemini"] = []ModelInfo{
|
|
{ID: "gemini-2.0-flash-exp", Name: "Gemini 2.0 Flash", Description: "Latest fast model", IsDefault: true},
|
|
{ID: "gemini-1.5-pro", Name: "Gemini 1.5 Pro", Description: "High capability model"},
|
|
{ID: "gemini-1.5-flash", Name: "Gemini 1.5 Flash", Description: "Fast and efficient"},
|
|
{ID: "gemini-1.0-pro", Name: "Gemini 1.0 Pro", Description: "Stable production model"},
|
|
}
|
|
|
|
models["qwen"] = []ModelInfo{
|
|
{ID: "qwen-max", Name: "Qwen Max", Description: "Most capable Qwen model", IsDefault: true},
|
|
{ID: "qwen-plus", Name: "Qwen Plus", Description: "Balanced performance"},
|
|
{ID: "qwen-turbo", Name: "Qwen Turbo", Description: "Fast responses"},
|
|
{ID: "qwen-long", Name: "Qwen Long", Description: "Extended context length"},
|
|
}
|
|
|
|
models["huggingface"] = []ModelInfo{
|
|
{ID: "meta-llama/Llama-3.3-70B-Instruct", Name: "Llama 3.3 70B", Description: "Latest Llama model", IsDefault: true},
|
|
{ID: "meta-llama/Llama-3.2-90B-Vision-Instruct", Name: "Llama 3.2 90B Vision", Description: "Multimodal model"},
|
|
{ID: "meta-llama/Llama-3.1-70B-Instruct", Name: "Llama 3.1 70B", Description: "Previous generation"},
|
|
{ID: "mistralai/Mixtral-8x7B-Instruct-v0.1", Name: "Mixtral 8x7B", Description: "Mixture of experts"},
|
|
{ID: "microsoft/Phi-3-medium-128k-instruct", Name: "Phi-3 Medium", Description: "Compact but capable"},
|
|
}
|
|
|
|
respondWithJSON(w, http.StatusOK, map[string]interface{}{"models": models})
|
|
}
|
|
|
|
// fetchOpenAIModels fetches the list of available models from OpenAI API
|
|
func fetchOpenAIModels() []ModelInfo {
|
|
// Fallback models if API call fails
|
|
fallback := []ModelInfo{
|
|
{ID: "gpt-4o", Name: "GPT-4o", Description: "Most capable multimodal model", IsDefault: true},
|
|
{ID: "gpt-4o-mini", Name: "GPT-4o Mini", Description: "Fast and affordable"},
|
|
{ID: "gpt-4-turbo", Name: "GPT-4 Turbo", Description: "High capability with vision"},
|
|
{ID: "gpt-3.5-turbo", Name: "GPT-3.5 Turbo", Description: "Fast and cost-effective"},
|
|
}
|
|
|
|
// Try to get API key from environment or a configured user
|
|
apiKey := os.Getenv("OPENAI_API_KEY")
|
|
if apiKey == "" {
|
|
// Try to get from first configured user
|
|
var encryptedKey string
|
|
err := db.QueryRow(`SELECT api_key_encrypted FROM llm_configs WHERE provider = 'openai' LIMIT 1`).Scan(&encryptedKey)
|
|
if err == nil {
|
|
decrypted, err := decrypt(encryptedKey)
|
|
if err == nil {
|
|
apiKey = decrypted
|
|
}
|
|
}
|
|
}
|
|
|
|
if apiKey == "" {
|
|
return fallback
|
|
}
|
|
|
|
client := &http.Client{Timeout: 10 * time.Second}
|
|
req, err := http.NewRequest("GET", "https://api.openai.com/v1/models", nil)
|
|
if err != nil {
|
|
log.Printf("Error creating OpenAI request: %v", err)
|
|
return fallback
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
log.Printf("Error fetching OpenAI models: %v", err)
|
|
return fallback
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
log.Printf("OpenAI API returned status %d", resp.StatusCode)
|
|
return fallback
|
|
}
|
|
|
|
var result struct {
|
|
Data []struct {
|
|
ID string `json:"id"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
log.Printf("Error decoding OpenAI response: %v", err)
|
|
return fallback
|
|
}
|
|
|
|
// Filter to chat models only and sort by preference
|
|
var models []ModelInfo
|
|
chatModelPrefixes := []string{"gpt-4", "gpt-3.5", "o1", "o3", "chatgpt"}
|
|
|
|
for _, m := range result.Data {
|
|
isChatModel := false
|
|
for _, prefix := range chatModelPrefixes {
|
|
if strings.HasPrefix(m.ID, prefix) {
|
|
isChatModel = true
|
|
break
|
|
}
|
|
}
|
|
if !isChatModel {
|
|
continue
|
|
}
|
|
|
|
// Skip specific non-chat variants
|
|
if strings.Contains(m.ID, "instruct") || strings.Contains(m.ID, "vision") ||
|
|
strings.Contains(m.ID, "audio") || strings.Contains(m.ID, "realtime") ||
|
|
strings.Contains(m.ID, "transcribe") || strings.Contains(m.ID, "tts") {
|
|
continue
|
|
}
|
|
|
|
name := formatModelName(m.ID)
|
|
models = append(models, ModelInfo{
|
|
ID: m.ID,
|
|
Name: name,
|
|
IsDefault: m.ID == "gpt-4o",
|
|
})
|
|
}
|
|
|
|
// Sort models: gpt-4o first, then gpt-4, then o1/o3, then gpt-3.5
|
|
sortModels(models)
|
|
|
|
if len(models) == 0 {
|
|
return fallback
|
|
}
|
|
|
|
return models
|
|
}
|
|
|
|
// fetchClaudeModels fetches the list of available models from Anthropic API
|
|
func fetchClaudeModels() []ModelInfo {
|
|
fallback := []ModelInfo{
|
|
{ID: "claude-sonnet-4-5-20250929", Name: "Claude Sonnet 4.5", Description: "Latest balanced model", IsDefault: true},
|
|
{ID: "claude-opus-4-5-20251101", Name: "Claude Opus 4.5", Description: "Most capable model"},
|
|
{ID: "claude-3-5-sonnet-20241022", Name: "Claude 3.5 Sonnet", Description: "Previous generation balanced"},
|
|
{ID: "claude-3-5-haiku-20241022", Name: "Claude 3.5 Haiku", Description: "Fast and affordable"},
|
|
}
|
|
|
|
apiKey := os.Getenv("ANTHROPIC_API_KEY")
|
|
if apiKey == "" {
|
|
var encryptedKey string
|
|
err := db.QueryRow(`SELECT api_key_encrypted FROM llm_configs WHERE provider = 'claude' LIMIT 1`).Scan(&encryptedKey)
|
|
if err == nil {
|
|
decrypted, err := decrypt(encryptedKey)
|
|
if err == nil {
|
|
apiKey = decrypted
|
|
}
|
|
}
|
|
}
|
|
|
|
if apiKey == "" {
|
|
return fallback
|
|
}
|
|
|
|
client := &http.Client{Timeout: 10 * time.Second}
|
|
req, err := http.NewRequest("GET", "https://api.anthropic.com/v1/models", nil)
|
|
if err != nil {
|
|
log.Printf("Error creating Anthropic request: %v", err)
|
|
return fallback
|
|
}
|
|
req.Header.Set("x-api-key", apiKey)
|
|
req.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
log.Printf("Error fetching Claude models: %v", err)
|
|
return fallback
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
log.Printf("Anthropic API returned status %d", resp.StatusCode)
|
|
return fallback
|
|
}
|
|
|
|
var result struct {
|
|
Data []struct {
|
|
ID string `json:"id"`
|
|
DisplayName string `json:"display_name"`
|
|
Type string `json:"type"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
log.Printf("Error decoding Anthropic response: %v", err)
|
|
return fallback
|
|
}
|
|
|
|
var models []ModelInfo
|
|
for _, m := range result.Data {
|
|
if m.Type != "model" {
|
|
continue
|
|
}
|
|
name := m.DisplayName
|
|
if name == "" {
|
|
name = formatModelName(m.ID)
|
|
}
|
|
models = append(models, ModelInfo{
|
|
ID: m.ID,
|
|
Name: name,
|
|
IsDefault: strings.Contains(m.ID, "sonnet") && strings.Contains(m.ID, "4-5"),
|
|
})
|
|
}
|
|
|
|
if len(models) == 0 {
|
|
return fallback
|
|
}
|
|
|
|
return models
|
|
}
|
|
|
|
// formatModelName converts model ID to display name
|
|
func formatModelName(id string) string {
|
|
name := strings.ReplaceAll(id, "-", " ")
|
|
name = strings.ReplaceAll(name, "_", " ")
|
|
// Capitalize first letter of each word
|
|
words := strings.Fields(name)
|
|
for i, word := range words {
|
|
if len(word) > 0 {
|
|
words[i] = strings.ToUpper(string(word[0])) + word[1:]
|
|
}
|
|
}
|
|
return strings.Join(words, " ")
|
|
}
|
|
|
|
// sortModels sorts models by preference
|
|
func sortModels(models []ModelInfo) {
|
|
priority := func(id string) int {
|
|
switch {
|
|
case strings.HasPrefix(id, "gpt-4o") && !strings.Contains(id, "mini"):
|
|
return 0
|
|
case strings.HasPrefix(id, "gpt-4o-mini"):
|
|
return 1
|
|
case strings.HasPrefix(id, "gpt-4"):
|
|
return 2
|
|
case strings.HasPrefix(id, "o1") || strings.HasPrefix(id, "o3"):
|
|
return 3
|
|
case strings.HasPrefix(id, "gpt-3.5"):
|
|
return 4
|
|
case strings.HasPrefix(id, "chatgpt"):
|
|
return 5
|
|
default:
|
|
return 6
|
|
}
|
|
}
|
|
|
|
for i := 0; i < len(models)-1; i++ {
|
|
for j := i + 1; j < len(models); j++ {
|
|
if priority(models[i].ID) > priority(models[j].ID) {
|
|
models[i], models[j] = models[j], models[i]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// CORS middleware
|
|
func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
|
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
next(w, r)
|
|
}
|
|
}
|
|
|
|
// authMiddleware validates JWT tokens and extracts user ID
|
|
func authMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
respondWithError(w, http.StatusUnauthorized, "Authorization header required")
|
|
return
|
|
}
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
jwtSecret := os.Getenv("JWT_SECRET")
|
|
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(jwtSecret), nil
|
|
})
|
|
|
|
if err != nil {
|
|
log.Printf("JWT parse error: %v", err)
|
|
respondWithError(w, http.StatusUnauthorized, "Invalid token")
|
|
return
|
|
}
|
|
if !token.Valid {
|
|
log.Printf("JWT token is not valid")
|
|
respondWithError(w, http.StatusUnauthorized, "Invalid token")
|
|
return
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
respondWithError(w, http.StatusUnauthorized, "Invalid token claims")
|
|
return
|
|
}
|
|
|
|
userID, ok := claims["user_id"].(float64)
|
|
if !ok {
|
|
respondWithError(w, http.StatusUnauthorized, "Invalid user_id in token")
|
|
return
|
|
}
|
|
|
|
// Add user ID to context
|
|
ctx := context.WithValue(r.Context(), "userID", int(userID))
|
|
next(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
// getConfigsHandler returns all LLM configurations for the authenticated user
|
|
func getConfigsHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
respondWithError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
return
|
|
}
|
|
|
|
userID := r.Context().Value("userID").(int)
|
|
|
|
rows, err := db.Query(`
|
|
SELECT provider, api_key_encrypted, model, temperature, max_tokens
|
|
FROM llm_configs
|
|
WHERE user_id = $1
|
|
`, userID)
|
|
if err != nil {
|
|
respondWithError(w, http.StatusInternalServerError, "Failed to query configs")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
configs := make(map[string]LLMConfig)
|
|
|
|
for rows.Next() {
|
|
var provider, apiKeyEncrypted, model string
|
|
var temperature float64
|
|
var maxTokens int
|
|
|
|
if err := rows.Scan(&provider, &apiKeyEncrypted, &model, &temperature, &maxTokens); err != nil {
|
|
respondWithError(w, http.StatusInternalServerError, "Failed to scan config")
|
|
return
|
|
}
|
|
|
|
// Mask API key for security (don't return actual key)
|
|
configs[provider] = LLMConfig{
|
|
APIKey: "sk-***",
|
|
Model: model,
|
|
Temperature: temperature,
|
|
MaxTokens: maxTokens,
|
|
}
|
|
}
|
|
|
|
respondWithJSON(w, http.StatusOK, map[string]interface{}{"configs": configs})
|
|
}
|
|
|
|
// configHandler handles CRUD operations for individual provider configs
|
|
func configHandler(w http.ResponseWriter, r *http.Request) {
|
|
userID := r.Context().Value("userID").(int)
|
|
provider := strings.TrimPrefix(r.URL.Path, "/llm/config/")
|
|
|
|
if provider == "" {
|
|
respondWithError(w, http.StatusBadRequest, "Provider required")
|
|
return
|
|
}
|
|
|
|
switch r.Method {
|
|
case http.MethodPost:
|
|
saveConfig(w, r, userID, provider)
|
|
case http.MethodDelete:
|
|
deleteConfig(w, r, userID, provider)
|
|
default:
|
|
respondWithError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
}
|
|
|
|
// saveConfig saves or updates an LLM configuration
|
|
func saveConfig(w http.ResponseWriter, r *http.Request, userID int, provider string) {
|
|
var config LLMConfig
|
|
if err := json.NewDecoder(r.Body).Decode(&config); err != nil {
|
|
respondWithError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if config.APIKey == "" {
|
|
respondWithError(w, http.StatusBadRequest, "API key required")
|
|
return
|
|
}
|
|
|
|
// Set defaults
|
|
if config.Temperature == 0 {
|
|
config.Temperature = 0.7
|
|
}
|
|
if config.MaxTokens == 0 {
|
|
config.MaxTokens = 2048
|
|
}
|
|
|
|
// Encrypt API key
|
|
encryptedKey, err := encrypt(config.APIKey)
|
|
if err != nil {
|
|
respondWithError(w, http.StatusInternalServerError, "Failed to encrypt API key")
|
|
return
|
|
}
|
|
|
|
// Upsert configuration
|
|
_, err = db.Exec(`
|
|
INSERT INTO llm_configs (user_id, provider, api_key_encrypted, model, temperature, max_tokens, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, NOW())
|
|
ON CONFLICT (user_id, provider)
|
|
DO UPDATE SET
|
|
api_key_encrypted = EXCLUDED.api_key_encrypted,
|
|
model = EXCLUDED.model,
|
|
temperature = EXCLUDED.temperature,
|
|
max_tokens = EXCLUDED.max_tokens,
|
|
updated_at = NOW()
|
|
`, userID, provider, encryptedKey, config.Model, config.Temperature, config.MaxTokens)
|
|
|
|
if err != nil {
|
|
log.Printf("Failed to save config: %v", err)
|
|
respondWithError(w, http.StatusInternalServerError, "Failed to save configuration")
|
|
return
|
|
}
|
|
|
|
respondWithJSON(w, http.StatusOK, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Configuration saved",
|
|
})
|
|
}
|
|
|
|
// deleteConfig deletes an LLM configuration
|
|
func deleteConfig(w http.ResponseWriter, r *http.Request, userID int, provider string) {
|
|
result, err := db.Exec(`DELETE FROM llm_configs WHERE user_id = $1 AND provider = $2`, userID, provider)
|
|
if err != nil {
|
|
respondWithError(w, http.StatusInternalServerError, "Failed to delete configuration")
|
|
return
|
|
}
|
|
|
|
rowsAffected, _ := result.RowsAffected()
|
|
if rowsAffected == 0 {
|
|
respondWithError(w, http.StatusNotFound, "Configuration not found")
|
|
return
|
|
}
|
|
|
|
// Also delete chat history for this provider
|
|
db.Exec(`DELETE FROM llm_chat_history WHERE user_id = $1 AND provider = $2`, userID, provider)
|
|
|
|
respondWithJSON(w, http.StatusOK, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Configuration deleted",
|
|
})
|
|
}
|
|
|
|
// chatHandler handles chat requests to LLM providers
|
|
func chatHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
respondWithError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
return
|
|
}
|
|
|
|
userID := r.Context().Value("userID").(int)
|
|
provider := strings.TrimPrefix(r.URL.Path, "/llm/chat/")
|
|
|
|
if provider == "" {
|
|
respondWithError(w, http.StatusBadRequest, "Provider required")
|
|
return
|
|
}
|
|
|
|
var chatReq ChatRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&chatReq); err != nil {
|
|
respondWithError(w, http.StatusBadRequest, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
if chatReq.Message == "" {
|
|
respondWithError(w, http.StatusBadRequest, "Message required")
|
|
return
|
|
}
|
|
|
|
// Get provider configuration
|
|
var apiKeyEncrypted, model string
|
|
var temperature float64
|
|
var maxTokens int
|
|
|
|
err := db.QueryRow(`
|
|
SELECT api_key_encrypted, model, temperature, max_tokens
|
|
FROM llm_configs
|
|
WHERE user_id = $1 AND provider = $2
|
|
`, userID, provider).Scan(&apiKeyEncrypted, &model, &temperature, &maxTokens)
|
|
|
|
if err == sql.ErrNoRows {
|
|
respondWithError(w, http.StatusNotFound, "Provider not configured")
|
|
return
|
|
}
|
|
if err != nil {
|
|
respondWithError(w, http.StatusInternalServerError, "Failed to get configuration")
|
|
return
|
|
}
|
|
|
|
// Decrypt API key
|
|
apiKey, err := decrypt(apiKeyEncrypted)
|
|
if err != nil {
|
|
respondWithError(w, http.StatusInternalServerError, "Failed to decrypt API key")
|
|
return
|
|
}
|
|
|
|
// Save user message to history
|
|
db.Exec(`
|
|
INSERT INTO llm_chat_history (user_id, provider, role, content, created_at)
|
|
VALUES ($1, $2, 'user', $3, NOW())
|
|
`, userID, provider, chatReq.Message)
|
|
|
|
// Call appropriate LLM provider
|
|
var response string
|
|
var tokensUsed int
|
|
|
|
switch provider {
|
|
case "openai":
|
|
response, tokensUsed, err = callOpenAI(apiKey, model, chatReq.Message, chatReq.History, temperature, maxTokens)
|
|
case "gemini":
|
|
response, tokensUsed, err = callGemini(apiKey, model, chatReq.Message, chatReq.History, temperature, maxTokens)
|
|
case "claude":
|
|
response, tokensUsed, err = callClaude(apiKey, model, chatReq.Message, chatReq.History, temperature, maxTokens)
|
|
case "qwen":
|
|
response, tokensUsed, err = callQwen(apiKey, model, chatReq.Message, chatReq.History, temperature, maxTokens)
|
|
case "huggingface":
|
|
response, tokensUsed, err = callHuggingFace(apiKey, model, chatReq.Message, chatReq.History, temperature, maxTokens)
|
|
default:
|
|
respondWithError(w, http.StatusBadRequest, "Unsupported provider")
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
log.Printf("LLM call failed: %v", err)
|
|
respondWithError(w, http.StatusInternalServerError, fmt.Sprintf("LLM call failed: %v", err))
|
|
return
|
|
}
|
|
|
|
// Save assistant response to history
|
|
db.Exec(`
|
|
INSERT INTO llm_chat_history (user_id, provider, role, content, tokens_used, created_at)
|
|
VALUES ($1, $2, 'assistant', $3, $4, NOW())
|
|
`, userID, provider, response, tokensUsed)
|
|
|
|
respondWithJSON(w, http.StatusOK, ChatResponse{
|
|
Response: response,
|
|
TokensUsed: tokensUsed,
|
|
Model: model,
|
|
})
|
|
}
|
|
|
|
// historyHandler handles chat history operations
|
|
func historyHandler(w http.ResponseWriter, r *http.Request) {
|
|
userID := r.Context().Value("userID").(int)
|
|
provider := strings.TrimPrefix(r.URL.Path, "/llm/history/")
|
|
|
|
if provider == "" {
|
|
respondWithError(w, http.StatusBadRequest, "Provider required")
|
|
return
|
|
}
|
|
|
|
switch r.Method {
|
|
case http.MethodGet:
|
|
getHistory(w, r, userID, provider)
|
|
case http.MethodDelete:
|
|
clearHistory(w, r, userID, provider)
|
|
default:
|
|
respondWithError(w, http.StatusMethodNotAllowed, "Method not allowed")
|
|
}
|
|
}
|
|
|
|
// getHistory returns chat history for a provider
|
|
func getHistory(w http.ResponseWriter, r *http.Request, userID int, provider string) {
|
|
limit := 50
|
|
offset := 0
|
|
|
|
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
|
if l, err := strconv.Atoi(limitStr); err == nil {
|
|
limit = l
|
|
}
|
|
}
|
|
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
|
if o, err := strconv.Atoi(offsetStr); err == nil {
|
|
offset = o
|
|
}
|
|
}
|
|
|
|
rows, err := db.Query(`
|
|
SELECT id, role, content, tokens_used, created_at
|
|
FROM llm_chat_history
|
|
WHERE user_id = $1 AND provider = $2
|
|
ORDER BY created_at DESC
|
|
LIMIT $3 OFFSET $4
|
|
`, userID, provider, limit, offset)
|
|
|
|
if err != nil {
|
|
respondWithError(w, http.StatusInternalServerError, "Failed to get history")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var history []ChatMessage
|
|
for rows.Next() {
|
|
var msg ChatMessage
|
|
var tokensUsed sql.NullInt64
|
|
|
|
if err := rows.Scan(&msg.ID, &msg.Role, &msg.Content, &tokensUsed, &msg.Timestamp); err != nil {
|
|
respondWithError(w, http.StatusInternalServerError, "Failed to scan history")
|
|
return
|
|
}
|
|
|
|
if tokensUsed.Valid {
|
|
msg.TokensUsed = int(tokensUsed.Int64)
|
|
}
|
|
|
|
history = append(history, msg)
|
|
}
|
|
|
|
// Get total count
|
|
var total int
|
|
db.QueryRow(`SELECT COUNT(*) FROM llm_chat_history WHERE user_id = $1 AND provider = $2`, userID, provider).Scan(&total)
|
|
|
|
respondWithJSON(w, http.StatusOK, map[string]interface{}{
|
|
"history": history,
|
|
"total": total,
|
|
})
|
|
}
|
|
|
|
// clearHistory deletes all chat history for a provider
|
|
func clearHistory(w http.ResponseWriter, r *http.Request, userID int, provider string) {
|
|
_, err := db.Exec(`DELETE FROM llm_chat_history WHERE user_id = $1 AND provider = $2`, userID, provider)
|
|
if err != nil {
|
|
respondWithError(w, http.StatusInternalServerError, "Failed to clear history")
|
|
return
|
|
}
|
|
|
|
respondWithJSON(w, http.StatusOK, map[string]interface{}{
|
|
"success": true,
|
|
"message": "Chat history cleared",
|
|
})
|
|
}
|
|
|
|
// Encryption/Decryption functions
|
|
func encrypt(plaintext string) (string, error) {
|
|
block, err := aes.NewCipher(encryptionKey)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
nonce := make([]byte, gcm.NonceSize())
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
|
return hex.EncodeToString(ciphertext), nil
|
|
}
|
|
|
|
func decrypt(ciphertext string) (string, error) {
|
|
data, err := hex.DecodeString(ciphertext)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
block, err := aes.NewCipher(encryptionKey)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
nonceSize := gcm.NonceSize()
|
|
if len(data) < nonceSize {
|
|
return "", fmt.Errorf("ciphertext too short")
|
|
}
|
|
|
|
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
|
|
plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return string(plaintext), nil
|
|
}
|
|
|
|
// Helper functions
|
|
func respondWithJSON(w http.ResponseWriter, code int, payload interface{}) {
|
|
response, _ := json.Marshal(payload)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(code)
|
|
w.Write(response)
|
|
}
|
|
|
|
func respondWithError(w http.ResponseWriter, code int, message string) {
|
|
respondWithJSON(w, code, ErrorResponse{
|
|
Error: http.StatusText(code),
|
|
Message: message,
|
|
})
|
|
}
|