From 6bc3e62fd541e4bcad7576f5e18e92b15806f7b4 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 30 Sep 2025 16:52:14 +0800 Subject: [PATCH] feat: add endpoint type selection to channel testing functionality --- common/endpoint_defaults.go | 1 + constant/endpoint_type.go | 1 + controller/channel-test.go | 277 ++++++++++++------ .../table/channels/modals/ModelTestModal.jsx | 28 +- web/src/hooks/channels/useChannelsData.jsx | 14 +- 5 files changed, 235 insertions(+), 86 deletions(-) diff --git a/common/endpoint_defaults.go b/common/endpoint_defaults.go index ffc263507..25f9c68eb 100644 --- a/common/endpoint_defaults.go +++ b/common/endpoint_defaults.go @@ -23,6 +23,7 @@ var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{ constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"}, constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"}, constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"}, + constant.EndpointTypeEmbeddings: {Path: "/v1/embeddings", Method: "POST"}, } // GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在 diff --git a/constant/endpoint_type.go b/constant/endpoint_type.go index ef096b759..f799e5ba8 100644 --- a/constant/endpoint_type.go +++ b/constant/endpoint_type.go @@ -9,6 +9,7 @@ const ( EndpointTypeGemini EndpointType = "gemini" EndpointTypeJinaRerank EndpointType = "jina-rerank" EndpointTypeImageGeneration EndpointType = "image-generation" + EndpointTypeEmbeddings EndpointType = "embeddings" //EndpointTypeMidjourney EndpointType = "midjourney-proxy" //EndpointTypeSuno EndpointType = "suno-proxy" //EndpointTypeKling EndpointType = "kling" diff --git a/controller/channel-test.go b/controller/channel-test.go index 9ea6eed75..b3a3be4eb 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -38,7 +38,7 @@ type testResult struct { newAPIError *types.NewAPIError } -func testChannel(channel *model.Channel, testModel string) testResult { +func testChannel(channel *model.Channel, testModel string, endpointType string) testResult { tik := time.Now() if channel.Type == constant.ChannelTypeMidjourney { return testResult{ @@ -81,18 +81,26 @@ func testChannel(channel *model.Channel, testModel string) testResult { requestPath := "/v1/chat/completions" - // 先判断是否为 Embedding 模型 - if strings.Contains(strings.ToLower(testModel), "embedding") || - strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 - strings.Contains(testModel, "bge-") || // bge 系列模型 - strings.Contains(testModel, "embed") || - channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型 - requestPath = "/v1/embeddings" // 修改请求路径 - } + // 如果指定了端点类型,使用指定的端点类型 + if endpointType != "" { + if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok { + requestPath = endpointInfo.Path + } + } else { + // 如果没有指定端点类型,使用原有的自动检测逻辑 + // 先判断是否为 Embedding 模型 + if strings.Contains(strings.ToLower(testModel), "embedding") || + strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 + strings.Contains(testModel, "bge-") || // bge 系列模型 + strings.Contains(testModel, "embed") || + channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型 + requestPath = "/v1/embeddings" // 修改请求路径 + } - // VolcEngine 图像生成模型 - if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { - requestPath = "/v1/images/generations" + // VolcEngine 图像生成模型 + if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { + requestPath = "/v1/images/generations" + } } c.Request = &http.Request{ @@ -114,21 +122,6 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } - // 重新检查模型类型并更新请求路径 - if strings.Contains(strings.ToLower(testModel), "embedding") || - strings.HasPrefix(testModel, "m3e") || - strings.Contains(testModel, "bge-") || - strings.Contains(testModel, "embed") || - channel.Type == constant.ChannelTypeMokaAI { - requestPath = "/v1/embeddings" - c.Request.URL.Path = requestPath - } - - if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { - requestPath = "/v1/images/generations" - c.Request.URL.Path = requestPath - } - cache, err := model.GetUserCache(1) if err != nil { return testResult{ @@ -153,17 +146,54 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: newAPIError, } } - request := buildTestRequest(testModel) - // Determine relay format based on request path - relayFormat := types.RelayFormatOpenAI - if c.Request.URL.Path == "/v1/embeddings" { - relayFormat = types.RelayFormatEmbedding - } - if c.Request.URL.Path == "/v1/images/generations" { - relayFormat = types.RelayFormatOpenAIImage + // Determine relay format based on endpoint type or request path + var relayFormat types.RelayFormat + if endpointType != "" { + // 根据指定的端点类型设置 relayFormat + switch constant.EndpointType(endpointType) { + case constant.EndpointTypeOpenAI: + relayFormat = types.RelayFormatOpenAI + case constant.EndpointTypeOpenAIResponse: + relayFormat = types.RelayFormatOpenAIResponses + case constant.EndpointTypeAnthropic: + relayFormat = types.RelayFormatClaude + case constant.EndpointTypeGemini: + relayFormat = types.RelayFormatGemini + case constant.EndpointTypeJinaRerank: + relayFormat = types.RelayFormatRerank + case constant.EndpointTypeImageGeneration: + relayFormat = types.RelayFormatOpenAIImage + case constant.EndpointTypeEmbeddings: + relayFormat = types.RelayFormatEmbedding + default: + relayFormat = types.RelayFormatOpenAI + } + } else { + // 根据请求路径自动检测 + relayFormat = types.RelayFormatOpenAI + if c.Request.URL.Path == "/v1/embeddings" { + relayFormat = types.RelayFormatEmbedding + } + if c.Request.URL.Path == "/v1/images/generations" { + relayFormat = types.RelayFormatOpenAIImage + } + if c.Request.URL.Path == "/v1/messages" { + relayFormat = types.RelayFormatClaude + } + if strings.Contains(c.Request.URL.Path, "/v1beta/models") { + relayFormat = types.RelayFormatGemini + } + if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" { + relayFormat = types.RelayFormatRerank + } + if c.Request.URL.Path == "/v1/responses" { + relayFormat = types.RelayFormatOpenAIResponses + } } + request := buildTestRequest(testModel, endpointType) + info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) if err != nil { @@ -186,7 +216,8 @@ func testChannel(channel *model.Channel, testModel string) testResult { } testModel = info.UpstreamModelName - request.Model = testModel + // 更新请求中的模型名称 + request.SetModelName(testModel) apiType, _ := common.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) @@ -216,33 +247,62 @@ func testChannel(channel *model.Channel, testModel string) testResult { var convertedRequest any // 根据 RelayMode 选择正确的转换函数 - if info.RelayMode == relayconstant.RelayModeEmbeddings { - // 创建一个 EmbeddingRequest - embeddingRequest := dto.EmbeddingRequest{ - Input: request.Input, - Model: request.Model, - } - // 调用专门用于 Embedding 的转换函数 - convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest) - } else if info.RelayMode == relayconstant.RelayModeImagesGenerations { - // 创建一个 ImageRequest - prompt := "cat" - if request.Prompt != nil { - if promptStr, ok := request.Prompt.(string); ok && promptStr != "" { - prompt = promptStr + switch info.RelayMode { + case relayconstant.RelayModeEmbeddings: + // Embedding 请求 - request 已经是正确的类型 + if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok { + convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq) + } else { + return testResult{ + context: c, + localErr: errors.New("invalid embedding request type"), + newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed), } } - imageRequest := dto.ImageRequest{ - Prompt: prompt, - Model: request.Model, - N: uint(request.N), - Size: request.Size, + case relayconstant.RelayModeImagesGenerations: + // 图像生成请求 - request 已经是正确的类型 + if imageReq, ok := request.(*dto.ImageRequest); ok { + convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq) + } else { + return testResult{ + context: c, + localErr: errors.New("invalid image request type"), + newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed), + } + } + case relayconstant.RelayModeRerank: + // Rerank 请求 - request 已经是正确的类型 + if rerankReq, ok := request.(*dto.RerankRequest); ok { + convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq) + } else { + return testResult{ + context: c, + localErr: errors.New("invalid rerank request type"), + newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed), + } + } + case relayconstant.RelayModeResponses: + // Response 请求 - request 已经是正确的类型 + if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok { + convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq) + } else { + return testResult{ + context: c, + localErr: errors.New("invalid response request type"), + newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed), + } + } + default: + // Chat/Completion 等其他请求类型 + if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok { + convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq) + } else { + return testResult{ + context: c, + localErr: errors.New("invalid general request type"), + newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed), + } } - // 调用专门用于图像生成的转换函数 - convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest) - } else { - // 对其他所有请求类型(如 Chat),保持原有逻辑 - convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request) } if err != nil { @@ -345,22 +405,82 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } -func buildTestRequest(model string) *dto.GeneralOpenAIRequest { - testRequest := &dto.GeneralOpenAIRequest{ - Model: "", // this will be set later - Stream: false, +func buildTestRequest(model string, endpointType string) dto.Request { + // 根据端点类型构建不同的测试请求 + if endpointType != "" { + switch constant.EndpointType(endpointType) { + case constant.EndpointTypeEmbeddings: + // 返回 EmbeddingRequest + return &dto.EmbeddingRequest{ + Model: model, + Input: []any{"hello world"}, + } + case constant.EndpointTypeImageGeneration: + // 返回 ImageRequest + return &dto.ImageRequest{ + Model: model, + Prompt: "a cute cat", + N: 1, + Size: "1024x1024", + } + case constant.EndpointTypeJinaRerank: + // 返回 RerankRequest + return &dto.RerankRequest{ + Model: model, + Query: "What is Deep Learning?", + Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."}, + TopN: 2, + } + case constant.EndpointTypeOpenAIResponse: + // 返回 OpenAIResponsesRequest + return &dto.OpenAIResponsesRequest{ + Model: model, + Input: json.RawMessage("\"hi\""), + } + case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI: + // 返回 GeneralOpenAIRequest + maxTokens := uint(10) + if constant.EndpointType(endpointType) == constant.EndpointTypeGemini { + maxTokens = 3000 + } + return &dto.GeneralOpenAIRequest{ + Model: model, + Stream: false, + Messages: []dto.Message{ + { + Role: "user", + Content: "hi", + }, + }, + MaxTokens: maxTokens, + } + } } + // 自动检测逻辑(保持原有行为) // 先判断是否为 Embedding 模型 - if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型 - strings.HasPrefix(model, "m3e") || // m3e 系列模型 + if strings.Contains(strings.ToLower(model), "embedding") || + strings.HasPrefix(model, "m3e") || strings.Contains(model, "bge-") { - testRequest.Model = model - // Embedding 请求 - testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型 - return testRequest + // 返回 EmbeddingRequest + return &dto.EmbeddingRequest{ + Model: model, + Input: []any{"hello world"}, + } } - // 并非Embedding 模型 + + // Chat/Completion 请求 - 返回 GeneralOpenAIRequest + testRequest := &dto.GeneralOpenAIRequest{ + Model: model, + Stream: false, + Messages: []dto.Message{ + { + Role: "user", + Content: "hi", + }, + }, + } + if strings.HasPrefix(model, "o") { testRequest.MaxCompletionTokens = 10 } else if strings.Contains(model, "thinking") { @@ -373,12 +493,6 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest { testRequest.MaxTokens = 10 } - testMessage := dto.Message{ - Role: "user", - Content: "hi", - } - testRequest.Model = model - testRequest.Messages = append(testRequest.Messages, testMessage) return testRequest } @@ -402,8 +516,9 @@ func TestChannel(c *gin.Context) { // } //}() testModel := c.Query("model") + endpointType := c.Query("endpoint_type") tik := time.Now() - result := testChannel(channel, testModel) + result := testChannel(channel, testModel, endpointType) if result.localErr != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -429,7 +544,6 @@ func TestChannel(c *gin.Context) { "message": "", "time": consumedTime, }) - return } var testAllChannelsLock sync.Mutex @@ -463,7 +577,7 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - result := testChannel(channel, "") + result := testChannel(channel, "", "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() @@ -477,7 +591,7 @@ func testAllChannels(notify bool) error { // 当错误检查通过,才检查响应时间 if common.AutomaticDisableChannelEnabled && !shouldBanChannel { if milliseconds > disableThreshold { - err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout) shouldBanChannel = true } @@ -514,7 +628,6 @@ func TestAllChannels(c *gin.Context) { "success": true, "message": "", }) - return } var autoTestChannelsOnce sync.Once diff --git a/web/src/components/table/channels/modals/ModelTestModal.jsx b/web/src/components/table/channels/modals/ModelTestModal.jsx index c643ed100..7cc56612d 100644 --- a/web/src/components/table/channels/modals/ModelTestModal.jsx +++ b/web/src/components/table/channels/modals/ModelTestModal.jsx @@ -25,6 +25,7 @@ import { Table, Tag, Typography, + Select, } from '@douyinfe/semi-ui'; import { IconSearch } from '@douyinfe/semi-icons'; import { copy, showError, showInfo, showSuccess } from '../../../../helpers'; @@ -45,6 +46,8 @@ const ModelTestModal = ({ testChannel, modelTablePage, setModelTablePage, + selectedEndpointType, + setSelectedEndpointType, allSelectingRef, isMobile, t, @@ -59,6 +62,17 @@ const ModelTestModal = ({ ) : []; + const endpointTypeOptions = [ + { value: '', label: t('自动检测') }, + { value: 'openai', label: 'OpenAI (/v1/chat/completions)' }, + { value: 'openai-response', label: 'OpenAI Response (/v1/responses)' }, + { value: 'anthropic', label: 'Anthropic (/v1/messages)' }, + { value: 'gemini', label: 'Gemini (/v1beta/models/{model}:generateContent)' }, + { value: 'jina-rerank', label: 'Jina Rerank (/rerank)' }, + { value: 'image-generation', label: t('图像生成') + ' (/v1/images/generations)' }, + { value: 'embeddings', label: 'Embeddings (/v1/embeddings)' }, + ]; + const handleCopySelected = () => { if (selectedModelKeys.length === 0) { showError(t('请先选择模型!')); @@ -152,7 +166,7 @@ const ModelTestModal = ({ return (