diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index db5ea489c..20a352dde 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -356,6 +356,13 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i }) } geminiRequest.SetTools(geminiTools) + + // [NEW] Convert OpenAI tool_choice to Gemini toolConfig.functionCallingConfig + // Mapping: "auto" -> "AUTO", "none" -> "NONE", "required" -> "ANY" + // Object format: {"type": "function", "function": {"name": "xxx"}} -> "ANY" + allowedFunctionNames + if textRequest.ToolChoice != nil { + geminiRequest.ToolConfig = convertToolChoiceToGeminiConfig(textRequest.ToolChoice) + } } if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") { @@ -960,6 +967,24 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) choice.FinishReason = constant.FinishReasonStop case "MAX_TOKENS": choice.FinishReason = constant.FinishReasonLength + case "SAFETY": + // Safety filter triggered + choice.FinishReason = constant.FinishReasonContentFilter + case "RECITATION": + // Recitation (citation) detected + choice.FinishReason = constant.FinishReasonContentFilter + case "BLOCKLIST": + // Blocklist triggered + choice.FinishReason = constant.FinishReasonContentFilter + case "PROHIBITED_CONTENT": + // Prohibited content detected + choice.FinishReason = constant.FinishReasonContentFilter + case "SPII": + // Sensitive personally identifiable information + choice.FinishReason = constant.FinishReasonContentFilter + case "OTHER": + // Other reasons + choice.FinishReason = constant.FinishReasonContentFilter default: choice.FinishReason = constant.FinishReasonContentFilter } @@ -991,13 +1016,34 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d isTools := false isThought := false if candidate.FinishReason != nil { - // p := GeminiConvertFinishReason(*candidate.FinishReason) + // Map Gemini FinishReason to OpenAI finish_reason switch *candidate.FinishReason { case "STOP": + // Normal completion choice.FinishReason = &constant.FinishReasonStop case "MAX_TOKENS": + // Reached maximum token limit choice.FinishReason = &constant.FinishReasonLength + case "SAFETY": + // Safety filter triggered + choice.FinishReason = &constant.FinishReasonContentFilter + case "RECITATION": + // Recitation (citation) detected + choice.FinishReason = &constant.FinishReasonContentFilter + case "BLOCKLIST": + // Blocklist triggered + choice.FinishReason = &constant.FinishReasonContentFilter + case "PROHIBITED_CONTENT": + // Prohibited content detected + choice.FinishReason = &constant.FinishReasonContentFilter + case "SPII": + // Sensitive personally identifiable information + choice.FinishReason = &constant.FinishReasonContentFilter + case "OTHER": + // Other reasons + choice.FinishReason = &constant.FinishReasonContentFilter default: + // Unknown reason, treat as content filter choice.FinishReason = &constant.FinishReasonContentFilter } } @@ -1214,12 +1260,20 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if len(geminiResponse.Candidates) == 0 { - //return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - //if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil { - // return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest) - //} else { - // return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError) - //} + // [FIX] Return meaningful error when Candidates is empty + if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil { + return nil, types.NewOpenAIError( + errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), + types.ErrorCodePromptBlocked, + http.StatusBadRequest, + ) + } else { + return nil, types.NewOpenAIError( + errors.New("empty response from Gemini API"), + types.ErrorCodeEmptyResponse, + http.StatusInternalServerError, + ) + } } fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse.Model = info.UpstreamModelName @@ -1362,3 +1416,63 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. return usage, nil } + +// convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig +// OpenAI tool_choice values: +// - "auto": Let the model decide (default) +// - "none": Don't call any tools +// - "required": Must call at least one tool +// - {"type": "function", "function": {"name": "xxx"}}: Call specific function +// +// Gemini functionCallingConfig.mode values: +// - "AUTO": Model decides whether to call functions +// - "NONE": Model won't call functions +// - "ANY": Model must call at least one function +func convertToolChoiceToGeminiConfig(toolChoice any) *dto.ToolConfig { + if toolChoice == nil { + return nil + } + + // Handle string values: "auto", "none", "required" + if toolChoiceStr, ok := toolChoice.(string); ok { + config := &dto.ToolConfig{ + FunctionCallingConfig: &dto.FunctionCallingConfig{}, + } + switch toolChoiceStr { + case "auto": + config.FunctionCallingConfig.Mode = "AUTO" + case "none": + config.FunctionCallingConfig.Mode = "NONE" + case "required": + config.FunctionCallingConfig.Mode = "ANY" + default: + // Unknown string value, default to AUTO + config.FunctionCallingConfig.Mode = "AUTO" + } + return config + } + + // Handle object value: {"type": "function", "function": {"name": "xxx"}} + if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok { + if toolChoiceMap["type"] == "function" { + config := &dto.ToolConfig{ + FunctionCallingConfig: &dto.FunctionCallingConfig{ + Mode: "ANY", + }, + } + // Extract function name if specified + if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok { + if name, ok := function["name"].(string); ok && name != "" { + config.FunctionCallingConfig.AllowedFunctionNames = []string{name} + } + } + return config + } + // Unsupported map structure (type is not "function"), return nil + return nil + } + + // Unsupported type, return nil + return nil +} +