mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 07:57:03 +00:00
* fix: test using the correct path for rerank. * fix: The `input` parameter for testing responses uses an array to accommodate certain channels, such as Codex, which are incompatible with single strings.
703 lines
21 KiB
Go
703 lines
21 KiB
Go
package controller
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/constant"
|
|
"github.com/QuantumNous/new-api/dto"
|
|
"github.com/QuantumNous/new-api/middleware"
|
|
"github.com/QuantumNous/new-api/model"
|
|
"github.com/QuantumNous/new-api/relay"
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
|
"github.com/QuantumNous/new-api/relay/helper"
|
|
"github.com/QuantumNous/new-api/service"
|
|
"github.com/QuantumNous/new-api/setting/operation_setting"
|
|
"github.com/QuantumNous/new-api/types"
|
|
|
|
"github.com/bytedance/gopkg/util/gopool"
|
|
"github.com/samber/lo"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type testResult struct {
|
|
context *gin.Context
|
|
localErr error
|
|
newAPIError *types.NewAPIError
|
|
}
|
|
|
|
func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
|
|
tik := time.Now()
|
|
var unsupportedTestChannelTypes = []int{
|
|
constant.ChannelTypeMidjourney,
|
|
constant.ChannelTypeMidjourneyPlus,
|
|
constant.ChannelTypeSunoAPI,
|
|
constant.ChannelTypeKling,
|
|
constant.ChannelTypeJimeng,
|
|
constant.ChannelTypeDoubaoVideo,
|
|
constant.ChannelTypeVidu,
|
|
}
|
|
if lo.Contains(unsupportedTestChannelTypes, channel.Type) {
|
|
channelTypeName := constant.GetChannelTypeName(channel.Type)
|
|
return testResult{
|
|
localErr: fmt.Errorf("%s channel test is not supported", channelTypeName),
|
|
}
|
|
}
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
|
|
testModel = strings.TrimSpace(testModel)
|
|
if testModel == "" {
|
|
if channel.TestModel != nil && *channel.TestModel != "" {
|
|
testModel = strings.TrimSpace(*channel.TestModel)
|
|
} else {
|
|
models := channel.GetModels()
|
|
if len(models) > 0 {
|
|
testModel = strings.TrimSpace(models[0])
|
|
}
|
|
if testModel == "" {
|
|
testModel = "gpt-4o-mini"
|
|
}
|
|
}
|
|
}
|
|
|
|
requestPath := "/v1/chat/completions"
|
|
|
|
// 如果指定了端点类型,使用指定的端点类型
|
|
if endpointType != "" {
|
|
if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok {
|
|
requestPath = endpointInfo.Path
|
|
}
|
|
} else {
|
|
// 如果没有指定端点类型,使用原有的自动检测逻辑
|
|
|
|
if strings.Contains(strings.ToLower(testModel), "rerank") {
|
|
requestPath = "/v1/rerank"
|
|
}
|
|
|
|
// 先判断是否为 Embedding 模型
|
|
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
|
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
|
strings.Contains(testModel, "bge-") || // bge 系列模型
|
|
strings.Contains(testModel, "embed") ||
|
|
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
|
requestPath = "/v1/embeddings" // 修改请求路径
|
|
}
|
|
|
|
// VolcEngine 图像生成模型
|
|
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
|
requestPath = "/v1/images/generations"
|
|
}
|
|
|
|
// responses-only models
|
|
if strings.Contains(strings.ToLower(testModel), "codex") {
|
|
requestPath = "/v1/responses"
|
|
}
|
|
}
|
|
|
|
c.Request = &http.Request{
|
|
Method: "POST",
|
|
URL: &url.URL{Path: requestPath}, // 使用动态路径
|
|
Body: nil,
|
|
Header: make(http.Header),
|
|
}
|
|
|
|
cache, err := model.GetUserCache(1)
|
|
if err != nil {
|
|
return testResult{
|
|
localErr: err,
|
|
newAPIError: nil,
|
|
}
|
|
}
|
|
cache.WriteContext(c)
|
|
|
|
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
|
c.Request.Header.Set("Content-Type", "application/json")
|
|
c.Set("channel", channel.Type)
|
|
c.Set("base_url", channel.GetBaseURL())
|
|
group, _ := model.GetUserGroup(1, false)
|
|
c.Set("group", group)
|
|
|
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
|
if newAPIError != nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: newAPIError,
|
|
newAPIError: newAPIError,
|
|
}
|
|
}
|
|
|
|
// Determine relay format based on endpoint type or request path
|
|
var relayFormat types.RelayFormat
|
|
if endpointType != "" {
|
|
// 根据指定的端点类型设置 relayFormat
|
|
switch constant.EndpointType(endpointType) {
|
|
case constant.EndpointTypeOpenAI:
|
|
relayFormat = types.RelayFormatOpenAI
|
|
case constant.EndpointTypeOpenAIResponse:
|
|
relayFormat = types.RelayFormatOpenAIResponses
|
|
case constant.EndpointTypeAnthropic:
|
|
relayFormat = types.RelayFormatClaude
|
|
case constant.EndpointTypeGemini:
|
|
relayFormat = types.RelayFormatGemini
|
|
case constant.EndpointTypeJinaRerank:
|
|
relayFormat = types.RelayFormatRerank
|
|
case constant.EndpointTypeImageGeneration:
|
|
relayFormat = types.RelayFormatOpenAIImage
|
|
case constant.EndpointTypeEmbeddings:
|
|
relayFormat = types.RelayFormatEmbedding
|
|
default:
|
|
relayFormat = types.RelayFormatOpenAI
|
|
}
|
|
} else {
|
|
// 根据请求路径自动检测
|
|
relayFormat = types.RelayFormatOpenAI
|
|
if c.Request.URL.Path == "/v1/embeddings" {
|
|
relayFormat = types.RelayFormatEmbedding
|
|
}
|
|
if c.Request.URL.Path == "/v1/images/generations" {
|
|
relayFormat = types.RelayFormatOpenAIImage
|
|
}
|
|
if c.Request.URL.Path == "/v1/messages" {
|
|
relayFormat = types.RelayFormatClaude
|
|
}
|
|
if strings.Contains(c.Request.URL.Path, "/v1beta/models") {
|
|
relayFormat = types.RelayFormatGemini
|
|
}
|
|
if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" {
|
|
relayFormat = types.RelayFormatRerank
|
|
}
|
|
if c.Request.URL.Path == "/v1/responses" {
|
|
relayFormat = types.RelayFormatOpenAIResponses
|
|
}
|
|
}
|
|
|
|
request := buildTestRequest(testModel, endpointType, channel)
|
|
|
|
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
|
|
|
if err != nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: err,
|
|
newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
|
|
}
|
|
}
|
|
|
|
info.IsChannelTest = true
|
|
info.InitChannelMeta(c)
|
|
|
|
err = helper.ModelMappedHelper(c, info, request)
|
|
if err != nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: err,
|
|
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
|
}
|
|
}
|
|
|
|
testModel = info.UpstreamModelName
|
|
// 更新请求中的模型名称
|
|
request.SetModelName(testModel)
|
|
|
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
|
adaptor := relay.GetAdaptor(apiType)
|
|
if adaptor == nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
|
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
|
}
|
|
}
|
|
|
|
//// 创建一个用于日志的 info 副本,移除 ApiKey
|
|
//logInfo := info
|
|
//logInfo.ApiKey = ""
|
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
|
|
|
|
priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
|
|
if err != nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: err,
|
|
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
|
}
|
|
}
|
|
|
|
adaptor.Init(info)
|
|
|
|
var convertedRequest any
|
|
// 根据 RelayMode 选择正确的转换函数
|
|
switch info.RelayMode {
|
|
case relayconstant.RelayModeEmbeddings:
|
|
// Embedding 请求 - request 已经是正确的类型
|
|
if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok {
|
|
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq)
|
|
} else {
|
|
return testResult{
|
|
context: c,
|
|
localErr: errors.New("invalid embedding request type"),
|
|
newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed),
|
|
}
|
|
}
|
|
case relayconstant.RelayModeImagesGenerations:
|
|
// 图像生成请求 - request 已经是正确的类型
|
|
if imageReq, ok := request.(*dto.ImageRequest); ok {
|
|
convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq)
|
|
} else {
|
|
return testResult{
|
|
context: c,
|
|
localErr: errors.New("invalid image request type"),
|
|
newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed),
|
|
}
|
|
}
|
|
case relayconstant.RelayModeRerank:
|
|
// Rerank 请求 - request 已经是正确的类型
|
|
if rerankReq, ok := request.(*dto.RerankRequest); ok {
|
|
convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq)
|
|
} else {
|
|
return testResult{
|
|
context: c,
|
|
localErr: errors.New("invalid rerank request type"),
|
|
newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed),
|
|
}
|
|
}
|
|
case relayconstant.RelayModeResponses:
|
|
// Response 请求 - request 已经是正确的类型
|
|
if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
|
convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq)
|
|
} else {
|
|
return testResult{
|
|
context: c,
|
|
localErr: errors.New("invalid response request type"),
|
|
newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed),
|
|
}
|
|
}
|
|
default:
|
|
// Chat/Completion 等其他请求类型
|
|
if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok {
|
|
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq)
|
|
} else {
|
|
return testResult{
|
|
context: c,
|
|
localErr: errors.New("invalid general request type"),
|
|
newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed),
|
|
}
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: err,
|
|
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
|
}
|
|
}
|
|
jsonData, err := json.Marshal(convertedRequest)
|
|
if err != nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: err,
|
|
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
|
}
|
|
}
|
|
|
|
//jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
|
//if err != nil {
|
|
// return testResult{
|
|
// context: c,
|
|
// localErr: err,
|
|
// newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
|
// }
|
|
//}
|
|
|
|
if len(info.ParamOverride) > 0 {
|
|
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
|
if err != nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: err,
|
|
newAPIError: types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid),
|
|
}
|
|
}
|
|
}
|
|
|
|
requestBody := bytes.NewBuffer(jsonData)
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
|
resp, err := adaptor.DoRequest(c, info, requestBody)
|
|
if err != nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: err,
|
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
|
}
|
|
}
|
|
var httpResp *http.Response
|
|
if resp != nil {
|
|
httpResp = resp.(*http.Response)
|
|
if httpResp.StatusCode != http.StatusOK {
|
|
err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
|
|
common.SysError(fmt.Sprintf(
|
|
"channel test bad response: channel_id=%d name=%s type=%d model=%s endpoint_type=%s status=%d err=%v",
|
|
channel.Id,
|
|
channel.Name,
|
|
channel.Type,
|
|
testModel,
|
|
endpointType,
|
|
httpResp.StatusCode,
|
|
err,
|
|
))
|
|
return testResult{
|
|
context: c,
|
|
localErr: err,
|
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
|
}
|
|
}
|
|
}
|
|
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
|
if respErr != nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: respErr,
|
|
newAPIError: respErr,
|
|
}
|
|
}
|
|
if usageA == nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: errors.New("usage is nil"),
|
|
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
|
}
|
|
}
|
|
usage := usageA.(*dto.Usage)
|
|
result := w.Result()
|
|
respBody, err := io.ReadAll(result.Body)
|
|
if err != nil {
|
|
return testResult{
|
|
context: c,
|
|
localErr: err,
|
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
|
}
|
|
}
|
|
info.SetEstimatePromptTokens(usage.PromptTokens)
|
|
|
|
quota := 0
|
|
if !priceData.UsePrice {
|
|
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
|
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
|
if priceData.ModelRatio != 0 && quota <= 0 {
|
|
quota = 1
|
|
}
|
|
} else {
|
|
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
|
|
}
|
|
tok := time.Now()
|
|
milliseconds := tok.Sub(tik).Milliseconds()
|
|
consumedTime := float64(milliseconds) / 1000.0
|
|
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
|
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
|
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
|
ChannelId: channel.Id,
|
|
PromptTokens: usage.PromptTokens,
|
|
CompletionTokens: usage.CompletionTokens,
|
|
ModelName: info.OriginModelName,
|
|
TokenName: "模型测试",
|
|
Quota: quota,
|
|
Content: "模型测试",
|
|
UseTimeSeconds: int(consumedTime),
|
|
IsStream: info.IsStream,
|
|
Group: info.UsingGroup,
|
|
Other: other,
|
|
})
|
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
|
return testResult{
|
|
context: c,
|
|
localErr: nil,
|
|
newAPIError: nil,
|
|
}
|
|
}
|
|
|
|
func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request {
|
|
// 根据端点类型构建不同的测试请求
|
|
if endpointType != "" {
|
|
switch constant.EndpointType(endpointType) {
|
|
case constant.EndpointTypeEmbeddings:
|
|
// 返回 EmbeddingRequest
|
|
return &dto.EmbeddingRequest{
|
|
Model: model,
|
|
Input: []any{"hello world"},
|
|
}
|
|
case constant.EndpointTypeImageGeneration:
|
|
// 返回 ImageRequest
|
|
return &dto.ImageRequest{
|
|
Model: model,
|
|
Prompt: "a cute cat",
|
|
N: 1,
|
|
Size: "1024x1024",
|
|
}
|
|
case constant.EndpointTypeJinaRerank:
|
|
// 返回 RerankRequest
|
|
return &dto.RerankRequest{
|
|
Model: model,
|
|
Query: "What is Deep Learning?",
|
|
Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
|
|
TopN: 2,
|
|
}
|
|
case constant.EndpointTypeOpenAIResponse:
|
|
// 返回 OpenAIResponsesRequest
|
|
return &dto.OpenAIResponsesRequest{
|
|
Model: model,
|
|
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
|
}
|
|
case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
|
|
// 返回 GeneralOpenAIRequest
|
|
maxTokens := uint(16)
|
|
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
|
|
maxTokens = 3000
|
|
}
|
|
return &dto.GeneralOpenAIRequest{
|
|
Model: model,
|
|
Stream: false,
|
|
Messages: []dto.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "hi",
|
|
},
|
|
},
|
|
MaxTokens: maxTokens,
|
|
}
|
|
}
|
|
}
|
|
|
|
// 自动检测逻辑(保持原有行为)
|
|
if strings.Contains(strings.ToLower(model), "rerank") {
|
|
return &dto.RerankRequest{
|
|
Model: model,
|
|
Query: "What is Deep Learning?",
|
|
Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
|
|
TopN: 2,
|
|
}
|
|
}
|
|
|
|
// 先判断是否为 Embedding 模型
|
|
if strings.Contains(strings.ToLower(model), "embedding") ||
|
|
strings.HasPrefix(model, "m3e") ||
|
|
strings.Contains(model, "bge-") {
|
|
// 返回 EmbeddingRequest
|
|
return &dto.EmbeddingRequest{
|
|
Model: model,
|
|
Input: []any{"hello world"},
|
|
}
|
|
}
|
|
|
|
// Responses-only models (e.g. codex series)
|
|
if strings.Contains(strings.ToLower(model), "codex") {
|
|
return &dto.OpenAIResponsesRequest{
|
|
Model: model,
|
|
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
|
}
|
|
}
|
|
|
|
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
|
|
testRequest := &dto.GeneralOpenAIRequest{
|
|
Model: model,
|
|
Stream: false,
|
|
Messages: []dto.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "hi",
|
|
},
|
|
},
|
|
}
|
|
|
|
if strings.HasPrefix(model, "o") {
|
|
testRequest.MaxCompletionTokens = 16
|
|
} else if strings.Contains(model, "thinking") {
|
|
if !strings.Contains(model, "claude") {
|
|
testRequest.MaxTokens = 50
|
|
}
|
|
} else if strings.Contains(model, "gemini") {
|
|
testRequest.MaxTokens = 3000
|
|
} else {
|
|
testRequest.MaxTokens = 16
|
|
}
|
|
|
|
return testRequest
|
|
}
|
|
|
|
func TestChannel(c *gin.Context) {
|
|
channelId, err := strconv.Atoi(c.Param("id"))
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
channel, err := model.CacheGetChannel(channelId)
|
|
if err != nil {
|
|
channel, err = model.GetChannelById(channelId, true)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
}
|
|
//defer func() {
|
|
// if channel.ChannelInfo.IsMultiKey {
|
|
// go func() { _ = channel.SaveChannelInfo() }()
|
|
// }
|
|
//}()
|
|
testModel := c.Query("model")
|
|
endpointType := c.Query("endpoint_type")
|
|
tik := time.Now()
|
|
result := testChannel(channel, testModel, endpointType)
|
|
if result.localErr != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": result.localErr.Error(),
|
|
"time": 0.0,
|
|
})
|
|
return
|
|
}
|
|
tok := time.Now()
|
|
milliseconds := tok.Sub(tik).Milliseconds()
|
|
go channel.UpdateResponseTime(milliseconds)
|
|
consumedTime := float64(milliseconds) / 1000.0
|
|
if result.newAPIError != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": result.newAPIError.Error(),
|
|
"time": consumedTime,
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"time": consumedTime,
|
|
})
|
|
}
|
|
|
|
var testAllChannelsLock sync.Mutex
|
|
var testAllChannelsRunning bool = false
|
|
|
|
func testAllChannels(notify bool) error {
|
|
|
|
testAllChannelsLock.Lock()
|
|
if testAllChannelsRunning {
|
|
testAllChannelsLock.Unlock()
|
|
return errors.New("测试已在运行中")
|
|
}
|
|
testAllChannelsRunning = true
|
|
testAllChannelsLock.Unlock()
|
|
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
|
if getChannelErr != nil {
|
|
return getChannelErr
|
|
}
|
|
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
|
if disableThreshold == 0 {
|
|
disableThreshold = 10000000 // a impossible value
|
|
}
|
|
gopool.Go(func() {
|
|
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
|
|
defer func() {
|
|
testAllChannelsLock.Lock()
|
|
testAllChannelsRunning = false
|
|
testAllChannelsLock.Unlock()
|
|
}()
|
|
|
|
for _, channel := range channels {
|
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
|
tik := time.Now()
|
|
result := testChannel(channel, "", "")
|
|
tok := time.Now()
|
|
milliseconds := tok.Sub(tik).Milliseconds()
|
|
|
|
shouldBanChannel := false
|
|
newAPIError := result.newAPIError
|
|
// request error disables the channel
|
|
if newAPIError != nil {
|
|
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
|
}
|
|
|
|
// 当错误检查通过,才检查响应时间
|
|
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
|
if milliseconds > disableThreshold {
|
|
err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
|
|
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
|
shouldBanChannel = true
|
|
}
|
|
}
|
|
|
|
// disable channel
|
|
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
|
processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
|
}
|
|
|
|
// enable channel
|
|
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
|
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
|
}
|
|
|
|
channel.UpdateResponseTime(milliseconds)
|
|
time.Sleep(common.RequestInterval)
|
|
}
|
|
|
|
if notify {
|
|
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
|
}
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func TestAllChannels(c *gin.Context) {
|
|
err := testAllChannels(true)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
})
|
|
}
|
|
|
|
var autoTestChannelsOnce sync.Once
|
|
|
|
func AutomaticallyTestChannels() {
|
|
// 只在Master节点定时测试渠道
|
|
if !common.IsMasterNode {
|
|
return
|
|
}
|
|
autoTestChannelsOnce.Do(func() {
|
|
for {
|
|
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
|
time.Sleep(1 * time.Minute)
|
|
continue
|
|
}
|
|
for {
|
|
frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
|
|
time.Sleep(time.Duration(int(math.Round(frequency))) * time.Minute)
|
|
common.SysLog(fmt.Sprintf("automatically test channels with interval %f minutes", frequency))
|
|
common.SysLog("automatically testing all channels")
|
|
_ = testAllChannels(false)
|
|
common.SysLog("automatically channel test finished")
|
|
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|