Merge pull request #2742 from seefs001/fix/pr-2540

feat(gemini): 支持 tool_choice 参数转换,优化多个渠道错误处理
This commit is contained in:
Calcium-Ion
2026-01-26 15:12:09 +08:00
committed by GitHub
10 changed files with 247 additions and 32 deletions

View File

@@ -8,11 +8,13 @@ import (
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/relay/channel/openrouter"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/relay/reasonmap"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/types"
@@ -27,17 +29,15 @@ const (
)
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "end_turn":
return "stop"
case "max_tokens":
return "length"
case "tool_use":
return "tool_calls"
default:
return reason
return reasonmap.ClaudeStopReasonToOpenAIFinishReason(reason)
}
func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
if c == nil {
return
}
if strings.EqualFold(stopReason, "refusal") {
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "claude_stop_reason=refusal")
}
}
@@ -644,6 +644,12 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
}
if claudeResponse.StopReason != "" {
maybeMarkClaudeRefusal(c, claudeResponse.StopReason)
}
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
maybeMarkClaudeRefusal(c, *claudeResponse.Delta.StopReason)
}
if info.RelayFormat == types.RelayFormatClaude {
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
@@ -735,6 +741,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
}
maybeMarkClaudeRefusal(c, claudeResponse.StopReason)
if requestMode == RequestModeCompletion {
claudeInfo.Usage = service.ResponseText2Usage(c, claudeResponse.Completion, info.UpstreamModelName, info.GetEstimatePromptTokens())
} else {

View File

@@ -1,10 +1,12 @@
package gemini
import (
"fmt"
"io"
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
@@ -35,6 +37,10 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
}
// 计算使用量(基于 UsageMetadata
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,

View File

@@ -359,6 +359,13 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
})
}
geminiRequest.SetTools(geminiTools)
// [NEW] Convert OpenAI tool_choice to Gemini toolConfig.functionCallingConfig
// Mapping: "auto" -> "AUTO", "none" -> "NONE", "required" -> "ANY"
// Object format: {"type": "function", "function": {"name": "xxx"}} -> "ANY" + allowedFunctionNames
if textRequest.ToolChoice != nil {
geminiRequest.ToolConfig = convertToolChoiceToGeminiConfig(textRequest.ToolChoice)
}
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
@@ -1031,6 +1038,24 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
choice.FinishReason = constant.FinishReasonStop
case "MAX_TOKENS":
choice.FinishReason = constant.FinishReasonLength
case "SAFETY":
// Safety filter triggered
choice.FinishReason = constant.FinishReasonContentFilter
case "RECITATION":
// Recitation (citation) detected
choice.FinishReason = constant.FinishReasonContentFilter
case "BLOCKLIST":
// Blocklist triggered
choice.FinishReason = constant.FinishReasonContentFilter
case "PROHIBITED_CONTENT":
// Prohibited content detected
choice.FinishReason = constant.FinishReasonContentFilter
case "SPII":
// Sensitive personally identifiable information
choice.FinishReason = constant.FinishReasonContentFilter
case "OTHER":
// Other reasons
choice.FinishReason = constant.FinishReasonContentFilter
default:
choice.FinishReason = constant.FinishReasonContentFilter
}
@@ -1062,13 +1087,34 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
isTools := false
isThought := false
if candidate.FinishReason != nil {
// p := GeminiConvertFinishReason(*candidate.FinishReason)
// Map Gemini FinishReason to OpenAI finish_reason
switch *candidate.FinishReason {
case "STOP":
// Normal completion
choice.FinishReason = &constant.FinishReasonStop
case "MAX_TOKENS":
// Reached maximum token limit
choice.FinishReason = &constant.FinishReasonLength
case "SAFETY":
// Safety filter triggered
choice.FinishReason = &constant.FinishReasonContentFilter
case "RECITATION":
// Recitation (citation) detected
choice.FinishReason = &constant.FinishReasonContentFilter
case "BLOCKLIST":
// Blocklist triggered
choice.FinishReason = &constant.FinishReasonContentFilter
case "PROHIBITED_CONTENT":
// Prohibited content detected
choice.FinishReason = &constant.FinishReasonContentFilter
case "SPII":
// Sensitive personally identifiable information
choice.FinishReason = &constant.FinishReasonContentFilter
case "OTHER":
// Other reasons
choice.FinishReason = &constant.FinishReasonContentFilter
default:
// Unknown reason, treat as content filter
choice.FinishReason = &constant.FinishReasonContentFilter
}
}
@@ -1151,6 +1197,10 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
return false
}
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
}
// 统计图片数量
for _, candidate := range geminiResponse.Candidates {
for _, part := range candidate.Content.Parts {
@@ -1309,12 +1359,52 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Candidates) == 0 {
//return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
//if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
// return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
//} else {
// return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
//}
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
}
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
if usage.PromptTokens <= 0 {
usage.PromptTokens = info.GetEstimatePromptTokens()
}
var newAPIError *types.NewAPIError
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
newAPIError = types.NewOpenAIError(
errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason),
types.ErrorCodePromptBlocked,
http.StatusBadRequest,
)
} else {
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "gemini_empty_candidates")
newAPIError = types.NewOpenAIError(
errors.New("empty response from Gemini API"),
types.ErrorCodeEmptyResponse,
http.StatusInternalServerError,
)
}
service.ResetStatusCode(newAPIError, c.GetString("status_code_mapping"))
switch info.RelayFormat {
case types.RelayFormatClaude:
c.JSON(newAPIError.StatusCode, gin.H{
"type": "error",
"error": newAPIError.ToClaudeError(),
})
default:
c.JSON(newAPIError.StatusCode, gin.H{
"error": newAPIError.ToOpenAIError(),
})
}
return &usage, nil
}
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
@@ -1530,3 +1620,62 @@ func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
return allModels, nil
}
// convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig
// OpenAI tool_choice values:
// - "auto": Let the model decide (default)
// - "none": Don't call any tools
// - "required": Must call at least one tool
// - {"type": "function", "function": {"name": "xxx"}}: Call specific function
//
// Gemini functionCallingConfig.mode values:
// - "AUTO": Model decides whether to call functions
// - "NONE": Model won't call functions
// - "ANY": Model must call at least one function
func convertToolChoiceToGeminiConfig(toolChoice any) *dto.ToolConfig {
if toolChoice == nil {
return nil
}
// Handle string values: "auto", "none", "required"
if toolChoiceStr, ok := toolChoice.(string); ok {
config := &dto.ToolConfig{
FunctionCallingConfig: &dto.FunctionCallingConfig{},
}
switch toolChoiceStr {
case "auto":
config.FunctionCallingConfig.Mode = "AUTO"
case "none":
config.FunctionCallingConfig.Mode = "NONE"
case "required":
config.FunctionCallingConfig.Mode = "ANY"
default:
// Unknown string value, default to AUTO
config.FunctionCallingConfig.Mode = "AUTO"
}
return config
}
// Handle object value: {"type": "function", "function": {"name": "xxx"}}
if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
if toolChoiceMap["type"] == "function" {
config := &dto.ToolConfig{
FunctionCallingConfig: &dto.FunctionCallingConfig{
Mode: "ANY",
},
}
// Extract function name if specified
if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
if name, ok := function["name"].(string); ok && name != "" {
config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
}
}
return config
}
// Unsupported map structure (type is not "function"), return nil
return nil
}
// Unsupported type, return nil
return nil
}

