mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 09:48:39 +00:00
feat: add endpoint type selection to channel testing functionality
This commit is contained in:
@@ -23,6 +23,7 @@ var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
|
|||||||
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
|
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
|
||||||
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
|
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
|
||||||
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
|
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
|
||||||
|
constant.EndpointTypeEmbeddings: {Path: "/v1/embeddings", Method: "POST"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
|
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ const (
|
|||||||
EndpointTypeGemini EndpointType = "gemini"
|
EndpointTypeGemini EndpointType = "gemini"
|
||||||
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||||
EndpointTypeImageGeneration EndpointType = "image-generation"
|
EndpointTypeImageGeneration EndpointType = "image-generation"
|
||||||
|
EndpointTypeEmbeddings EndpointType = "embeddings"
|
||||||
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
||||||
//EndpointTypeSuno EndpointType = "suno-proxy"
|
//EndpointTypeSuno EndpointType = "suno-proxy"
|
||||||
//EndpointTypeKling EndpointType = "kling"
|
//EndpointTypeKling EndpointType = "kling"
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ type testResult struct {
|
|||||||
newAPIError *types.NewAPIError
|
newAPIError *types.NewAPIError
|
||||||
}
|
}
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) testResult {
|
func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
if channel.Type == constant.ChannelTypeMidjourney {
|
if channel.Type == constant.ChannelTypeMidjourney {
|
||||||
return testResult{
|
return testResult{
|
||||||
@@ -81,18 +81,26 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
|
|
||||||
requestPath := "/v1/chat/completions"
|
requestPath := "/v1/chat/completions"
|
||||||
|
|
||||||
// 先判断是否为 Embedding 模型
|
// 如果指定了端点类型,使用指定的端点类型
|
||||||
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
if endpointType != "" {
|
||||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok {
|
||||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
requestPath = endpointInfo.Path
|
||||||
strings.Contains(testModel, "embed") ||
|
}
|
||||||
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
} else {
|
||||||
requestPath = "/v1/embeddings" // 修改请求路径
|
// 如果没有指定端点类型,使用原有的自动检测逻辑
|
||||||
}
|
// 先判断是否为 Embedding 模型
|
||||||
|
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
||||||
|
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||||
|
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||||
|
strings.Contains(testModel, "embed") ||
|
||||||
|
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||||
|
requestPath = "/v1/embeddings" // 修改请求路径
|
||||||
|
}
|
||||||
|
|
||||||
// VolcEngine 图像生成模型
|
// VolcEngine 图像生成模型
|
||||||
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
||||||
requestPath = "/v1/images/generations"
|
requestPath = "/v1/images/generations"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request = &http.Request{
|
c.Request = &http.Request{
|
||||||
@@ -114,21 +122,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 重新检查模型类型并更新请求路径
|
|
||||||
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
|
||||||
strings.HasPrefix(testModel, "m3e") ||
|
|
||||||
strings.Contains(testModel, "bge-") ||
|
|
||||||
strings.Contains(testModel, "embed") ||
|
|
||||||
channel.Type == constant.ChannelTypeMokaAI {
|
|
||||||
requestPath = "/v1/embeddings"
|
|
||||||
c.Request.URL.Path = requestPath
|
|
||||||
}
|
|
||||||
|
|
||||||
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
|
||||||
requestPath = "/v1/images/generations"
|
|
||||||
c.Request.URL.Path = requestPath
|
|
||||||
}
|
|
||||||
|
|
||||||
cache, err := model.GetUserCache(1)
|
cache, err := model.GetUserCache(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return testResult{
|
return testResult{
|
||||||
@@ -153,17 +146,54 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
newAPIError: newAPIError,
|
newAPIError: newAPIError,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
request := buildTestRequest(testModel)
|
|
||||||
|
|
||||||
// Determine relay format based on request path
|
// Determine relay format based on endpoint type or request path
|
||||||
relayFormat := types.RelayFormatOpenAI
|
var relayFormat types.RelayFormat
|
||||||
if c.Request.URL.Path == "/v1/embeddings" {
|
if endpointType != "" {
|
||||||
relayFormat = types.RelayFormatEmbedding
|
// 根据指定的端点类型设置 relayFormat
|
||||||
}
|
switch constant.EndpointType(endpointType) {
|
||||||
if c.Request.URL.Path == "/v1/images/generations" {
|
case constant.EndpointTypeOpenAI:
|
||||||
relayFormat = types.RelayFormatOpenAIImage
|
relayFormat = types.RelayFormatOpenAI
|
||||||
|
case constant.EndpointTypeOpenAIResponse:
|
||||||
|
relayFormat = types.RelayFormatOpenAIResponses
|
||||||
|
case constant.EndpointTypeAnthropic:
|
||||||
|
relayFormat = types.RelayFormatClaude
|
||||||
|
case constant.EndpointTypeGemini:
|
||||||
|
relayFormat = types.RelayFormatGemini
|
||||||
|
case constant.EndpointTypeJinaRerank:
|
||||||
|
relayFormat = types.RelayFormatRerank
|
||||||
|
case constant.EndpointTypeImageGeneration:
|
||||||
|
relayFormat = types.RelayFormatOpenAIImage
|
||||||
|
case constant.EndpointTypeEmbeddings:
|
||||||
|
relayFormat = types.RelayFormatEmbedding
|
||||||
|
default:
|
||||||
|
relayFormat = types.RelayFormatOpenAI
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 根据请求路径自动检测
|
||||||
|
relayFormat = types.RelayFormatOpenAI
|
||||||
|
if c.Request.URL.Path == "/v1/embeddings" {
|
||||||
|
relayFormat = types.RelayFormatEmbedding
|
||||||
|
}
|
||||||
|
if c.Request.URL.Path == "/v1/images/generations" {
|
||||||
|
relayFormat = types.RelayFormatOpenAIImage
|
||||||
|
}
|
||||||
|
if c.Request.URL.Path == "/v1/messages" {
|
||||||
|
relayFormat = types.RelayFormatClaude
|
||||||
|
}
|
||||||
|
if strings.Contains(c.Request.URL.Path, "/v1beta/models") {
|
||||||
|
relayFormat = types.RelayFormatGemini
|
||||||
|
}
|
||||||
|
if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" {
|
||||||
|
relayFormat = types.RelayFormatRerank
|
||||||
|
}
|
||||||
|
if c.Request.URL.Path == "/v1/responses" {
|
||||||
|
relayFormat = types.RelayFormatOpenAIResponses
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request := buildTestRequest(testModel, endpointType)
|
||||||
|
|
||||||
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -186,7 +216,8 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testModel = info.UpstreamModelName
|
testModel = info.UpstreamModelName
|
||||||
request.Model = testModel
|
// 更新请求中的模型名称
|
||||||
|
request.SetModelName(testModel)
|
||||||
|
|
||||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
@@ -216,33 +247,62 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
|
|
||||||
var convertedRequest any
|
var convertedRequest any
|
||||||
// 根据 RelayMode 选择正确的转换函数
|
// 根据 RelayMode 选择正确的转换函数
|
||||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
switch info.RelayMode {
|
||||||
// 创建一个 EmbeddingRequest
|
case relayconstant.RelayModeEmbeddings:
|
||||||
embeddingRequest := dto.EmbeddingRequest{
|
// Embedding 请求 - request 已经是正确的类型
|
||||||
Input: request.Input,
|
if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok {
|
||||||
Model: request.Model,
|
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq)
|
||||||
}
|
} else {
|
||||||
// 调用专门用于 Embedding 的转换函数
|
return testResult{
|
||||||
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
|
context: c,
|
||||||
} else if info.RelayMode == relayconstant.RelayModeImagesGenerations {
|
localErr: errors.New("invalid embedding request type"),
|
||||||
// 创建一个 ImageRequest
|
newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed),
|
||||||
prompt := "cat"
|
|
||||||
if request.Prompt != nil {
|
|
||||||
if promptStr, ok := request.Prompt.(string); ok && promptStr != "" {
|
|
||||||
prompt = promptStr
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
imageRequest := dto.ImageRequest{
|
case relayconstant.RelayModeImagesGenerations:
|
||||||
Prompt: prompt,
|
// 图像生成请求 - request 已经是正确的类型
|
||||||
Model: request.Model,
|
if imageReq, ok := request.(*dto.ImageRequest); ok {
|
||||||
N: uint(request.N),
|
convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq)
|
||||||
Size: request.Size,
|
} else {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: errors.New("invalid image request type"),
|
||||||
|
newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case relayconstant.RelayModeRerank:
|
||||||
|
// Rerank 请求 - request 已经是正确的类型
|
||||||
|
if rerankReq, ok := request.(*dto.RerankRequest); ok {
|
||||||
|
convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq)
|
||||||
|
} else {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: errors.New("invalid rerank request type"),
|
||||||
|
newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case relayconstant.RelayModeResponses:
|
||||||
|
// Response 请求 - request 已经是正确的类型
|
||||||
|
if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
||||||
|
convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq)
|
||||||
|
} else {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: errors.New("invalid response request type"),
|
||||||
|
newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Chat/Completion 等其他请求类型
|
||||||
|
if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok {
|
||||||
|
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq)
|
||||||
|
} else {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: errors.New("invalid general request type"),
|
||||||
|
newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// 调用专门用于图像生成的转换函数
|
|
||||||
convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest)
|
|
||||||
} else {
|
|
||||||
// 对其他所有请求类型(如 Chat),保持原有逻辑
|
|
||||||
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -345,22 +405,82 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
func buildTestRequest(model string, endpointType string) dto.Request {
|
||||||
testRequest := &dto.GeneralOpenAIRequest{
|
// 根据端点类型构建不同的测试请求
|
||||||
Model: "", // this will be set later
|
if endpointType != "" {
|
||||||
Stream: false,
|
switch constant.EndpointType(endpointType) {
|
||||||
|
case constant.EndpointTypeEmbeddings:
|
||||||
|
// 返回 EmbeddingRequest
|
||||||
|
return &dto.EmbeddingRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: []any{"hello world"},
|
||||||
|
}
|
||||||
|
case constant.EndpointTypeImageGeneration:
|
||||||
|
// 返回 ImageRequest
|
||||||
|
return &dto.ImageRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: "a cute cat",
|
||||||
|
N: 1,
|
||||||
|
Size: "1024x1024",
|
||||||
|
}
|
||||||
|
case constant.EndpointTypeJinaRerank:
|
||||||
|
// 返回 RerankRequest
|
||||||
|
return &dto.RerankRequest{
|
||||||
|
Model: model,
|
||||||
|
Query: "What is Deep Learning?",
|
||||||
|
Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
|
||||||
|
TopN: 2,
|
||||||
|
}
|
||||||
|
case constant.EndpointTypeOpenAIResponse:
|
||||||
|
// 返回 OpenAIResponsesRequest
|
||||||
|
return &dto.OpenAIResponsesRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: json.RawMessage("\"hi\""),
|
||||||
|
}
|
||||||
|
case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
|
||||||
|
// 返回 GeneralOpenAIRequest
|
||||||
|
maxTokens := uint(10)
|
||||||
|
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
|
||||||
|
maxTokens = 3000
|
||||||
|
}
|
||||||
|
return &dto.GeneralOpenAIRequest{
|
||||||
|
Model: model,
|
||||||
|
Stream: false,
|
||||||
|
Messages: []dto.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hi",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MaxTokens: maxTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 自动检测逻辑(保持原有行为)
|
||||||
// 先判断是否为 Embedding 模型
|
// 先判断是否为 Embedding 模型
|
||||||
if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
|
if strings.Contains(strings.ToLower(model), "embedding") ||
|
||||||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
|
strings.HasPrefix(model, "m3e") ||
|
||||||
strings.Contains(model, "bge-") {
|
strings.Contains(model, "bge-") {
|
||||||
testRequest.Model = model
|
// 返回 EmbeddingRequest
|
||||||
// Embedding 请求
|
return &dto.EmbeddingRequest{
|
||||||
testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
|
Model: model,
|
||||||
return testRequest
|
Input: []any{"hello world"},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// 并非Embedding 模型
|
|
||||||
|
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
|
||||||
|
testRequest := &dto.GeneralOpenAIRequest{
|
||||||
|
Model: model,
|
||||||
|
Stream: false,
|
||||||
|
Messages: []dto.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hi",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(model, "o") {
|
if strings.HasPrefix(model, "o") {
|
||||||
testRequest.MaxCompletionTokens = 10
|
testRequest.MaxCompletionTokens = 10
|
||||||
} else if strings.Contains(model, "thinking") {
|
} else if strings.Contains(model, "thinking") {
|
||||||
@@ -373,12 +493,6 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
|||||||
testRequest.MaxTokens = 10
|
testRequest.MaxTokens = 10
|
||||||
}
|
}
|
||||||
|
|
||||||
testMessage := dto.Message{
|
|
||||||
Role: "user",
|
|
||||||
Content: "hi",
|
|
||||||
}
|
|
||||||
testRequest.Model = model
|
|
||||||
testRequest.Messages = append(testRequest.Messages, testMessage)
|
|
||||||
return testRequest
|
return testRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -402,8 +516,9 @@ func TestChannel(c *gin.Context) {
|
|||||||
// }
|
// }
|
||||||
//}()
|
//}()
|
||||||
testModel := c.Query("model")
|
testModel := c.Query("model")
|
||||||
|
endpointType := c.Query("endpoint_type")
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
result := testChannel(channel, testModel)
|
result := testChannel(channel, testModel, endpointType)
|
||||||
if result.localErr != nil {
|
if result.localErr != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -429,7 +544,6 @@ func TestChannel(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var testAllChannelsLock sync.Mutex
|
var testAllChannelsLock sync.Mutex
|
||||||
@@ -463,7 +577,7 @@ func testAllChannels(notify bool) error {
|
|||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
result := testChannel(channel, "")
|
result := testChannel(channel, "", "")
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
|
||||||
@@ -477,7 +591,7 @@ func testAllChannels(notify bool) error {
|
|||||||
// 当错误检查通过,才检查响应时间
|
// 当错误检查通过,才检查响应时间
|
||||||
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||||
if milliseconds > disableThreshold {
|
if milliseconds > disableThreshold {
|
||||||
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
|
||||||
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
||||||
shouldBanChannel = true
|
shouldBanChannel = true
|
||||||
}
|
}
|
||||||
@@ -514,7 +628,6 @@ func TestAllChannels(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var autoTestChannelsOnce sync.Once
|
var autoTestChannelsOnce sync.Once
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import {
|
|||||||
Table,
|
Table,
|
||||||
Tag,
|
Tag,
|
||||||
Typography,
|
Typography,
|
||||||
|
Select,
|
||||||
} from '@douyinfe/semi-ui';
|
} from '@douyinfe/semi-ui';
|
||||||
import { IconSearch } from '@douyinfe/semi-icons';
|
import { IconSearch } from '@douyinfe/semi-icons';
|
||||||
import { copy, showError, showInfo, showSuccess } from '../../../../helpers';
|
import { copy, showError, showInfo, showSuccess } from '../../../../helpers';
|
||||||
@@ -45,6 +46,8 @@ const ModelTestModal = ({
|
|||||||
testChannel,
|
testChannel,
|
||||||
modelTablePage,
|
modelTablePage,
|
||||||
setModelTablePage,
|
setModelTablePage,
|
||||||
|
selectedEndpointType,
|
||||||
|
setSelectedEndpointType,
|
||||||
allSelectingRef,
|
allSelectingRef,
|
||||||
isMobile,
|
isMobile,
|
||||||
t,
|
t,
|
||||||
@@ -59,6 +62,17 @@ const ModelTestModal = ({
|
|||||||
)
|
)
|
||||||
: [];
|
: [];
|
||||||
|
|
||||||
|
const endpointTypeOptions = [
|
||||||
|
{ value: '', label: t('自动检测') },
|
||||||
|
{ value: 'openai', label: 'OpenAI (/v1/chat/completions)' },
|
||||||
|
{ value: 'openai-response', label: 'OpenAI Response (/v1/responses)' },
|
||||||
|
{ value: 'anthropic', label: 'Anthropic (/v1/messages)' },
|
||||||
|
{ value: 'gemini', label: 'Gemini (/v1beta/models/{model}:generateContent)' },
|
||||||
|
{ value: 'jina-rerank', label: 'Jina Rerank (/rerank)' },
|
||||||
|
{ value: 'image-generation', label: t('图像生成') + ' (/v1/images/generations)' },
|
||||||
|
{ value: 'embeddings', label: 'Embeddings (/v1/embeddings)' },
|
||||||
|
];
|
||||||
|
|
||||||
const handleCopySelected = () => {
|
const handleCopySelected = () => {
|
||||||
if (selectedModelKeys.length === 0) {
|
if (selectedModelKeys.length === 0) {
|
||||||
showError(t('请先选择模型!'));
|
showError(t('请先选择模型!'));
|
||||||
@@ -152,7 +166,7 @@ const ModelTestModal = ({
|
|||||||
return (
|
return (
|
||||||
<Button
|
<Button
|
||||||
type='tertiary'
|
type='tertiary'
|
||||||
onClick={() => testChannel(currentTestChannel, record.model)}
|
onClick={() => testChannel(currentTestChannel, record.model, selectedEndpointType)}
|
||||||
loading={isTesting}
|
loading={isTesting}
|
||||||
size='small'
|
size='small'
|
||||||
>
|
>
|
||||||
@@ -228,6 +242,18 @@ const ModelTestModal = ({
|
|||||||
>
|
>
|
||||||
{hasChannel && (
|
{hasChannel && (
|
||||||
<div className='model-test-scroll'>
|
<div className='model-test-scroll'>
|
||||||
|
{/* 端点类型选择器 */}
|
||||||
|
<div className='flex items-center gap-2 w-full mb-2'>
|
||||||
|
<Typography.Text strong>{t('端点类型')}:</Typography.Text>
|
||||||
|
<Select
|
||||||
|
value={selectedEndpointType}
|
||||||
|
onChange={setSelectedEndpointType}
|
||||||
|
optionList={endpointTypeOptions}
|
||||||
|
className='!w-full'
|
||||||
|
placeholder={t('选择端点类型')}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
{/* 搜索与操作按钮 */}
|
{/* 搜索与操作按钮 */}
|
||||||
<div className='flex items-center justify-end gap-2 w-full mb-2'>
|
<div className='flex items-center justify-end gap-2 w-full mb-2'>
|
||||||
<Input
|
<Input
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ export const useChannelsData = () => {
|
|||||||
const [selectedModelKeys, setSelectedModelKeys] = useState([]);
|
const [selectedModelKeys, setSelectedModelKeys] = useState([]);
|
||||||
const [isBatchTesting, setIsBatchTesting] = useState(false);
|
const [isBatchTesting, setIsBatchTesting] = useState(false);
|
||||||
const [modelTablePage, setModelTablePage] = useState(1);
|
const [modelTablePage, setModelTablePage] = useState(1);
|
||||||
|
const [selectedEndpointType, setSelectedEndpointType] = useState('');
|
||||||
|
|
||||||
// 使用 ref 来避免闭包问题,类似旧版实现
|
// 使用 ref 来避免闭包问题,类似旧版实现
|
||||||
const shouldStopBatchTestingRef = useRef(false);
|
const shouldStopBatchTestingRef = useRef(false);
|
||||||
@@ -691,7 +692,7 @@ export const useChannelsData = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Test channel - 单个模型测试,参考旧版实现
|
// Test channel - 单个模型测试,参考旧版实现
|
||||||
const testChannel = async (record, model) => {
|
const testChannel = async (record, model, endpointType = '') => {
|
||||||
const testKey = `${record.id}-${model}`;
|
const testKey = `${record.id}-${model}`;
|
||||||
|
|
||||||
// 检查是否应该停止批量测试
|
// 检查是否应该停止批量测试
|
||||||
@@ -703,7 +704,11 @@ export const useChannelsData = () => {
|
|||||||
setTestingModels(prev => new Set([...prev, model]));
|
setTestingModels(prev => new Set([...prev, model]));
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const res = await API.get(`/api/channel/test/${record.id}?model=${model}`);
|
let url = `/api/channel/test/${record.id}?model=${model}`;
|
||||||
|
if (endpointType) {
|
||||||
|
url += `&endpoint_type=${endpointType}`;
|
||||||
|
}
|
||||||
|
const res = await API.get(url);
|
||||||
|
|
||||||
// 检查是否在请求期间被停止
|
// 检查是否在请求期间被停止
|
||||||
if (shouldStopBatchTestingRef.current && isBatchTesting) {
|
if (shouldStopBatchTestingRef.current && isBatchTesting) {
|
||||||
@@ -820,7 +825,7 @@ export const useChannelsData = () => {
|
|||||||
.replace('${total}', models.length)
|
.replace('${total}', models.length)
|
||||||
);
|
);
|
||||||
|
|
||||||
const batchPromises = batch.map(model => testChannel(currentTestChannel, model));
|
const batchPromises = batch.map(model => testChannel(currentTestChannel, model, selectedEndpointType));
|
||||||
const batchResults = await Promise.allSettled(batchPromises);
|
const batchResults = await Promise.allSettled(batchPromises);
|
||||||
results.push(...batchResults);
|
results.push(...batchResults);
|
||||||
|
|
||||||
@@ -902,6 +907,7 @@ export const useChannelsData = () => {
|
|||||||
setTestingModels(new Set());
|
setTestingModels(new Set());
|
||||||
setSelectedModelKeys([]);
|
setSelectedModelKeys([]);
|
||||||
setModelTablePage(1);
|
setModelTablePage(1);
|
||||||
|
setSelectedEndpointType('');
|
||||||
// 可选择性保留测试结果,这里不清空以便用户查看
|
// 可选择性保留测试结果,这里不清空以便用户查看
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -989,6 +995,8 @@ export const useChannelsData = () => {
|
|||||||
isBatchTesting,
|
isBatchTesting,
|
||||||
modelTablePage,
|
modelTablePage,
|
||||||
setModelTablePage,
|
setModelTablePage,
|
||||||
|
selectedEndpointType,
|
||||||
|
setSelectedEndpointType,
|
||||||
allSelectingRef,
|
allSelectingRef,
|
||||||
|
|
||||||
// Multi-key management states
|
// Multi-key management states
|
||||||
|
|||||||
Reference in New Issue
Block a user