Compare commits

..

1 Commits

Author SHA1 Message Date
coderabbitai[bot]
40a3e19a78 📝 Add docstrings to fix/channel-test-responses-fallback
Docstrings generation was requested by @FlowerRealm.

* https://github.com/QuantumNous/new-api/pull/2501#issuecomment-3686382220

The following files were modified:

* `controller/channel-test.go`
* `relay/helper/valid_request.go`
* `service/error.go`
2025-12-23 11:56:30 +00:00
9 changed files with 93 additions and 53 deletions

View File

@@ -40,6 +40,13 @@ type testResult struct {
newAPIError *types.NewAPIError
}
// testChannel executes a test request against the given channel using the provided testModel and optional endpointType,
// and returns a testResult containing the test context and any encountered error information.
// It selects or derives a model when testModel is empty, auto-detects the request endpoint (chat, responses, embeddings, images, rerank) when endpointType is not specified,
// converts and relays the request to the upstream adapter, and parses the upstream response to collect usage and pricing information.
// On upstream responses that indicate the chat/completions `messages` parameter is unsupported and endpointType was not specified, it will retry the test using the Responses API.
// The function records consumption logs and returns a testResult with a populated context on success, or with localErr/newAPIError set on failure;
// for channel types that are not supported for testing it returns a localErr explaining that the channel test is not supported.
func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
tik := time.Now()
var unsupportedTestChannelTypes = []int{
@@ -75,6 +82,8 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
originTestModel := testModel
requestPath := "/v1/chat/completions"
// 如果指定了端点类型,使用指定的端点类型
@@ -84,6 +93,10 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
} else {
// 如果没有指定端点类型,使用原有的自动检测逻辑
if common.IsOpenAIResponseOnlyModel(testModel) {
requestPath = "/v1/responses"
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
@@ -319,6 +332,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
// 自动检测模式下,如果上游不支持 chat.completions 的 messages 参数,尝试切换到 Responses API 再测一次。
if endpointType == "" && requestPath == "/v1/chat/completions" && err != nil {
lowerErr := strings.ToLower(err.Error())
if strings.Contains(lowerErr, "unsupported parameter") && strings.Contains(lowerErr, "messages") {
return testChannel(channel, originTestModel, string(constant.EndpointTypeOpenAIResponse))
}
}
return testResult{
context: c,
localErr: err,
@@ -389,6 +409,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
// for embedding models, and otherwise a chat/completion request with model-specific token limit heuristics.
func buildTestRequest(model string, endpointType string) dto.Request {
// 根据端点类型构建不同的测试请求
if endpointType != "" {
@@ -417,9 +438,12 @@ func buildTestRequest(model string, endpointType string) dto.Request {
}
case constant.EndpointTypeOpenAIResponse:
// 返回 OpenAIResponsesRequest
maxOutputTokens := uint(10)
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage("\"hi\""),
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
MaxOutputTokens: maxOutputTokens,
Stream: true,
}
case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
// 返回 GeneralOpenAIRequest
@@ -442,6 +466,16 @@ func buildTestRequest(model string, endpointType string) dto.Request {
}
// 自动检测逻辑(保持原有行为)
if common.IsOpenAIResponseOnlyModel(model) {
maxOutputTokens := uint(10)
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
MaxOutputTokens: maxOutputTokens,
Stream: true,
}
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
strings.HasPrefix(model, "m3e") ||
@@ -640,4 +674,4 @@ func AutomaticallyTestChannels() {
}
}
})
}
}

View File

@@ -110,17 +110,18 @@ func setupLogin(user *model.User, c *gin.Context) {
})
return
}
cleanUser := model.User{
Id: user.Id,
Username: user.Username,
DisplayName: user.DisplayName,
Role: user.Role,
Status: user.Status,
Group: user.Group,
}
c.JSON(http.StatusOK, gin.H{
"message": "",
"success": true,
"data": map[string]any{
"id": user.Id,
"username": user.Username,
"display_name": user.DisplayName,
"role": user.Role,
"status": user.Status,
"group": user.Group,
},
"data": cleanUser,
})
}

