Files
new-api/relay/channel/gemini/relay-gemini.go

1747 lines
56 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package gemini
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/relay/channel/openai"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
)
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
var geminiSupportedMimeTypes = map[string]bool{
"application/pdf": true,
"audio/mpeg": true,
"audio/mp3": true,
"audio/wav": true,
"image/png": true,
"image/jpeg": true,
"image/jpg": true, // support old image/jpeg
"image/webp": true,
"text/plain": true,
"video/mov": true,
"video/mpeg": true,
"video/mp4": true,
"video/mpg": true,
"video/avi": true,
"video/wmv": true,
"video/mpegps": true,
"video/flv": true,
}
const thoughtSignatureBypassValue = "context_engineering_is_the_way_to_go"
// Gemini 允许的思考预算范围
const (
pro25MinBudget = 128
pro25MaxBudget = 32768
flash25MaxBudget = 24576
flash25LiteMinBudget = 512
flash25LiteMaxBudget = 24576
)
func isNew25ProModel(modelName string) bool {
return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
}
func is25FlashLiteModel(modelName string) bool {
return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
}
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
func clampThinkingBudget(modelName string, budget int) int {
isNew25Pro := isNew25ProModel(modelName)
is25FlashLite := is25FlashLiteModel(modelName)
if is25FlashLite {
if budget < flash25LiteMinBudget {
return flash25LiteMinBudget
}
if budget > flash25LiteMaxBudget {
return flash25LiteMaxBudget
}
} else if isNew25Pro {
if budget < pro25MinBudget {
return pro25MinBudget
}
if budget > pro25MaxBudget {
return pro25MaxBudget
}
} else { // 其他模型
if budget < 0 {
return 0
}
if budget > flash25MaxBudget {
return flash25MaxBudget
}
}
return budget
}
// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
// "effort": "minimal" - Allocates a minimal portion of tokens (approximately 5% of max_tokens)
func clampThinkingBudgetByEffort(modelName string, effort string) int {
isNew25Pro := isNew25ProModel(modelName)
is25FlashLite := is25FlashLiteModel(modelName)
maxBudget := 0
if is25FlashLite {
maxBudget = flash25LiteMaxBudget
}
if isNew25Pro {
maxBudget = pro25MaxBudget
} else {
maxBudget = flash25MaxBudget
}
switch effort {
case "high":
maxBudget = maxBudget * 80 / 100
case "medium":
maxBudget = maxBudget * 50 / 100
case "low":
maxBudget = maxBudget * 20 / 100
case "minimal":
maxBudget = maxBudget * 5 / 100
}
return clampThinkingBudget(modelName, maxBudget)
}
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
modelName := info.UpstreamModelName
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
if strings.Contains(modelName, "-thinking-") {
parts := strings.SplitN(modelName, "-thinking-", 2)
if len(parts) == 2 && parts[1] != "" {
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(clampedBudget),
IncludeThoughts: true,
}
}
}
} else if strings.HasSuffix(modelName, "-thinking") {
unsupportedModels := []string{
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-pro-preview-03-25",
}
isUnsupported := false
for _, unsupportedModel := range unsupportedModels {
if strings.HasPrefix(modelName, unsupportedModel) {
isUnsupported = true
break
}
}
if isUnsupported {
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
IncludeThoughts: true,
}
} else {
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
IncludeThoughts: true,
}
if geminiRequest.GenerationConfig.MaxOutputTokens != nil && *geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(*geminiRequest.GenerationConfig.MaxOutputTokens)
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
} else {
if len(oaiRequest) > 0 {
// 如果有reasoningEffort参数则根据其值设置思考预算
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
}
}
}
} else if strings.HasSuffix(modelName, "-nothinking") {
if !isNew25Pro {
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(0),
}
}
} else if _, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
IncludeThoughts: true,
ThinkingLevel: level,
}
info.ReasoningEffort = level
}
}
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
geminiRequest := dto.GeminiChatRequest{
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
GenerationConfig: dto.GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
},
}
if textRequest.TopP != nil && *textRequest.TopP > 0 {
geminiRequest.GenerationConfig.TopP = common.GetPointer(*textRequest.TopP)
}
if maxTokens := textRequest.GetMaxTokens(); maxTokens > 0 {
geminiRequest.GenerationConfig.MaxOutputTokens = common.GetPointer(maxTokens)
}
if textRequest.Seed != nil && *textRequest.Seed != 0 {
geminiSeed := int64(lo.FromPtr(textRequest.Seed))
geminiRequest.GenerationConfig.Seed = common.GetPointer(geminiSeed)
}
attachThoughtSignature := (info.ChannelType == constant.ChannelTypeGemini ||
info.ChannelType == constant.ChannelTypeVertexAi) &&
model_setting.GetGeminiSettings().FunctionCallThoughtSignatureEnabled
if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
geminiRequest.GenerationConfig.ResponseModalities = []string{
"TEXT",
"IMAGE",
}
}
if stopSequences := parseStopSequences(textRequest.Stop); len(stopSequences) > 0 {
// Gemini supports up to 5 stop sequences
if len(stopSequences) > 5 {
stopSequences = stopSequences[:5]
}
geminiRequest.GenerationConfig.StopSequences = stopSequences
}
adaptorWithExtraBody := false
// patch extra_body
if len(textRequest.ExtraBody) > 0 {
var extraBody map[string]interface{}
if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
return nil, fmt.Errorf("invalid extra body: %w", err)
}
// eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
adaptorWithExtraBody = true
// check error param name like thinkingConfig, should be thinking_config
if _, hasErrorParam := googleBody["thinkingConfig"]; hasErrorParam {
return nil, errors.New("extra_body.google.thinkingConfig is not supported, use extra_body.google.thinking_config instead")
}
if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
// check error param name like thinkingBudget, should be thinking_budget
if _, hasErrorParam := thinkingConfig["thinkingBudget"]; hasErrorParam {
return nil, errors.New("extra_body.google.thinking_config.thinkingBudget is not supported, use extra_body.google.thinking_config.thinking_budget instead")
}
var hasThinkingConfig bool
var tempThinkingConfig dto.GeminiThinkingConfig
if thinkingBudget, exists := thinkingConfig["thinking_budget"]; exists {
switch v := thinkingBudget.(type) {
case float64:
budgetInt := int(v)
tempThinkingConfig.ThinkingBudget = common.GetPointer(budgetInt)
if budgetInt > 0 {
// 有正数预算
tempThinkingConfig.IncludeThoughts = true
} else {
// 存在但为0或负数禁用思考
tempThinkingConfig.IncludeThoughts = false
}
hasThinkingConfig = true
default:
return nil, errors.New("extra_body.google.thinking_config.thinking_budget must be an integer")
}
}
if includeThoughts, exists := thinkingConfig["include_thoughts"]; exists {
if v, ok := includeThoughts.(bool); ok {
tempThinkingConfig.IncludeThoughts = v
hasThinkingConfig = true
} else {
return nil, errors.New("extra_body.google.thinking_config.include_thoughts must be a boolean")
}
}
if thinkingLevel, exists := thinkingConfig["thinking_level"]; exists {
if v, ok := thinkingLevel.(string); ok {
tempThinkingConfig.ThinkingLevel = v
hasThinkingConfig = true
} else {
return nil, errors.New("extra_body.google.thinking_config.thinking_level must be a string")
}
}
if hasThinkingConfig {
// 避免 panic: 仅在获得配置时分配,防止后续赋值时空指针
if geminiRequest.GenerationConfig.ThinkingConfig == nil {
geminiRequest.GenerationConfig.ThinkingConfig = &tempThinkingConfig
} else {
// 如果已分配,则合并内容
if tempThinkingConfig.ThinkingBudget != nil {
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = tempThinkingConfig.ThinkingBudget
}
geminiRequest.GenerationConfig.ThinkingConfig.IncludeThoughts = tempThinkingConfig.IncludeThoughts
if tempThinkingConfig.ThinkingLevel != "" {
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingLevel = tempThinkingConfig.ThinkingLevel
}
}
}
}
}
// check error param name like imageConfig, should be image_config
if _, hasErrorParam := googleBody["imageConfig"]; hasErrorParam {
return nil, errors.New("extra_body.google.imageConfig is not supported, use extra_body.google.image_config instead")
}
if imageConfig, ok := googleBody["image_config"].(map[string]interface{}); ok {
// check error param name like aspectRatio, should be aspect_ratio
if _, hasErrorParam := imageConfig["aspectRatio"]; hasErrorParam {
return nil, errors.New("extra_body.google.image_config.aspectRatio is not supported, use extra_body.google.image_config.aspect_ratio instead")
}
// check error param name like imageSize, should be image_size
if _, hasErrorParam := imageConfig["imageSize"]; hasErrorParam {
return nil, errors.New("extra_body.google.image_config.imageSize is not supported, use extra_body.google.image_config.image_size instead")
}
// convert snake_case to camelCase for Gemini API
geminiImageConfig := make(map[string]interface{})
if aspectRatio, ok := imageConfig["aspect_ratio"]; ok {
geminiImageConfig["aspectRatio"] = aspectRatio
}
if imageSize, ok := imageConfig["image_size"]; ok {
geminiImageConfig["imageSize"] = imageSize
}
if len(geminiImageConfig) > 0 {
imageConfigBytes, err := common.Marshal(geminiImageConfig)
if err != nil {
return nil, fmt.Errorf("failed to marshal image_config: %w", err)
}
geminiRequest.GenerationConfig.ImageConfig = imageConfigBytes
}
}
}
}
if !adaptorWithExtraBody {
ThinkingAdaptor(&geminiRequest, info, textRequest)
}
safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
Category: category,
Threshold: model_setting.GetGeminiSafetySetting(category),
})
}
geminiRequest.SafetySettings = safetySettings
// openaiContent.FuncToToolCalls()
if textRequest.Tools != nil {
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
urlContext := false
for _, tool := range textRequest.Tools {
if tool.Function.Name == "googleSearch" {
googleSearch = true
continue
}
if tool.Function.Name == "codeExecution" {
codeExecution = true
continue
}
if tool.Function.Name == "urlContext" {
urlContext = true
continue
}
if tool.Function.Parameters != nil {
params, ok := tool.Function.Parameters.(map[string]interface{})
if ok {
if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
if len(props) == 0 {
tool.Function.Parameters = nil
}
}
}
}
// Clean the parameters before appending
cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
tool.Function.Parameters = cleanedParams
functions = append(functions, tool.Function)
}
geminiTools := geminiRequest.GetTools()
if codeExecution {
geminiTools = append(geminiTools, dto.GeminiChatTool{
CodeExecution: make(map[string]string),
})
}
if googleSearch {
geminiTools = append(geminiTools, dto.GeminiChatTool{
GoogleSearch: make(map[string]string),
})
}
if urlContext {
geminiTools = append(geminiTools, dto.GeminiChatTool{
URLContext: make(map[string]string),
})
}
if len(functions) > 0 {
geminiTools = append(geminiTools, dto.GeminiChatTool{
FunctionDeclarations: functions,
})
}
geminiRequest.SetTools(geminiTools)
// [NEW] Convert OpenAI tool_choice to Gemini toolConfig.functionCallingConfig
// Mapping: "auto" -> "AUTO", "none" -> "NONE", "required" -> "ANY"
// Object format: {"type": "function", "function": {"name": "xxx"}} -> "ANY" + allowedFunctionNames
if textRequest.ToolChoice != nil {
geminiRequest.ToolConfig = convertToolChoiceToGeminiConfig(textRequest.ToolChoice)
}
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
if len(textRequest.ResponseFormat.JsonSchema) > 0 {
// 先将json.RawMessage解析
var jsonSchema dto.FormatJsonSchema
if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
}
}
}
tool_call_ids := make(map[string]string)
var system_content []string
//shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
if message.Role == "system" || message.Role == "developer" {
system_content = append(system_content, message.StringContent())
continue
} else if message.Role == "tool" || message.Role == "function" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
Role: "user",
})
}
var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
name := ""
if message.Name != nil {
name = *message.Name
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
name = val
}
var contentMap map[string]interface{}
contentStr := message.StringContent()
// 1. 尝试解析为 JSON 对象
if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
// 2. 如果失败,尝试解析为 JSON 数组
var contentSlice []interface{}
if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
// 如果是数组,包装成对象
contentMap = map[string]interface{}{"result": contentSlice}
} else {
// 3. 如果再次失败,作为纯文本处理
contentMap = map[string]interface{}{"content": contentStr}
}
}
functionResp := &dto.GeminiFunctionResponse{
Name: name,
Response: contentMap,
}
*parts = append(*parts, dto.GeminiPart{
FunctionResponse: functionResp,
})
continue
}
var parts []dto.GeminiPart
content := dto.GeminiChatContent{
Role: message.Role,
}
shouldAttachThoughtSignature := attachThoughtSignature && (message.Role == "assistant" || message.Role == "model")
signatureAttached := false
// isToolCall := false
if message.ToolCalls != nil {
// message.Role = "model"
// isToolCall = true
for _, call := range message.ParseToolCalls() {
args := map[string]interface{}{}
if call.Function.Arguments != "" {
if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
}
}
toolCall := dto.GeminiPart{
FunctionCall: &dto.FunctionCall{
FunctionName: call.Function.Name,
Arguments: args,
},
}
if shouldAttachThoughtSignature && !signatureAttached && hasFunctionCallContent(toolCall.FunctionCall) && len(toolCall.ThoughtSignature) == 0 {
toolCall.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
signatureAttached = true
}
parts = append(parts, toolCall)
tool_call_ids[call.ID] = call.Function.Name
}
}
openaiContent := message.ParseContent()
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
if part.Text == "" {
continue
}
// check markdown image ![image](data:image/jpeg;base64,xxxxxxxxxxxx)
// 使用字符串查找而非正则,避免大文本性能问题
text := part.Text
hasMarkdownImage := false
for {
// 快速检查是否包含 markdown 图片标记
startIdx := strings.Index(text, "![")
if startIdx == -1 {
break
}
// 找到 ](
bracketIdx := strings.Index(text[startIdx:], "](data:")
if bracketIdx == -1 {
break
}
bracketIdx += startIdx
// 找到闭合的 )
closeIdx := strings.Index(text[bracketIdx+2:], ")")
if closeIdx == -1 {
break
}
closeIdx += bracketIdx + 2
hasMarkdownImage = true
// 添加图片前的文本
if startIdx > 0 {
textBefore := text[:startIdx]
if textBefore != "" {
parts = append(parts, dto.GeminiPart{
Text: textBefore,
})
}
}
// 提取 data URL (从 "](" 后面开始,到 ")" 之前)
dataUrl := text[bracketIdx+2 : closeIdx]
format, base64String, err := service.DecodeBase64FileData(dataUrl)
if err != nil {
return nil, fmt.Errorf("decode markdown base64 image data failed: %s", err.Error())
}
imgPart := dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: format,
Data: base64String,
},
}
if shouldAttachThoughtSignature {
imgPart.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
}
parts = append(parts, imgPart)
// 继续处理剩余文本
text = text[closeIdx+1:]
}
// 添加剩余文本或原始文本(如果没有找到 markdown 图片)
if !hasMarkdownImage {
parts = append(parts, dto.GeminiPart{
Text: part.Text,
})
}
} else if part.Type == dto.ContentTypeImageURL {
// 使用统一的文件服务获取图片数据
var source *types.FileSource
imageUrl := part.GetImageMedia().Url
if strings.HasPrefix(imageUrl, "http") {
source = types.NewURLFileSource(imageUrl)
} else {
source = types.NewBase64FileSource(imageUrl, "")
}
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini")
if err != nil {
return nil, fmt.Errorf("get file data from '%s' failed: %w", source.GetIdentifier(), err)
}
// 校验 MimeType 是否在 Gemini 支持的白名单中
if _, ok := geminiSupportedMimeTypes[strings.ToLower(mimeType)]; !ok {
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList())
}
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
} else if part.Type == dto.ContentTypeFile {
if part.GetFile().FileId != "" {
return nil, fmt.Errorf("only base64 file is supported in gemini")
}
fileSource := types.NewBase64FileSource(part.GetFile().FileData, "")
base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini")
if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
}
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
} else if part.Type == dto.ContentTypeInputAudio {
if part.GetInputAudio().Data == "" {
return nil, fmt.Errorf("only base64 audio is supported in gemini")
}
audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format)
base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini")
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
}
// 如果需要附加签名但还没有附加(没有 tool_calls 或 tool_calls 为空),
// 则在第一个文本 part 上附加 thoughtSignature
if shouldAttachThoughtSignature && !signatureAttached && len(parts) > 0 {
for i := range parts {
if parts[i].Text != "" {
parts[i].ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
break
}
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
if len(content.Parts) > 0 {
geminiRequest.Contents = append(geminiRequest.Contents, content)
}
}
if len(system_content) > 0 {
geminiRequest.SystemInstructions = &dto.GeminiChatContent{
Parts: []dto.GeminiPart{
{
Text: strings.Join(system_content, "\n"),
},
},
}
}
return &geminiRequest, nil
}
// parseStopSequences 解析停止序列,支持字符串或字符串数组
func parseStopSequences(stop any) []string {
if stop == nil {
return nil
}
switch v := stop.(type) {
case string:
if v != "" {
return []string{v}
}
case []string:
return v
case []interface{}:
sequences := make([]string, 0, len(v))
for _, item := range v {
if str, ok := item.(string); ok && str != "" {
sequences = append(sequences, str)
}
}
return sequences
}
return nil
}
func hasFunctionCallContent(call *dto.FunctionCall) bool {
if call == nil {
return false
}
if strings.TrimSpace(call.FunctionName) != "" {
return true
}
switch v := call.Arguments.(type) {
case nil:
return false
case string:
return strings.TrimSpace(v) != ""
case map[string]interface{}:
return len(v) > 0
case []interface{}:
return len(v) > 0
default:
return true
}
}
// Helper function to get a list of supported MIME types for error messages
func getSupportedMimeTypesList() []string {
keys := make([]string, 0, len(geminiSupportedMimeTypes))
for k := range geminiSupportedMimeTypes {
keys = append(keys, k)
}
return keys
}
var geminiOpenAPISchemaAllowedFields = map[string]struct{}{
"anyOf": {},
"default": {},
"description": {},
"enum": {},
"example": {},
"format": {},
"items": {},
"maxItems": {},
"maxLength": {},
"maxProperties": {},
"maximum": {},
"minItems": {},
"minLength": {},
"minProperties": {},
"minimum": {},
"nullable": {},
"pattern": {},
"properties": {},
"propertyOrdering": {},
"required": {},
"title": {},
"type": {},
}
const geminiFunctionSchemaMaxDepth = 64
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
func cleanFunctionParameters(params interface{}) interface{} {
return cleanFunctionParametersWithDepth(params, 0)
}
func cleanFunctionParametersWithDepth(params interface{}, depth int) interface{} {
if params == nil {
return nil
}
if depth >= geminiFunctionSchemaMaxDepth {
return cleanFunctionParametersShallow(params)
}
switch v := params.(type) {
case map[string]interface{}:
// Keep only Gemini-supported OpenAPI schema subset fields (per official SDK Schema).
cleanedMap := make(map[string]interface{}, len(v))
for k, val := range v {
if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
cleanedMap[k] = val
}
}
normalizeGeminiSchemaTypeAndNullable(cleanedMap)
// Clean properties
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
cleanedProps := make(map[string]interface{})
for propName, propValue := range props {
cleanedProps[propName] = cleanFunctionParametersWithDepth(propValue, depth+1)
}
cleanedMap["properties"] = cleanedProps
}
// Recursively clean items in arrays
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
cleanedMap["items"] = cleanFunctionParametersWithDepth(items, depth+1)
}
// OpenAPI tuple-style items is not supported by Gemini SDK Schema; keep first to avoid API rejection.
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok && len(itemsArray) > 0 {
cleanedMap["items"] = cleanFunctionParametersWithDepth(itemsArray[0], depth+1)
}
// Recursively clean anyOf
if nested, ok := cleanedMap["anyOf"].([]interface{}); ok && nested != nil {
cleanedNested := make([]interface{}, len(nested))
for i, item := range nested {
cleanedNested[i] = cleanFunctionParametersWithDepth(item, depth+1)
}
cleanedMap["anyOf"] = cleanedNested
}
return cleanedMap
case []interface{}:
// Handle arrays of schemas
cleanedArray := make([]interface{}, len(v))
for i, item := range v {
cleanedArray[i] = cleanFunctionParametersWithDepth(item, depth+1)
}
return cleanedArray
default:
// Not a map or array, return as is (e.g., could be a primitive)
return params
}
}
func cleanFunctionParametersShallow(params interface{}) interface{} {
switch v := params.(type) {
case map[string]interface{}:
cleanedMap := make(map[string]interface{}, len(v))
for k, val := range v {
if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
cleanedMap[k] = val
}
}
normalizeGeminiSchemaTypeAndNullable(cleanedMap)
// Stop recursion and avoid retaining huge nested structures.
delete(cleanedMap, "properties")
delete(cleanedMap, "items")
delete(cleanedMap, "anyOf")
return cleanedMap
case []interface{}:
// Prefer an empty list over deep recursion on attacker-controlled inputs.
return []interface{}{}
default:
return params
}
}
func normalizeGeminiSchemaTypeAndNullable(schema map[string]interface{}) {
rawType, ok := schema["type"]
if !ok || rawType == nil {
return
}
normalize := func(t string) (string, bool) {
switch strings.ToLower(strings.TrimSpace(t)) {
case "object":
return "OBJECT", false
case "array":
return "ARRAY", false
case "string":
return "STRING", false
case "integer":
return "INTEGER", false
case "number":
return "NUMBER", false
case "boolean":
return "BOOLEAN", false
case "null":
return "", true
default:
return t, false
}
}
switch t := rawType.(type) {
case string:
normalized, isNull := normalize(t)
if isNull {
schema["nullable"] = true
delete(schema, "type")
return
}
schema["type"] = normalized
case []interface{}:
nullable := false
var chosen string
for _, item := range t {
if s, ok := item.(string); ok {
normalized, isNull := normalize(s)
if isNull {
nullable = true
continue
}
if chosen == "" {
chosen = normalized
}
}
}
if nullable {
schema["nullable"] = true
}
if chosen != "" {
schema["type"] = chosen
} else {
delete(schema, "type")
}
}
}
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
if depth >= 5 {
return schema
}
v, ok := schema.(map[string]interface{})
if !ok || len(v) == 0 {
return schema
}
// 删除所有的title字段
delete(v, "title")
delete(v, "$schema")
// 如果type不为object和array则直接返回
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
return schema
}
switch v["type"] {
case "object":
delete(v, "additionalProperties")
// 处理 properties
if properties, ok := v["properties"].(map[string]interface{}); ok {
for key, value := range properties {
properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
}
}
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := v[field].([]interface{}); ok {
for i, item := range nested {
nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
}
}
}
case "array":
if items, ok := v["items"].(map[string]interface{}); ok {
v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
}
}
return v
}
func unescapeString(s string) (string, error) {
var result []rune
escaped := false
i := 0
for i < len(s) {
r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
if r == utf8.RuneError {
return "", fmt.Errorf("invalid UTF-8 encoding")
}
if escaped {
// 如果是转义符后的字符,检查其类型
switch r {
case '"':
result = append(result, '"')
case '\\':
result = append(result, '\\')
case '/':
result = append(result, '/')
case 'b':
result = append(result, '\b')
case 'f':
result = append(result, '\f')
case 'n':
result = append(result, '\n')
case 'r':
result = append(result, '\r')
case 't':
result = append(result, '\t')
case '\'':
result = append(result, '\'')
default:
// 如果遇到一个非法的转义字符,直接按原样输出
result = append(result, '\\', r)
}
escaped = false
} else {
if r == '\\' {
escaped = true // 记录反斜杠作为转义符
} else {
result = append(result, r)
}
}
i += size // 移动到下一个字符
}
return string(result), nil
}
func unescapeMapOrSlice(data interface{}) interface{} {
switch v := data.(type) {
case map[string]interface{}:
for k, val := range v {
v[k] = unescapeMapOrSlice(val)
}
case []interface{}:
for i, val := range v {
v[i] = unescapeMapOrSlice(val)
}
case string:
if unescaped, err := unescapeString(v); err != nil {
return v
} else {
return unescaped
}
}
return data
}
func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte
var err error
// 移除 unescapeMapOrSlice 调用,直接使用 json.Marshal
// JSON 序列化/反序列化已经正确处理了转义字符
argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
if err != nil {
return nil
}
return &dto.ToolCallResponse{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionResponse{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
}
}
func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage {
promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount
if promptTokens <= 0 && fallbackPromptTokens > 0 {
promptTokens = fallbackPromptTokens
}
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount,
TotalTokens: metadata.TotalTokenCount,
}
usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount
usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount
for _, detail := range metadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens += detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens += detail.TokenCount
}
}
for _, detail := range metadata.ToolUsePromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens += detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens += detail.TokenCount
}
}
if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 {
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
}
if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 {
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
}
return usage
}
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: helper.GetResponseID(c),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
isToolCall := false
for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: int(candidate.Index),
Message: dto.Message{
Role: "assistant",
Content: "",
},
FinishReason: constant.FinishReasonStop,
}
if len(candidate.Content.Parts) > 0 {
var texts []string
var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
// 媒体内容
if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
} else {
// 其他媒体类型,直接显示链接
texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
}
} else if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call)
}
} else if part.Thought {
choice.Message.ReasoningContent = part.Text
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
} else {
// 过滤掉空行
if part.Text != "\n" {
texts = append(texts, part.Text)
}
}
}
}
if len(toolCalls) > 0 {
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
}
if candidate.FinishReason != nil {
switch *candidate.FinishReason {
case "STOP":
choice.FinishReason = constant.FinishReasonStop
case "MAX_TOKENS":
choice.FinishReason = constant.FinishReasonLength
case "SAFETY":
// Safety filter triggered
choice.FinishReason = constant.FinishReasonContentFilter
case "RECITATION":
// Recitation (citation) detected
choice.FinishReason = constant.FinishReasonContentFilter
case "BLOCKLIST":
// Blocklist triggered
choice.FinishReason = constant.FinishReasonContentFilter
case "PROHIBITED_CONTENT":
// Prohibited content detected
choice.FinishReason = constant.FinishReasonContentFilter
case "SPII":
// Sensitive personally identifiable information
choice.FinishReason = constant.FinishReasonContentFilter
case "OTHER":
// Other reasons
choice.FinishReason = constant.FinishReasonContentFilter
default:
choice.FinishReason = constant.FinishReasonContentFilter
}
}
if isToolCall {
choice.FinishReason = constant.FinishReasonToolCalls
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
isStop := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
isStop = true
candidate.FinishReason = nil
}
choice := dto.ChatCompletionsStreamResponseChoice{
Index: int(candidate.Index),
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
//Role: "assistant",
},
}
var texts []string
isTools := false
isThought := false
if candidate.FinishReason != nil {
// Map Gemini FinishReason to OpenAI finish_reason
switch *candidate.FinishReason {
case "STOP":
// Normal completion
choice.FinishReason = &constant.FinishReasonStop
case "MAX_TOKENS":
// Reached maximum token limit
choice.FinishReason = &constant.FinishReasonLength
case "SAFETY":
// Safety filter triggered
choice.FinishReason = &constant.FinishReasonContentFilter
case "RECITATION":
// Recitation (citation) detected
choice.FinishReason = &constant.FinishReasonContentFilter
case "BLOCKLIST":
// Blocklist triggered
choice.FinishReason = &constant.FinishReasonContentFilter
case "PROHIBITED_CONTENT":
// Prohibited content detected
choice.FinishReason = &constant.FinishReasonContentFilter
case "SPII":
// Sensitive personally identifiable information
choice.FinishReason = &constant.FinishReasonContentFilter
case "OTHER":
// Other reasons
choice.FinishReason = &constant.FinishReasonContentFilter
default:
// Unknown reason, treat as content filter
choice.FinishReason = &constant.FinishReasonContentFilter
}
}
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
}
} else if part.FunctionCall != nil {
isTools = true
if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
} else if part.Thought {
isThought = true
texts = append(texts, part.Text)
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
} else {
if part.Text != "\n" {
texts = append(texts, part.Text)
}
}
}
}
if isThought {
choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
} else {
choice.Delta.SetContentString(strings.Join(texts, "\n"))
}
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
}
choices = append(choices, choice)
}
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Choices = choices
return &response, isStop
}
func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
streamData, err := common.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal stream response: %w", err)
}
err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
if err != nil {
return fmt.Errorf("failed to handle stream format: %w", err)
}
return nil
}
func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
streamData, err := common.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal stream response: %w", err)
}
openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
return nil
}
func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
var usage = &dto.Usage{}
var imageCount int
responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
}
// 统计图片数量
for _, candidate := range geminiResponse.Candidates {
for _, part := range candidate.Content.Parts {
if part.InlineData != nil && part.InlineData.MimeType != "" {
imageCount++
}
if part.Text != "" {
responseText.WriteString(part.Text)
}
}
}
// 更新使用量统计
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
*usage = mappedUsage
}
return callback(data, &geminiResponse)
})
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 1400
}
}
if usage.CompletionTokens <= 0 {
if info.ReceivedResponseCount > 0 {
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
} else {
usage = &dto.Usage{}
}
}
return usage, nil
}
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
finishReason := constant.FinishReasonStop
toolCallIndexByChoice := make(map[int]map[string]int)
nextToolCallIndexByChoice := make(map[int]int)
usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
for choiceIdx := range response.Choices {
choiceKey := response.Choices[choiceIdx].Index
for toolIdx := range response.Choices[choiceIdx].Delta.ToolCalls {
tool := &response.Choices[choiceIdx].Delta.ToolCalls[toolIdx]
if tool.ID == "" {
continue
}
m := toolCallIndexByChoice[choiceKey]
if m == nil {
m = make(map[string]int)
toolCallIndexByChoice[choiceKey] = m
}
if idx, ok := m[tool.ID]; ok {
tool.SetIndex(idx)
continue
}
idx := nextToolCallIndexByChoice[choiceKey]
nextToolCallIndexByChoice[choiceKey] = idx + 1
m[tool.ID] = idx
tool.SetIndex(idx)
}
}
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
if info.SendResponseCount == 0 {
// send first response
emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
if response.IsToolCall() {
if len(emptyResponse.Choices) > 0 && len(response.Choices) > 0 {
toolCalls := response.Choices[0].Delta.ToolCalls
copiedToolCalls := make([]dto.ToolCallResponse, len(toolCalls))
for idx := range toolCalls {
copiedToolCalls[idx] = toolCalls[idx]
copiedToolCalls[idx].Function.Arguments = ""
}
emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
}
finishReason = constant.FinishReasonToolCalls
err := handleStream(c, info, emptyResponse)
if err != nil {
logger.LogError(c, err.Error())
}
response.ClearToolCalls()
if response.IsFinished() {
response.Choices[0].FinishReason = nil
}
} else {
err := handleStream(c, info, emptyResponse)
if err != nil {
logger.LogError(c, err.Error())
}
}
}
err := handleStream(c, info, response)
if err != nil {
logger.LogError(c, err.Error())
}
if isStop {
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
}
return true
})
if err != nil {
return usage, err
}
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
handleErr := handleFinalStream(c, info, response)
if handleErr != nil {
common.SysLog("send final response failed: " + handleErr.Error())
}
return usage, nil
}
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
service.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println(string(responseBody))
}
var geminiResponse dto.GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Candidates) == 0 {
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
var newAPIError *types.NewAPIError
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
newAPIError = types.NewOpenAIError(
errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason),
types.ErrorCodePromptBlocked,
http.StatusBadRequest,
)
} else {
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "gemini_empty_candidates")
newAPIError = types.NewOpenAIError(
errors.New("empty response from Gemini API"),
types.ErrorCodeEmptyResponse,
http.StatusInternalServerError,
)
}
service.ResetStatusCode(newAPIError, c.GetString("status_code_mapping"))
switch info.RelayFormat {
case types.RelayFormatClaude:
c.JSON(newAPIError.StatusCode, gin.H{
"type": "error",
"error": newAPIError.ToClaudeError(),
})
default:
c.JSON(newAPIError.StatusCode, gin.H{
"error": newAPIError.ToOpenAIError(),
})
}
return &usage, nil
}
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
fullTextResponse.Usage = usage
switch info.RelayFormat {
case types.RelayFormatOpenAI:
responseBody, err = common.Marshal(fullTextResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
case types.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
claudeRespStr, err := common.Marshal(claudeResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
responseBody = claudeRespStr
case types.RelayFormatGemini:
break
}
service.IOCopyBytesGracefully(c, resp, responseBody)
return &usage, nil
}
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer service.CloseResponseBodyGracefully(resp)
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
var geminiResponse dto.GeminiBatchEmbeddingResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// convert to openai format response
openAIResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
Model: info.UpstreamModelName,
}
for i, embedding := range geminiResponse.Embeddings {
openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: "embedding",
Embedding: embedding.Values,
Index: i,
})
}
// calculate usage
// https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
// Google has not yet clarified how embedding models will be billed
// refer to openai billing method to use input tokens billing
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
openAIResponse.Usage = *usage
jsonResponse, jsonErr := common.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
service.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse dto.GeminiImageResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Predictions) == 0 {
return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
type GeminiModelsResponse struct {
Models []dto.GeminiModel `json:"models"`
NextPageToken string `json:"nextPageToken"`
}
func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
client, err := service.GetHttpClientWithProxy(proxyURL)
if err != nil {
return nil, fmt.Errorf("创建HTTP客户端失败: %v", err)
}
allModels := make([]string, 0)
nextPageToken := ""
maxPages := 100 // Safety limit to prevent infinite loops
for page := 0; page < maxPages; page++ {
url := fmt.Sprintf("%s/v1beta/models", baseURL)
if nextPageToken != "" {
url = fmt.Sprintf("%s?pageToken=%s", url, nextPageToken)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
cancel()
return nil, fmt.Errorf("创建请求失败: %v", err)
}
request.Header.Set("x-goog-api-key", apiKey)
response, err := client.Do(request)
if err != nil {
cancel()
return nil, fmt.Errorf("请求失败: %v", err)
}
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
response.Body.Close()
cancel()
return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
}
body, err := io.ReadAll(response.Body)
response.Body.Close()
cancel()
if err != nil {
return nil, fmt.Errorf("读取响应失败: %v", err)
}
var modelsResponse GeminiModelsResponse
if err = common.Unmarshal(body, &modelsResponse); err != nil {
return nil, fmt.Errorf("解析响应失败: %v", err)
}
for _, model := range modelsResponse.Models {
modelNameValue, ok := model.Name.(string)
if !ok {
continue
}
modelName := strings.TrimPrefix(modelNameValue, "models/")
allModels = append(allModels, modelName)
}
nextPageToken = modelsResponse.NextPageToken
if nextPageToken == "" {
break
}
}
return allModels, nil
}
// convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig
// OpenAI tool_choice values:
// - "auto": Let the model decide (default)
// - "none": Don't call any tools
// - "required": Must call at least one tool
// - {"type": "function", "function": {"name": "xxx"}}: Call specific function
//
// Gemini functionCallingConfig.mode values:
// - "AUTO": Model decides whether to call functions
// - "NONE": Model won't call functions
// - "ANY": Model must call at least one function
func convertToolChoiceToGeminiConfig(toolChoice any) *dto.ToolConfig {
if toolChoice == nil {
return nil
}
// Handle string values: "auto", "none", "required"
if toolChoiceStr, ok := toolChoice.(string); ok {
config := &dto.ToolConfig{
FunctionCallingConfig: &dto.FunctionCallingConfig{},
}
switch toolChoiceStr {
case "auto":
config.FunctionCallingConfig.Mode = "AUTO"
case "none":
config.FunctionCallingConfig.Mode = "NONE"
case "required":
config.FunctionCallingConfig.Mode = "ANY"
default:
// Unknown string value, default to AUTO
config.FunctionCallingConfig.Mode = "AUTO"
}
return config
}
// Handle object value: {"type": "function", "function": {"name": "xxx"}}
if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
if toolChoiceMap["type"] == "function" {
config := &dto.ToolConfig{
FunctionCallingConfig: &dto.FunctionCallingConfig{
Mode: "ANY",
},
}
// Extract function name if specified
if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
if name, ok := function["name"].(string); ok && name != "" {
config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
}
}
return config
}
// Unsupported map structure (type is not "function"), return nil
return nil
}
// Unsupported type, return nil
return nil
}