diff --git a/controller/channel-test.go b/controller/channel-test.go index 5ae04e8a0..2f2d45012 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -31,6 +31,7 @@ import ( "github.com/bytedance/gopkg/util/gopool" "github.com/samber/lo" + "github.com/tidwall/gjson" "github.com/gin-gonic/gin" ) @@ -41,7 +42,21 @@ type testResult struct { newAPIError *types.NewAPIError } -func testChannel(channel *model.Channel, testModel string, endpointType string) testResult { +func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointType string) string { + normalized := strings.TrimSpace(endpointType) + if normalized != "" { + return normalized + } + if strings.HasSuffix(modelName, ratio_setting.CompactModelSuffix) { + return string(constant.EndpointTypeOpenAIResponseCompact) + } + if channel != nil && channel.Type == constant.ChannelTypeCodex { + return string(constant.EndpointTypeOpenAIResponse) + } + return normalized +} + +func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult { tik := time.Now() var unsupportedTestChannelTypes = []int{ constant.ChannelTypeMidjourney, @@ -76,6 +91,8 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) } } + endpointType = normalizeChannelTestEndpoint(channel, testModel, endpointType) + requestPath := "/v1/chat/completions" // 如果指定了端点类型,使用指定的端点类型 @@ -200,7 +217,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) } } - request := buildTestRequest(testModel, endpointType, channel) + request := buildTestRequest(testModel, endpointType, channel, isStream) info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) @@ -418,16 +435,16 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) newAPIError: respErr, } } - if usageA == nil { + usage, usageErr := coerceTestUsage(usageA, isStream, info.GetEstimatePromptTokens()) + if usageErr != nil { return testResult{ context: c, - localErr: errors.New("usage is nil"), - newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError), + localErr: usageErr, + newAPIError: types.NewOpenAIError(usageErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), } } - usage := usageA.(*dto.Usage) result := w.Result() - respBody, err := io.ReadAll(result.Body) + respBody, err := readTestResponseBody(result.Body, isStream) if err != nil { return testResult{ context: c, @@ -435,6 +452,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), } } + if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil { + return testResult{ + context: c, + localErr: bodyErr, + newAPIError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), + } + } info.SetEstimatePromptTokens(usage.PromptTokens) quota := 0 @@ -473,7 +497,101 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) } } -func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request { +func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) { + switch u := usageAny.(type) { + case *dto.Usage: + return u, nil + case dto.Usage: + return &u, nil + case nil: + if !isStream { + return nil, errors.New("usage is nil") + } + usage := &dto.Usage{ + PromptTokens: estimatePromptTokens, + } + usage.TotalTokens = usage.PromptTokens + return usage, nil + default: + if !isStream { + return nil, fmt.Errorf("invalid usage type: %T", usageAny) + } + usage := &dto.Usage{ + PromptTokens: estimatePromptTokens, + } + usage.TotalTokens = usage.PromptTokens + return usage, nil + } +} + +func readTestResponseBody(body io.ReadCloser, isStream bool) ([]byte, error) { + defer func() { _ = body.Close() }() + const maxStreamLogBytes = 8 << 10 + if isStream { + return io.ReadAll(io.LimitReader(body, maxStreamLogBytes)) + } + return io.ReadAll(body) +} + +func detectErrorFromTestResponseBody(respBody []byte) error { + b := bytes.TrimSpace(respBody) + if len(b) == 0 { + return nil + } + if message := detectErrorMessageFromJSONBytes(b); message != "" { + return fmt.Errorf("upstream error: %s", message) + } + + for _, line := range bytes.Split(b, []byte{'\n'}) { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:"))) + if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) { + continue + } + if message := detectErrorMessageFromJSONBytes(payload); message != "" { + return fmt.Errorf("upstream error: %s", message) + } + } + + return nil +} + +func detectErrorMessageFromJSONBytes(jsonBytes []byte) string { + if len(jsonBytes) == 0 { + return "" + } + if jsonBytes[0] != '{' && jsonBytes[0] != '[' { + return "" + } + errVal := gjson.GetBytes(jsonBytes, "error") + if !errVal.Exists() || errVal.Type == gjson.Null { + return "" + } + + message := gjson.GetBytes(jsonBytes, "error.message").String() + if message == "" { + message = gjson.GetBytes(jsonBytes, "error.error.message").String() + } + if message == "" && errVal.Type == gjson.String { + message = errVal.String() + } + if message == "" { + message = errVal.Raw + } + message = strings.TrimSpace(message) + if message == "" { + return "upstream returned error payload" + } + return message +} + +func buildTestRequest(model string, endpointType string, channel *model.Channel, isStream bool) dto.Request { testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`) // 根据端点类型构建不同的测试请求 @@ -504,8 +622,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel) case constant.EndpointTypeOpenAIResponse: // 返回 OpenAIResponsesRequest return &dto.OpenAIResponsesRequest{ - Model: model, - Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), + Model: model, + Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), + Stream: isStream, } case constant.EndpointTypeOpenAIResponseCompact: // 返回 OpenAIResponsesCompactionRequest @@ -519,9 +638,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel) if constant.EndpointType(endpointType) == constant.EndpointTypeGemini { maxTokens = 3000 } - return &dto.GeneralOpenAIRequest{ + req := &dto.GeneralOpenAIRequest{ Model: model, - Stream: false, + Stream: isStream, Messages: []dto.Message{ { Role: "user", @@ -530,6 +649,10 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel) }, MaxTokens: maxTokens, } + if isStream { + req.StreamOptions = &dto.StreamOptions{IncludeUsage: true} + } + return req } } @@ -565,15 +688,16 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel) // Responses-only models (e.g. codex series) if strings.Contains(strings.ToLower(model), "codex") { return &dto.OpenAIResponsesRequest{ - Model: model, - Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), + Model: model, + Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), + Stream: isStream, } } // Chat/Completion 请求 - 返回 GeneralOpenAIRequest testRequest := &dto.GeneralOpenAIRequest{ Model: model, - Stream: false, + Stream: isStream, Messages: []dto.Message{ { Role: "user", @@ -581,6 +705,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel) }, }, } + if isStream { + testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true} + } if strings.HasPrefix(model, "o") { testRequest.MaxCompletionTokens = 16 @@ -618,8 +745,9 @@ func TestChannel(c *gin.Context) { //}() testModel := c.Query("model") endpointType := c.Query("endpoint_type") + isStream, _ := strconv.ParseBool(c.Query("stream")) tik := time.Now() - result := testChannel(channel, testModel, endpointType) + result := testChannel(channel, testModel, endpointType, isStream) if result.localErr != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -678,7 +806,7 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - result := testChannel(channel, "", "") + result := testChannel(channel, "", "", false) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() diff --git a/web/src/components/table/channels/modals/ModelTestModal.jsx b/web/src/components/table/channels/modals/ModelTestModal.jsx index 47aa66cbe..490cf54be 100644 --- a/web/src/components/table/channels/modals/ModelTestModal.jsx +++ b/web/src/components/table/channels/modals/ModelTestModal.jsx @@ -26,8 +26,10 @@ import { Tag, Typography, Select, + Switch, + Banner, } from '@douyinfe/semi-ui'; -import { IconSearch } from '@douyinfe/semi-icons'; +import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons'; import { copy, showError, showInfo, showSuccess } from '../../../../helpers'; import { MODEL_TABLE_PAGE_SIZE } from '../../../../constants'; @@ -48,11 +50,25 @@ const ModelTestModal = ({ setModelTablePage, selectedEndpointType, setSelectedEndpointType, + isStreamTest, + setIsStreamTest, allSelectingRef, isMobile, t, }) => { const hasChannel = Boolean(currentTestChannel); + const streamToggleDisabled = [ + 'embeddings', + 'image-generation', + 'jina-rerank', + 'openai-response-compact', + ].includes(selectedEndpointType); + + React.useEffect(() => { + if (streamToggleDisabled && isStreamTest) { + setIsStreamTest(false); + } + }, [streamToggleDisabled, isStreamTest, setIsStreamTest]); const filteredModels = hasChannel ? currentTestChannel.models @@ -181,6 +197,7 @@ const ModelTestModal = ({ currentTestChannel, record.model, selectedEndpointType, + isStreamTest, ) } loading={isTesting} @@ -258,25 +275,46 @@ const ModelTestModal = ({ > {hasChannel && (