diff --git a/controller/channel-test.go b/controller/channel-test.go index 5ae04e8a0..2f2d45012 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -31,6 +31,7 @@ import ( "github.com/bytedance/gopkg/util/gopool" "github.com/samber/lo" + "github.com/tidwall/gjson" "github.com/gin-gonic/gin" ) @@ -41,7 +42,21 @@ type testResult struct { newAPIError *types.NewAPIError } -func testChannel(channel *model.Channel, testModel string, endpointType string) testResult { +func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointType string) string { + normalized := strings.TrimSpace(endpointType) + if normalized != "" { + return normalized + } + if strings.HasSuffix(modelName, ratio_setting.CompactModelSuffix) { + return string(constant.EndpointTypeOpenAIResponseCompact) + } + if channel != nil && channel.Type == constant.ChannelTypeCodex { + return string(constant.EndpointTypeOpenAIResponse) + } + return normalized +} + +func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult { tik := time.Now() var unsupportedTestChannelTypes = []int{ constant.ChannelTypeMidjourney, @@ -76,6 +91,8 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) } } + endpointType = normalizeChannelTestEndpoint(channel, testModel, endpointType) + requestPath := "/v1/chat/completions" // 如果指定了端点类型,使用指定的端点类型 @@ -200,7 +217,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) } } - request := buildTestRequest(testModel, endpointType, channel) + request := buildTestRequest(testModel, endpointType, channel, isStream) info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) @@ -418,16 +435,16 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) newAPIError: respErr, } } - if usageA == nil { + usage, usageErr := coerceTestUsage(usageA, isStream, info.GetEstimatePromptTokens()) + if usageErr != nil { return testResult{ context: c, - localErr: errors.New("usage is nil"), - newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError), + localErr: usageErr, + newAPIError: types.NewOpenAIError(usageErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), } } - usage := usageA.(*dto.Usage) result := w.Result() - respBody, err := io.ReadAll(result.Body) + respBody, err := readTestResponseBody(result.Body, isStream) if err != nil { return testResult{ context: c, @@ -435,6 +452,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), } } + if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil { + return testResult{ + context: c, + localErr: bodyErr, + newAPIError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), + } + } info.SetEstimatePromptTokens(usage.PromptTokens) quota := 0 @@ -473,7 +497,101 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) } } -func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request { +func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) { + switch u := usageAny.(type) { + case *dto.Usage: + return u, nil + case dto.Usage: + return &u, nil + case nil: + if !isStream { + return nil, errors.New("usage is nil") + } + usage := &dto.Usage{ + PromptTokens: estimatePromptTokens, + } + usage.TotalTokens = usage.PromptTokens + return usage, nil + default: + if !isStream { + return nil, fmt.Errorf("invalid usage type: %T", usageAny) + } + usage := &dto.Usage{ + PromptTokens: estimatePromptTokens, + } + usage.TotalTokens = usage.PromptTokens + return usage, nil + } +} + +func readTestResponseBody(body io.ReadCloser, isStream bool) ([]byte, error) { + defer func() { _ = body.Close() }() + const maxStreamLogBytes = 8 << 10 + if isStream { + return io.ReadAll(io.LimitReader(body, maxStreamLogBytes)) + } + return io.ReadAll(body) +} + +func detectErrorFromTestResponseBody(respBody []byte) error { + b := bytes.TrimSpace(respBody) + if len(b) == 0 { + return nil + } + if message := detectErrorMessageFromJSONBytes(b); message != "" { + return fmt.Errorf("upstream error: %s", message) + } + + for _, line := range bytes.Split(b, []byte{'\n'}) { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:"))) + if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) { + continue + } + if message := detectErrorMessageFromJSONBytes(payload); message != "" { + return fmt.Errorf("upstream error: %s", message) + } + } + + return nil +} + +func detectErrorMessageFromJSONBytes(jsonBytes []byte) string { + if len(jsonBytes) == 0 { + return "" + } + if jsonBytes[0] != '{' && jsonBytes[0] != '[' { + return "" + } + errVal := gjson.GetBytes(jsonBytes, "error") + if !errVal.Exists() || errVal.Type == gjson.Null { + return "" + } + + message := gjson.GetBytes(jsonBytes, "error.message").String() + if message == "" { + message = gjson.GetBytes(jsonBytes, "error.error.message").String() + } + if message == "" && errVal.Type == gjson.String { + message = errVal.String() + } + if message == "" { + message = errVal.Raw + } + message = strings.TrimSpace(message) + if message == "" { + return "upstream returned error payload" + } + return message +} + +func buildTestRequest(model string, endpointType string, channel *model.Channel, isStream bool) dto.Request { testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`) // 根据端点类型构建不同的测试请求 @@ -504,8 +622,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel) case constant.EndpointTypeOpenAIResponse: // 返回 OpenAIResponsesRequest return &dto.OpenAIResponsesRequest{ - Model: model, - Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), + Model: model, + Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), + Stream: isStream, } case constant.EndpointTypeOpenAIResponseCompact: // 返回 OpenAIResponsesCompactionRequest @@ -519,9 +638,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel) if constant.EndpointType(endpointType) == constant.EndpointTypeGemini { maxTokens = 3000 } - return &dto.GeneralOpenAIRequest{ + req := &dto.GeneralOpenAIRequest{ Model: model, - Stream: false, + Stream: isStream, Messages: []dto.Message{ { Role: "user", @@ -530,6 +649,10 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel) }, MaxTokens: maxTokens, } + if isStream { + req.StreamOptions = &dto.StreamOptions{IncludeUsage: true} + } + return req } } @@ -565,15 +688,16 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel) // Responses-only models (e.g. codex series) if strings.Contains(strings.ToLower(model), "codex") { return &dto.OpenAIResponsesRequest{ - Model: model, - Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), + Model: model, + Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), + Stream: isStream, } } // Chat/Completion 请求 - 返回 GeneralOpenAIRequest testRequest := &dto.GeneralOpenAIRequest{ Model: model, - Stream: false, + Stream: isStream, Messages: []dto.Message{ { Role: "user", @@ -581,6 +705,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel) }, }, } + if isStream { + testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true} + } if strings.HasPrefix(model, "o") { testRequest.MaxCompletionTokens = 16 @@ -618,8 +745,9 @@ func TestChannel(c *gin.Context) { //}() testModel := c.Query("model") endpointType := c.Query("endpoint_type") + isStream, _ := strconv.ParseBool(c.Query("stream")) tik := time.Now() - result := testChannel(channel, testModel, endpointType) + result := testChannel(channel, testModel, endpointType, isStream) if result.localErr != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -678,7 +806,7 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - result := testChannel(channel, "", "") + result := testChannel(channel, "", "", false) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() diff --git a/web/src/components/table/channels/modals/ModelTestModal.jsx b/web/src/components/table/channels/modals/ModelTestModal.jsx index 47aa66cbe..490cf54be 100644 --- a/web/src/components/table/channels/modals/ModelTestModal.jsx +++ b/web/src/components/table/channels/modals/ModelTestModal.jsx @@ -26,8 +26,10 @@ import { Tag, Typography, Select, + Switch, + Banner, } from '@douyinfe/semi-ui'; -import { IconSearch } from '@douyinfe/semi-icons'; +import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons'; import { copy, showError, showInfo, showSuccess } from '../../../../helpers'; import { MODEL_TABLE_PAGE_SIZE } from '../../../../constants'; @@ -48,11 +50,25 @@ const ModelTestModal = ({ setModelTablePage, selectedEndpointType, setSelectedEndpointType, + isStreamTest, + setIsStreamTest, allSelectingRef, isMobile, t, }) => { const hasChannel = Boolean(currentTestChannel); + const streamToggleDisabled = [ + 'embeddings', + 'image-generation', + 'jina-rerank', + 'openai-response-compact', + ].includes(selectedEndpointType); + + React.useEffect(() => { + if (streamToggleDisabled && isStreamTest) { + setIsStreamTest(false); + } + }, [streamToggleDisabled, isStreamTest, setIsStreamTest]); const filteredModels = hasChannel ? currentTestChannel.models @@ -181,6 +197,7 @@ const ModelTestModal = ({ currentTestChannel, record.model, selectedEndpointType, + isStreamTest, ) } loading={isTesting} @@ -258,25 +275,46 @@ const ModelTestModal = ({ > {hasChannel && (
- {/* 端点类型选择器 */} -
- {t('端点类型')}: - +
+
+ + {t('流式')}: + + +
- - {t( + + } + className='!rounded-lg mb-2' + description={t( '说明:本页测试为非流式请求;若渠道仅支持流式返回,可能出现测试失败,请以实际使用为准。', )} - + /> {/* 搜索与操作按钮 */} -
+
} showClear /> - - - +
+ + +
{ const [isBatchTesting, setIsBatchTesting] = useState(false); const [modelTablePage, setModelTablePage] = useState(1); const [selectedEndpointType, setSelectedEndpointType] = useState(''); + const [isStreamTest, setIsStreamTest] = useState(false); const [globalPassThroughEnabled, setGlobalPassThroughEnabled] = useState(false); @@ -851,7 +852,12 @@ export const useChannelsData = () => { }; // Test channel - 单个模型测试,参考旧版实现 - const testChannel = async (record, model, endpointType = '') => { + const testChannel = async ( + record, + model, + endpointType = '', + stream = false, + ) => { const testKey = `${record.id}-${model}`; // 检查是否应该停止批量测试 @@ -867,6 +873,9 @@ export const useChannelsData = () => { if (endpointType) { url += `&endpoint_type=${endpointType}`; } + if (stream) { + url += `&stream=true`; + } const res = await API.get(url); // 检查是否在请求期间被停止 @@ -995,7 +1004,12 @@ export const useChannelsData = () => { ); const batchPromises = batch.map((model) => - testChannel(currentTestChannel, model, selectedEndpointType), + testChannel( + currentTestChannel, + model, + selectedEndpointType, + isStreamTest, + ), ); const batchResults = await Promise.allSettled(batchPromises); results.push(...batchResults); @@ -1080,6 +1094,7 @@ export const useChannelsData = () => { setSelectedModelKeys([]); setModelTablePage(1); setSelectedEndpointType(''); + setIsStreamTest(false); // 可选择性保留测试结果,这里不清空以便用户查看 }; @@ -1170,6 +1185,8 @@ export const useChannelsData = () => { setModelTablePage, selectedEndpointType, setSelectedEndpointType, + isStreamTest, + setIsStreamTest, allSelectingRef, // Multi-key management states diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index aaaac3a0d..b6cff6096 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1548,6 +1548,7 @@ "流": "stream", "流式响应完成": "Streaming response completed", "流式输出": "Streaming Output", + "流式": "Streaming", "流量端口": "Traffic Port", "浅色": "Light", "浅色模式": "Light Mode", diff --git a/web/src/i18n/locales/fr.json b/web/src/i18n/locales/fr.json index b811e202e..65f2298d2 100644 --- a/web/src/i18n/locales/fr.json +++ b/web/src/i18n/locales/fr.json @@ -1558,6 +1558,7 @@ "流": "Flux", "流式响应完成": "Flux terminé", "流式输出": "Sortie en flux", + "流式": "Streaming", "流量端口": "Traffic Port", "浅色": "Clair", "浅色模式": "Mode clair", diff --git a/web/src/i18n/locales/ja.json b/web/src/i18n/locales/ja.json index 83cdd48e5..5b1090dff 100644 --- a/web/src/i18n/locales/ja.json +++ b/web/src/i18n/locales/ja.json @@ -1543,6 +1543,7 @@ "流": "ストリーム", "流式响应完成": "ストリーム完了", "流式输出": "ストリーム出力", + "流式": "ストリーミング", "流量端口": "Traffic Port", "浅色": "ライト", "浅色模式": "ライトモード", diff --git a/web/src/i18n/locales/ru.json b/web/src/i18n/locales/ru.json index c33128b88..810837004 100644 --- a/web/src/i18n/locales/ru.json +++ b/web/src/i18n/locales/ru.json @@ -1569,6 +1569,7 @@ "流": "Поток", "流式响应完成": "Поток завершён", "流式输出": "Потоковый вывод", + "流式": "Стриминг", "流量端口": "Traffic Port", "浅色": "Светлая", "浅色模式": "Светлый режим", diff --git a/web/src/i18n/locales/vi.json b/web/src/i18n/locales/vi.json index 979d870b4..54b9a0c01 100644 --- a/web/src/i18n/locales/vi.json +++ b/web/src/i18n/locales/vi.json @@ -1597,6 +1597,7 @@ "流": "luồng", "流式响应完成": "Luồng hoàn tất", "流式输出": "Đầu ra luồng", + "流式": "Streaming", "流量端口": "Traffic Port", "浅色": "Sáng", "浅色模式": "Chế độ sáng", diff --git a/web/src/i18n/locales/zh.json b/web/src/i18n/locales/zh.json index dd2f689e1..739a1c7f7 100644 --- a/web/src/i18n/locales/zh.json +++ b/web/src/i18n/locales/zh.json @@ -1538,6 +1538,7 @@ "流": "流", "流式响应完成": "流式响应完成", "流式输出": "流式输出", + "流式": "流式", "流量端口": "流量端口", "浅色": "浅色", "浅色模式": "浅色模式",