View File

@@ -483,11 +483,9 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
}
}
} else if claudeResponse.Type == "message_delta" {
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
//claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {

View File

@@ -596,7 +596,7 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
case constant.ChannelTypeZhipu_v4, constant.ChannelTypeMoonshot:
case constant.ChannelTypeZhipu_v4:
if usage.PromptTokensDetails.CachedTokens == 0 {
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens

View File

@@ -300,20 +300,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
if !relayInfo.PriceData.UsePrice {
baseTokens := dPromptTokens
// 减去 cached tokens
// Anthropic API 的 input_tokens 已经不包含缓存 tokens不需要减去
// OpenAI/OpenRouter 等 API 的 prompt_tokens 包含缓存 tokens需要减去
var cachedTokensWithRatio decimal.Decimal
if !dCacheTokens.IsZero() {
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
baseTokens = baseTokens.Sub(dCacheTokens)
}
baseTokens = baseTokens.Sub(dCacheTokens)
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
}
var dCachedCreationTokensWithRatio decimal.Decimal
if !dCachedCreationTokens.IsZero() {
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
baseTokens = baseTokens.Sub(dCachedCreationTokens)
}
baseTokens = baseTokens.Sub(dCachedCreationTokens)
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
}

View File

@@ -110,6 +110,8 @@ func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.Embeddi
return embeddingRequest, nil
}
// GetAndValidateResponsesRequest parses the HTTP request body into an OpenAIResponsesRequest and ensures the Model field is provided.
// It returns the parsed request, or an error if the body cannot be parsed or the Model is empty.
func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
request := &dto.OpenAIResponsesRequest{}
err := common.UnmarshalBodyReusable(c, request)
@@ -119,9 +121,6 @@ func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest
if request.Model == "" {
return nil, errors.New("model is required")
}
if request.Input == nil {
return nil, errors.New("input is required")
}
return request, nil
}
@@ -324,4 +323,4 @@ func GetAndValidateGeminiBatchEmbeddingRequest(c *gin.Context) (*dto.GeminiBatch
return nil, err
}
return request, nil
}
}

View File

@@ -389,29 +389,25 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
}
idx := blockIndex
if toolCall.Function.Name != "" {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &idx,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Id: toolCall.ID,
Type: "tool_use",
Name: toolCall.Function.Name,
Input: map[string]interface{}{},
},
})
}
if len(toolCall.Function.Arguments) > 0 {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &idx,
Type: "content_block_delta",
Delta: &dto.ClaudeMediaMessage{
Type: "input_json_delta",
PartialJson: &toolCall.Function.Arguments,
},
})
}
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &idx,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Id: toolCall.ID,
Type: "tool_use",
Name: toolCall.Function.Name,
Input: map[string]interface{}{},
},
})
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &idx,
Type: "content_block_delta",
Delta: &dto.ClaudeMediaMessage{
Type: "input_json_delta",
PartialJson: &toolCall.Function.Arguments,
},
})
info.ClaudeConvertInfo.Index = blockIndex
}

View File

@@ -81,11 +81,24 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
return claudeErr
}
// RelayErrorHandler converts an HTTP error response into a structured types.NewAPIError.
// It returns a NewAPIError initialized with the response status code and one of:
// - an Err describing an absent or unreadable body,
// - an Err containing the unmarshaled error message (or status + raw body when showBodyWhenFail is true), or
// - an embedded OpenAI-style error when the response body contains a compatible error object.
// The returned NewAPIError's status code reflects resp.StatusCode.
func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
if resp.Body == nil {
newApiErr.Err = errors.New("response body is nil")
return
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
CloseResponseBodyGracefully(resp)
newApiErr.Err = fmt.Errorf("read response body failed: %w", err)
return
}
CloseResponseBodyGracefully(resp)
@@ -156,4 +169,4 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
}
return taskError
}
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
)
// from songquanpeng/one-api
@@ -828,6 +829,10 @@ func FormatMatchingModelName(name string) string {
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
}
if base, _, ok := reasoning.TrimEffortSuffix(name); ok {
name = base
}
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}