441 lines
11 KiB
Go
441 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
// callOpenAI sends a message to OpenAI's API
|
|
func callOpenAI(apiKey, model, message string, history []ChatMessage, temperature float64, maxTokens int) (string, int, error) {
|
|
if model == "" {
|
|
model = "gpt-4o"
|
|
}
|
|
|
|
client := openai.NewClient(apiKey)
|
|
|
|
// Build messages array
|
|
messages := []openai.ChatCompletionMessage{}
|
|
|
|
// Add history
|
|
for _, msg := range history {
|
|
role := msg.Role
|
|
if role == "assistant" {
|
|
role = openai.ChatMessageRoleAssistant
|
|
} else {
|
|
role = openai.ChatMessageRoleUser
|
|
}
|
|
messages = append(messages, openai.ChatCompletionMessage{
|
|
Role: role,
|
|
Content: msg.Content,
|
|
})
|
|
}
|
|
|
|
// Add current message
|
|
messages = append(messages, openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: message,
|
|
})
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
|
Model: model,
|
|
Messages: messages,
|
|
Temperature: float32(temperature),
|
|
MaxTokens: maxTokens,
|
|
})
|
|
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("openai api error: %w", err)
|
|
}
|
|
|
|
if len(resp.Choices) == 0 {
|
|
return "", 0, fmt.Errorf("no response from openai")
|
|
}
|
|
|
|
return resp.Choices[0].Message.Content, resp.Usage.TotalTokens, nil
|
|
}
|
|
|
|
// callGemini sends a message to Google Gemini's API
|
|
func callGemini(apiKey, model, message string, history []ChatMessage, temperature float64, maxTokens int) (string, int, error) {
|
|
if model == "" {
|
|
model = "gemini-2.0-flash-exp"
|
|
}
|
|
|
|
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s", model, apiKey)
|
|
|
|
// Build request
|
|
type Part struct {
|
|
Text string `json:"text"`
|
|
}
|
|
type Content struct {
|
|
Role string `json:"role,omitempty"`
|
|
Parts []Part `json:"parts"`
|
|
}
|
|
type GenerationConfig struct {
|
|
Temperature float64 `json:"temperature"`
|
|
MaxOutputTokens int `json:"maxOutputTokens"`
|
|
}
|
|
type Request struct {
|
|
Contents []Content `json:"contents"`
|
|
GenerationConfig GenerationConfig `json:"generationConfig"`
|
|
}
|
|
|
|
contents := []Content{}
|
|
|
|
// Add history
|
|
for _, msg := range history {
|
|
role := "user"
|
|
if msg.Role == "assistant" {
|
|
role = "model"
|
|
}
|
|
contents = append(contents, Content{
|
|
Role: role,
|
|
Parts: []Part{{Text: msg.Content}},
|
|
})
|
|
}
|
|
|
|
// Add current message
|
|
contents = append(contents, Content{
|
|
Role: "user",
|
|
Parts: []Part{{Text: message}},
|
|
})
|
|
|
|
reqBody := Request{
|
|
Contents: contents,
|
|
GenerationConfig: GenerationConfig{
|
|
Temperature: temperature,
|
|
MaxOutputTokens: maxTokens,
|
|
},
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("gemini api error: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return "", 0, fmt.Errorf("gemini api error: %s - %s", resp.Status, string(body))
|
|
}
|
|
|
|
type Candidate struct {
|
|
Content struct {
|
|
Parts []Part `json:"parts"`
|
|
} `json:"content"`
|
|
}
|
|
type UsageMetadata struct {
|
|
TotalTokenCount int `json:"totalTokenCount"`
|
|
}
|
|
type Response struct {
|
|
Candidates []Candidate `json:"candidates"`
|
|
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
|
}
|
|
|
|
var geminiResp Response
|
|
if err := json.NewDecoder(resp.Body).Decode(&geminiResp); err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
if len(geminiResp.Candidates) == 0 || len(geminiResp.Candidates[0].Content.Parts) == 0 {
|
|
return "", 0, fmt.Errorf("no response from gemini")
|
|
}
|
|
|
|
return geminiResp.Candidates[0].Content.Parts[0].Text, geminiResp.UsageMetadata.TotalTokenCount, nil
|
|
}
|
|
|
|
// callClaude sends a message to Anthropic Claude's API
|
|
func callClaude(apiKey, model, message string, history []ChatMessage, temperature float64, maxTokens int) (string, int, error) {
|
|
if model == "" {
|
|
model = "claude-sonnet-4-5-20250929"
|
|
}
|
|
|
|
url := "https://api.anthropic.com/v1/messages"
|
|
|
|
// Build request
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
type Request struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Temperature float64 `json:"temperature"`
|
|
}
|
|
|
|
messages := []Message{}
|
|
|
|
// Add history
|
|
for _, msg := range history {
|
|
messages = append(messages, Message{
|
|
Role: msg.Role,
|
|
Content: msg.Content,
|
|
})
|
|
}
|
|
|
|
// Add current message
|
|
messages = append(messages, Message{
|
|
Role: "user",
|
|
Content: message,
|
|
})
|
|
|
|
reqBody := Request{
|
|
Model: model,
|
|
Messages: messages,
|
|
MaxTokens: maxTokens,
|
|
Temperature: temperature,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("x-api-key", apiKey)
|
|
req.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("claude api error: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return "", 0, fmt.Errorf("claude api error: %s - %s", resp.Status, string(body))
|
|
}
|
|
|
|
type ContentBlock struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text"`
|
|
}
|
|
type Usage struct {
|
|
InputTokens int `json:"input_tokens"`
|
|
OutputTokens int `json:"output_tokens"`
|
|
}
|
|
type Response struct {
|
|
Content []ContentBlock `json:"content"`
|
|
Usage Usage `json:"usage"`
|
|
}
|
|
|
|
var claudeResp Response
|
|
if err := json.NewDecoder(resp.Body).Decode(&claudeResp); err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
if len(claudeResp.Content) == 0 {
|
|
return "", 0, fmt.Errorf("no response from claude")
|
|
}
|
|
|
|
totalTokens := claudeResp.Usage.InputTokens + claudeResp.Usage.OutputTokens
|
|
return claudeResp.Content[0].Text, totalTokens, nil
|
|
}
|
|
|
|
// callQwen sends a message to Qwen AI's API
|
|
func callQwen(apiKey, model, message string, history []ChatMessage, temperature float64, maxTokens int) (string, int, error) {
|
|
if model == "" {
|
|
model = "qwen-max"
|
|
}
|
|
|
|
url := "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
|
|
|
// Build request
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
type Input struct {
|
|
Messages []Message `json:"messages"`
|
|
}
|
|
type Parameters struct {
|
|
ResultFormat string `json:"result_format"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
}
|
|
type Request struct {
|
|
Model string `json:"model"`
|
|
Input Input `json:"input"`
|
|
Parameters Parameters `json:"parameters"`
|
|
}
|
|
|
|
messages := []Message{}
|
|
|
|
// Add history
|
|
for _, msg := range history {
|
|
messages = append(messages, Message{
|
|
Role: msg.Role,
|
|
Content: msg.Content,
|
|
})
|
|
}
|
|
|
|
// Add current message
|
|
messages = append(messages, Message{
|
|
Role: "user",
|
|
Content: message,
|
|
})
|
|
|
|
reqBody := Request{
|
|
Model: model,
|
|
Input: Input{Messages: messages},
|
|
Parameters: Parameters{
|
|
ResultFormat: "message",
|
|
Temperature: temperature,
|
|
MaxTokens: maxTokens,
|
|
},
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("qwen api error: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return "", 0, fmt.Errorf("qwen api error: %s - %s", resp.Status, string(body))
|
|
}
|
|
|
|
type Output struct {
|
|
Choices []struct {
|
|
Message Message `json:"message"`
|
|
} `json:"choices"`
|
|
}
|
|
type Usage struct {
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
type Response struct {
|
|
Output Output `json:"output"`
|
|
Usage Usage `json:"usage"`
|
|
}
|
|
|
|
var qwenResp Response
|
|
if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
if len(qwenResp.Output.Choices) == 0 {
|
|
return "", 0, fmt.Errorf("no response from qwen")
|
|
}
|
|
|
|
return qwenResp.Output.Choices[0].Message.Content, qwenResp.Usage.TotalTokens, nil
|
|
}
|
|
|
|
// callHuggingFace sends a message to HuggingFace's Inference API
|
|
func callHuggingFace(apiKey, model, message string, history []ChatMessage, temperature float64, maxTokens int) (string, int, error) {
|
|
if model == "" {
|
|
model = "meta-llama/Llama-3.3-70B-Instruct"
|
|
}
|
|
|
|
url := fmt.Sprintf("https://api-inference.huggingface.co/models/%s", model)
|
|
|
|
// Build conversation prompt
|
|
prompt := ""
|
|
for _, msg := range history {
|
|
if msg.Role == "user" {
|
|
prompt += fmt.Sprintf("User: %s\n", msg.Content)
|
|
} else {
|
|
prompt += fmt.Sprintf("Assistant: %s\n", msg.Content)
|
|
}
|
|
}
|
|
prompt += fmt.Sprintf("User: %s\nAssistant:", message)
|
|
|
|
// Build request
|
|
type Parameters struct {
|
|
Temperature float64 `json:"temperature"`
|
|
MaxNewTokens int `json:"max_new_tokens"`
|
|
ReturnFullText bool `json:"return_full_text"`
|
|
}
|
|
type Request struct {
|
|
Inputs string `json:"inputs"`
|
|
Parameters Parameters `json:"parameters"`
|
|
}
|
|
|
|
reqBody := Request{
|
|
Inputs: prompt,
|
|
Parameters: Parameters{
|
|
Temperature: temperature,
|
|
MaxNewTokens: maxTokens,
|
|
ReturnFullText: false,
|
|
},
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
|
client := &http.Client{Timeout: 60 * time.Second} // HF can be slower
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("huggingface api error: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return "", 0, fmt.Errorf("huggingface api error: %s - %s", resp.Status, string(body))
|
|
}
|
|
|
|
type Response struct {
|
|
GeneratedText string `json:"generated_text"`
|
|
}
|
|
|
|
var hfResp []Response
|
|
if err := json.NewDecoder(resp.Body).Decode(&hfResp); err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
if len(hfResp) == 0 {
|
|
return "", 0, fmt.Errorf("no response from huggingface")
|
|
}
|
|
|
|
// Estimate tokens (rough approximation: ~4 chars per token)
|
|
estimatedTokens := len(hfResp[0].GeneratedText) / 4
|
|
|
|
return hfResp[0].GeneratedText, estimatedTokens, nil
|
|
}
|