View File

@@ -229,6 +229,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
for _, choice := range simpleResponse.Choices {
if choice.FinishReason == constant.FinishReasonContentFilter {
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "openai_finish_reason=content_filter")
break
}
}
forceFormat := false
if info.ChannelSetting.ForceFormat {
forceFormat = true

View File

@@ -237,6 +237,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
}
extraContent = append(extraContent, "上游无计费信息")
}
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
cacheTokens := usage.PromptTokensDetails.CachedTokens
@@ -461,6 +464,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
}
logContent := strings.Join(extraContent, ", ")
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
if adminRejectReason != "" {
other["reject_reason"] = adminRejectReason
}
// For chat-based calls to the Claude model, tagging is required. Using Claude's rendering logs, the two approaches handle input rendering differently.
if isClaudeUsageSemantic {
other["claude"] = true

View File

@@ -0,0 +1,41 @@
package reasonmap
import (
"strings"
"github.com/QuantumNous/new-api/constant"
)
func ClaudeStopReasonToOpenAIFinishReason(stopReason string) string {
switch strings.ToLower(stopReason) {
case "stop_sequence":
return "stop"
case "end_turn":
return "stop"
case "max_tokens":
return "length"
case "tool_use":
return "tool_calls"
case "refusal":
return constant.FinishReasonContentFilter
default:
return stopReason
}
}
func OpenAIFinishReasonToClaudeStopReason(finishReason string) string {
switch strings.ToLower(finishReason) {
case "stop":
return "end_turn"
case "stop_sequence":
return "stop_sequence"
case "length", "max_tokens":
return "max_tokens"
case constant.FinishReasonContentFilter:
return "refusal"
case "tool_calls":
return "tool_use"
default:
return finishReason
}
}