mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:25:00 +00:00
Merge pull request #2742 from seefs001/fix/pr-2540
feat(gemini): 支持 tool_choice 参数转换,优化多个渠道错误处理
This commit is contained in:
@@ -8,11 +8,13 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openrouter"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/relay/reasonmap"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
@@ -27,17 +29,15 @@ const (
|
||||
)
|
||||
|
||||
func stopReasonClaude2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
case "tool_use":
|
||||
return "tool_calls"
|
||||
default:
|
||||
return reason
|
||||
return reasonmap.ClaudeStopReasonToOpenAIFinishReason(reason)
|
||||
}
|
||||
|
||||
func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(stopReason, "refusal") {
|
||||
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "claude_stop_reason=refusal")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -644,6 +644,12 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if claudeResponse.StopReason != "" {
|
||||
maybeMarkClaudeRefusal(c, claudeResponse.StopReason)
|
||||
}
|
||||
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
|
||||
maybeMarkClaudeRefusal(c, *claudeResponse.Delta.StopReason)
|
||||
}
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
|
||||
@@ -735,6 +741,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
maybeMarkClaudeRefusal(c, claudeResponse.StopReason)
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeResponse.Completion, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -35,6 +37,10 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
|
||||
@@ -359,6 +359,13 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
})
|
||||
}
|
||||
geminiRequest.SetTools(geminiTools)
|
||||
|
||||
// [NEW] Convert OpenAI tool_choice to Gemini toolConfig.functionCallingConfig
|
||||
// Mapping: "auto" -> "AUTO", "none" -> "NONE", "required" -> "ANY"
|
||||
// Object format: {"type": "function", "function": {"name": "xxx"}} -> "ANY" + allowedFunctionNames
|
||||
if textRequest.ToolChoice != nil {
|
||||
geminiRequest.ToolConfig = convertToolChoiceToGeminiConfig(textRequest.ToolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
@@ -1031,6 +1038,24 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
|
||||
choice.FinishReason = constant.FinishReasonStop
|
||||
case "MAX_TOKENS":
|
||||
choice.FinishReason = constant.FinishReasonLength
|
||||
case "SAFETY":
|
||||
// Safety filter triggered
|
||||
choice.FinishReason = constant.FinishReasonContentFilter
|
||||
case "RECITATION":
|
||||
// Recitation (citation) detected
|
||||
choice.FinishReason = constant.FinishReasonContentFilter
|
||||
case "BLOCKLIST":
|
||||
// Blocklist triggered
|
||||
choice.FinishReason = constant.FinishReasonContentFilter
|
||||
case "PROHIBITED_CONTENT":
|
||||
// Prohibited content detected
|
||||
choice.FinishReason = constant.FinishReasonContentFilter
|
||||
case "SPII":
|
||||
// Sensitive personally identifiable information
|
||||
choice.FinishReason = constant.FinishReasonContentFilter
|
||||
case "OTHER":
|
||||
// Other reasons
|
||||
choice.FinishReason = constant.FinishReasonContentFilter
|
||||
default:
|
||||
choice.FinishReason = constant.FinishReasonContentFilter
|
||||
}
|
||||
@@ -1062,13 +1087,34 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
|
||||
isTools := false
|
||||
isThought := false
|
||||
if candidate.FinishReason != nil {
|
||||
// p := GeminiConvertFinishReason(*candidate.FinishReason)
|
||||
// Map Gemini FinishReason to OpenAI finish_reason
|
||||
switch *candidate.FinishReason {
|
||||
case "STOP":
|
||||
// Normal completion
|
||||
choice.FinishReason = &constant.FinishReasonStop
|
||||
case "MAX_TOKENS":
|
||||
// Reached maximum token limit
|
||||
choice.FinishReason = &constant.FinishReasonLength
|
||||
case "SAFETY":
|
||||
// Safety filter triggered
|
||||
choice.FinishReason = &constant.FinishReasonContentFilter
|
||||
case "RECITATION":
|
||||
// Recitation (citation) detected
|
||||
choice.FinishReason = &constant.FinishReasonContentFilter
|
||||
case "BLOCKLIST":
|
||||
// Blocklist triggered
|
||||
choice.FinishReason = &constant.FinishReasonContentFilter
|
||||
case "PROHIBITED_CONTENT":
|
||||
// Prohibited content detected
|
||||
choice.FinishReason = &constant.FinishReasonContentFilter
|
||||
case "SPII":
|
||||
// Sensitive personally identifiable information
|
||||
choice.FinishReason = &constant.FinishReasonContentFilter
|
||||
case "OTHER":
|
||||
// Other reasons
|
||||
choice.FinishReason = &constant.FinishReasonContentFilter
|
||||
default:
|
||||
// Unknown reason, treat as content filter
|
||||
choice.FinishReason = &constant.FinishReasonContentFilter
|
||||
}
|
||||
}
|
||||
@@ -1151,6 +1197,10 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
return false
|
||||
}
|
||||
|
||||
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
|
||||
}
|
||||
|
||||
// 统计图片数量
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
@@ -1309,12 +1359,52 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
//return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
//if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
// return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
|
||||
//} else {
|
||||
// return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
//}
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
}
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
if usage.PromptTokens <= 0 {
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
}
|
||||
|
||||
var newAPIError *types.NewAPIError
|
||||
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
|
||||
newAPIError = types.NewOpenAIError(
|
||||
errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason),
|
||||
types.ErrorCodePromptBlocked,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
} else {
|
||||
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "gemini_empty_candidates")
|
||||
newAPIError = types.NewOpenAIError(
|
||||
errors.New("empty response from Gemini API"),
|
||||
types.ErrorCodeEmptyResponse,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
service.ResetStatusCode(newAPIError, c.GetString("status_code_mapping"))
|
||||
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": newAPIError.ToClaudeError(),
|
||||
})
|
||||
default:
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
@@ -1530,3 +1620,62 @@ func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
|
||||
|
||||
return allModels, nil
|
||||
}
|
||||
|
||||
// convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig
|
||||
// OpenAI tool_choice values:
|
||||
// - "auto": Let the model decide (default)
|
||||
// - "none": Don't call any tools
|
||||
// - "required": Must call at least one tool
|
||||
// - {"type": "function", "function": {"name": "xxx"}}: Call specific function
|
||||
//
|
||||
// Gemini functionCallingConfig.mode values:
|
||||
// - "AUTO": Model decides whether to call functions
|
||||
// - "NONE": Model won't call functions
|
||||
// - "ANY": Model must call at least one function
|
||||
func convertToolChoiceToGeminiConfig(toolChoice any) *dto.ToolConfig {
|
||||
if toolChoice == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle string values: "auto", "none", "required"
|
||||
if toolChoiceStr, ok := toolChoice.(string); ok {
|
||||
config := &dto.ToolConfig{
|
||||
FunctionCallingConfig: &dto.FunctionCallingConfig{},
|
||||
}
|
||||
switch toolChoiceStr {
|
||||
case "auto":
|
||||
config.FunctionCallingConfig.Mode = "AUTO"
|
||||
case "none":
|
||||
config.FunctionCallingConfig.Mode = "NONE"
|
||||
case "required":
|
||||
config.FunctionCallingConfig.Mode = "ANY"
|
||||
default:
|
||||
// Unknown string value, default to AUTO
|
||||
config.FunctionCallingConfig.Mode = "AUTO"
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// Handle object value: {"type": "function", "function": {"name": "xxx"}}
|
||||
if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
|
||||
if toolChoiceMap["type"] == "function" {
|
||||
config := &dto.ToolConfig{
|
||||
FunctionCallingConfig: &dto.FunctionCallingConfig{
|
||||
Mode: "ANY",
|
||||
},
|
||||
}
|
||||
// Extract function name if specified
|
||||
if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
|
||||
if name, ok := function["name"].(string); ok && name != "" {
|
||||
config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
// Unsupported map structure (type is not "function"), return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unsupported type, return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -229,6 +229,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
if choice.FinishReason == constant.FinishReasonContentFilter {
|
||||
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "openai_finish_reason=content_filter")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
forceFormat := false
|
||||
if info.ChannelSetting.ForceFormat {
|
||||
forceFormat = true
|
||||
|
||||
@@ -237,6 +237,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
}
|
||||
extraContent = append(extraContent, "上游无计费信息")
|
||||
}
|
||||
|
||||
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
@@ -461,6 +464,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
}
|
||||
logContent := strings.Join(extraContent, ", ")
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if adminRejectReason != "" {
|
||||
other["reject_reason"] = adminRejectReason
|
||||
}
|
||||
// For chat-based calls to the Claude model, tagging is required. Using Claude's rendering logs, the two approaches handle input rendering differently.
|
||||
if isClaudeUsageSemantic {
|
||||
other["claude"] = true
|
||||
|
||||
41
relay/reasonmap/reasonmap.go
Normal file
41
relay/reasonmap/reasonmap.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package reasonmap
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
)
|
||||
|
||||
func ClaudeStopReasonToOpenAIFinishReason(stopReason string) string {
|
||||
switch strings.ToLower(stopReason) {
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
case "tool_use":
|
||||
return "tool_calls"
|
||||
case "refusal":
|
||||
return constant.FinishReasonContentFilter
|
||||
default:
|
||||
return stopReason
|
||||
}
|
||||
}
|
||||
|
||||
func OpenAIFinishReasonToClaudeStopReason(finishReason string) string {
|
||||
switch strings.ToLower(finishReason) {
|
||||
case "stop":
|
||||
return "end_turn"
|
||||
case "stop_sequence":
|
||||
return "stop_sequence"
|
||||
case "length", "max_tokens":
|
||||
return "max_tokens"
|
||||
case constant.FinishReasonContentFilter:
|
||||
return "refusal"
|
||||
case "tool_calls":
|
||||
return "tool_use"
|
||||
default:
|
||||
return finishReason
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user