Merge pull request #2875 from seefs001/feature/channel-test-stream

feat: channel test with stream=true
This commit is contained in:
Calcium-Ion
2026-02-08 00:17:07 +08:00
committed by GitHub
9 changed files with 230 additions and 40 deletions

View File

@@ -31,6 +31,7 @@ import (
"github.com/bytedance/gopkg/util/gopool" "github.com/bytedance/gopkg/util/gopool"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/tidwall/gjson"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -41,7 +42,21 @@ type testResult struct {
newAPIError *types.NewAPIError newAPIError *types.NewAPIError
} }
func testChannel(channel *model.Channel, testModel string, endpointType string) testResult { func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointType string) string {
normalized := strings.TrimSpace(endpointType)
if normalized != "" {
return normalized
}
if strings.HasSuffix(modelName, ratio_setting.CompactModelSuffix) {
return string(constant.EndpointTypeOpenAIResponseCompact)
}
if channel != nil && channel.Type == constant.ChannelTypeCodex {
return string(constant.EndpointTypeOpenAIResponse)
}
return normalized
}
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
tik := time.Now() tik := time.Now()
var unsupportedTestChannelTypes = []int{ var unsupportedTestChannelTypes = []int{
constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourney,
@@ -76,6 +91,8 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
} }
} }
endpointType = normalizeChannelTestEndpoint(channel, testModel, endpointType)
requestPath := "/v1/chat/completions" requestPath := "/v1/chat/completions"
// 如果指定了端点类型,使用指定的端点类型 // 如果指定了端点类型,使用指定的端点类型
@@ -200,7 +217,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
} }
} }
request := buildTestRequest(testModel, endpointType, channel) request := buildTestRequest(testModel, endpointType, channel, isStream)
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
@@ -418,16 +435,16 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
newAPIError: respErr, newAPIError: respErr,
} }
} }
if usageA == nil { usage, usageErr := coerceTestUsage(usageA, isStream, info.GetEstimatePromptTokens())
if usageErr != nil {
return testResult{ return testResult{
context: c, context: c,
localErr: errors.New("usage is nil"), localErr: usageErr,
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError), newAPIError: types.NewOpenAIError(usageErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
} }
} }
usage := usageA.(*dto.Usage)
result := w.Result() result := w.Result()
respBody, err := io.ReadAll(result.Body) respBody, err := readTestResponseBody(result.Body, isStream)
if err != nil { if err != nil {
return testResult{ return testResult{
context: c, context: c,
@@ -435,6 +452,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
} }
} }
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
return testResult{
context: c,
localErr: bodyErr,
newAPIError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
info.SetEstimatePromptTokens(usage.PromptTokens) info.SetEstimatePromptTokens(usage.PromptTokens)
quota := 0 quota := 0
@@ -473,7 +497,101 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
} }
} }
func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request { func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
switch u := usageAny.(type) {
case *dto.Usage:
return u, nil
case dto.Usage:
return &u, nil
case nil:
if !isStream {
return nil, errors.New("usage is nil")
}
usage := &dto.Usage{
PromptTokens: estimatePromptTokens,
}
usage.TotalTokens = usage.PromptTokens
return usage, nil
default:
if !isStream {
return nil, fmt.Errorf("invalid usage type: %T", usageAny)
}
usage := &dto.Usage{
PromptTokens: estimatePromptTokens,
}
usage.TotalTokens = usage.PromptTokens
return usage, nil
}
}
func readTestResponseBody(body io.ReadCloser, isStream bool) ([]byte, error) {
defer func() { _ = body.Close() }()
const maxStreamLogBytes = 8 << 10
if isStream {
return io.ReadAll(io.LimitReader(body, maxStreamLogBytes))
}
return io.ReadAll(body)
}
func detectErrorFromTestResponseBody(respBody []byte) error {
b := bytes.TrimSpace(respBody)
if len(b) == 0 {
return nil
}
if message := detectErrorMessageFromJSONBytes(b); message != "" {
return fmt.Errorf("upstream error: %s", message)
}
for _, line := range bytes.Split(b, []byte{'\n'}) {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
if message := detectErrorMessageFromJSONBytes(payload); message != "" {
return fmt.Errorf("upstream error: %s", message)
}
}
return nil
}
func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
if len(jsonBytes) == 0 {
return ""
}
if jsonBytes[0] != '{' && jsonBytes[0] != '[' {
return ""
}
errVal := gjson.GetBytes(jsonBytes, "error")
if !errVal.Exists() || errVal.Type == gjson.Null {
return ""
}
message := gjson.GetBytes(jsonBytes, "error.message").String()
if message == "" {
message = gjson.GetBytes(jsonBytes, "error.error.message").String()
}
if message == "" && errVal.Type == gjson.String {
message = errVal.String()
}
if message == "" {
message = errVal.Raw
}
message = strings.TrimSpace(message)
if message == "" {
return "upstream returned error payload"
}
return message
}
func buildTestRequest(model string, endpointType string, channel *model.Channel, isStream bool) dto.Request {
testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`) testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`)
// 根据端点类型构建不同的测试请求 // 根据端点类型构建不同的测试请求
@@ -504,8 +622,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
case constant.EndpointTypeOpenAIResponse: case constant.EndpointTypeOpenAIResponse:
// 返回 OpenAIResponsesRequest // 返回 OpenAIResponsesRequest
return &dto.OpenAIResponsesRequest{ return &dto.OpenAIResponsesRequest{
Model: model, Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Stream: isStream,
} }
case constant.EndpointTypeOpenAIResponseCompact: case constant.EndpointTypeOpenAIResponseCompact:
// 返回 OpenAIResponsesCompactionRequest // 返回 OpenAIResponsesCompactionRequest
@@ -519,9 +638,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini { if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
maxTokens = 3000 maxTokens = 3000
} }
return &dto.GeneralOpenAIRequest{ req := &dto.GeneralOpenAIRequest{
Model: model, Model: model,
Stream: false, Stream: isStream,
Messages: []dto.Message{ Messages: []dto.Message{
{ {
Role: "user", Role: "user",
@@ -530,6 +649,10 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
}, },
MaxTokens: maxTokens, MaxTokens: maxTokens,
} }
if isStream {
req.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
}
return req
} }
} }
@@ -565,15 +688,16 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
// Responses-only models (e.g. codex series) // Responses-only models (e.g. codex series)
if strings.Contains(strings.ToLower(model), "codex") { if strings.Contains(strings.ToLower(model), "codex") {
return &dto.OpenAIResponsesRequest{ return &dto.OpenAIResponsesRequest{
Model: model, Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Stream: isStream,
} }
} }
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest // Chat/Completion 请求 - 返回 GeneralOpenAIRequest
testRequest := &dto.GeneralOpenAIRequest{ testRequest := &dto.GeneralOpenAIRequest{
Model: model, Model: model,
Stream: false, Stream: isStream,
Messages: []dto.Message{ Messages: []dto.Message{
{ {
Role: "user", Role: "user",
@@ -581,6 +705,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
}, },
}, },
} }
if isStream {
testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
}
if strings.HasPrefix(model, "o") { if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = 16 testRequest.MaxCompletionTokens = 16
@@ -618,8 +745,9 @@ func TestChannel(c *gin.Context) {
//}() //}()
testModel := c.Query("model") testModel := c.Query("model")
endpointType := c.Query("endpoint_type") endpointType := c.Query("endpoint_type")
isStream, _ := strconv.ParseBool(c.Query("stream"))
tik := time.Now() tik := time.Now()
result := testChannel(channel, testModel, endpointType) result := testChannel(channel, testModel, endpointType, isStream)
if result.localErr != nil { if result.localErr != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -678,7 +806,7 @@ func testAllChannels(notify bool) error {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
result := testChannel(channel, "", "") result := testChannel(channel, "", "", false)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()

View File

@@ -26,8 +26,10 @@ import {
Tag, Tag,
Typography, Typography,
Select, Select,
Switch,
Banner,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { IconSearch } from '@douyinfe/semi-icons'; import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons';
import { copy, showError, showInfo, showSuccess } from '../../../../helpers'; import { copy, showError, showInfo, showSuccess } from '../../../../helpers';
import { MODEL_TABLE_PAGE_SIZE } from '../../../../constants'; import { MODEL_TABLE_PAGE_SIZE } from '../../../../constants';
@@ -48,11 +50,25 @@ const ModelTestModal = ({
setModelTablePage, setModelTablePage,
selectedEndpointType, selectedEndpointType,
setSelectedEndpointType, setSelectedEndpointType,
isStreamTest,
setIsStreamTest,
allSelectingRef, allSelectingRef,
isMobile, isMobile,
t, t,
}) => { }) => {
const hasChannel = Boolean(currentTestChannel); const hasChannel = Boolean(currentTestChannel);
const streamToggleDisabled = [
'embeddings',
'image-generation',
'jina-rerank',
'openai-response-compact',
].includes(selectedEndpointType);
React.useEffect(() => {
if (streamToggleDisabled && isStreamTest) {
setIsStreamTest(false);
}
}, [streamToggleDisabled, isStreamTest, setIsStreamTest]);
const filteredModels = hasChannel const filteredModels = hasChannel
? currentTestChannel.models ? currentTestChannel.models
@@ -181,6 +197,7 @@ const ModelTestModal = ({
currentTestChannel, currentTestChannel,
record.model, record.model,
selectedEndpointType, selectedEndpointType,
isStreamTest,
) )
} }
loading={isTesting} loading={isTesting}
@@ -258,25 +275,46 @@ const ModelTestModal = ({
> >
{hasChannel && ( {hasChannel && (
<div className='model-test-scroll'> <div className='model-test-scroll'>
{/* 端点类型选择器 */} {/* Endpoint toolbar */}
<div className='flex items-center gap-2 w-full mb-2'> <div className='flex flex-col sm:flex-row sm:items-center gap-2 w-full mb-2'>
<Typography.Text strong>{t('端点类型')}:</Typography.Text> <div className='flex items-center gap-2 flex-1 min-w-0'>
<Select <Typography.Text strong className='shrink-0'>
value={selectedEndpointType} {t('端点类型')}:
onChange={setSelectedEndpointType} </Typography.Text>
optionList={endpointTypeOptions} <Select
className='!w-full' value={selectedEndpointType}
placeholder={t('选择端点类型')} onChange={setSelectedEndpointType}
/> optionList={endpointTypeOptions}
className='!w-full min-w-0'
placeholder={t('选择端点类型')}
/>
</div>
<div className='flex items-center justify-between sm:justify-end gap-2 shrink-0'>
<Typography.Text strong className='shrink-0'>
{t('流式')}:
</Typography.Text>
<Switch
checked={isStreamTest}
onChange={setIsStreamTest}
size='small'
disabled={streamToggleDisabled}
aria-label={t('流式')}
/>
</div>
</div> </div>
<Typography.Text type='tertiary' size='small' className='block mb-2'>
{t( <Banner
type='info'
closeIcon={null}
icon={<IconInfoCircle />}
className='!rounded-lg mb-2'
description={t(
'说明:本页测试为非流式请求;若渠道仅支持流式返回,可能出现测试失败,请以实际使用为准。', '说明:本页测试为非流式请求;若渠道仅支持流式返回,可能出现测试失败,请以实际使用为准。',
)} )}
</Typography.Text> />
{/* 搜索与操作按钮 */} {/* 搜索与操作按钮 */}
<div className='flex items-center justify-end gap-2 w-full mb-2'> <div className='flex flex-col sm:flex-row sm:items-center gap-2 w-full mb-2'>
<Input <Input
placeholder={t('搜索模型...')} placeholder={t('搜索模型...')}
value={modelSearchKeyword} value={modelSearchKeyword}
@@ -284,16 +322,17 @@ const ModelTestModal = ({
setModelSearchKeyword(v); setModelSearchKeyword(v);
setModelTablePage(1); setModelTablePage(1);
}} }}
className='!w-full' className='!w-full sm:!flex-1'
prefix={<IconSearch />} prefix={<IconSearch />}
showClear showClear
/> />
<Button onClick={handleCopySelected}>{t('复制已选')}</Button> <div className='flex items-center justify-end gap-2'>
<Button onClick={handleCopySelected}>{t('复制已选')}</Button>
<Button type='tertiary' onClick={handleSelectSuccess}> <Button type='tertiary' onClick={handleSelectSuccess}>
{t('选择成功')} {t('选择成功')}
</Button> </Button>
</div>
</div> </div>
<Table <Table

View File

@@ -87,6 +87,7 @@ export const useChannelsData = () => {
const [isBatchTesting, setIsBatchTesting] = useState(false); const [isBatchTesting, setIsBatchTesting] = useState(false);
const [modelTablePage, setModelTablePage] = useState(1); const [modelTablePage, setModelTablePage] = useState(1);
const [selectedEndpointType, setSelectedEndpointType] = useState(''); const [selectedEndpointType, setSelectedEndpointType] = useState('');
const [isStreamTest, setIsStreamTest] = useState(false);
const [globalPassThroughEnabled, setGlobalPassThroughEnabled] = const [globalPassThroughEnabled, setGlobalPassThroughEnabled] =
useState(false); useState(false);
@@ -851,7 +852,12 @@ export const useChannelsData = () => {
}; };
// Test channel - 单个模型测试,参考旧版实现 // Test channel - 单个模型测试,参考旧版实现
const testChannel = async (record, model, endpointType = '') => { const testChannel = async (
record,
model,
endpointType = '',
stream = false,
) => {
const testKey = `${record.id}-${model}`; const testKey = `${record.id}-${model}`;
// 检查是否应该停止批量测试 // 检查是否应该停止批量测试
@@ -867,6 +873,9 @@ export const useChannelsData = () => {
if (endpointType) { if (endpointType) {
url += `&endpoint_type=${endpointType}`; url += `&endpoint_type=${endpointType}`;
} }
if (stream) {
url += `&stream=true`;
}
const res = await API.get(url); const res = await API.get(url);
// 检查是否在请求期间被停止 // 检查是否在请求期间被停止
@@ -995,7 +1004,12 @@ export const useChannelsData = () => {
); );
const batchPromises = batch.map((model) => const batchPromises = batch.map((model) =>
testChannel(currentTestChannel, model, selectedEndpointType), testChannel(
currentTestChannel,
model,
selectedEndpointType,
isStreamTest,
),
); );
const batchResults = await Promise.allSettled(batchPromises); const batchResults = await Promise.allSettled(batchPromises);
results.push(...batchResults); results.push(...batchResults);
@@ -1080,6 +1094,7 @@ export const useChannelsData = () => {
setSelectedModelKeys([]); setSelectedModelKeys([]);
setModelTablePage(1); setModelTablePage(1);
setSelectedEndpointType(''); setSelectedEndpointType('');
setIsStreamTest(false);
// 可选择性保留测试结果,这里不清空以便用户查看 // 可选择性保留测试结果,这里不清空以便用户查看
}; };
@@ -1170,6 +1185,8 @@ export const useChannelsData = () => {
setModelTablePage, setModelTablePage,
selectedEndpointType, selectedEndpointType,
setSelectedEndpointType, setSelectedEndpointType,
isStreamTest,
setIsStreamTest,
allSelectingRef, allSelectingRef,
// Multi-key management states // Multi-key management states

View File

@@ -1548,6 +1548,7 @@
"流": "stream", "流": "stream",
"流式响应完成": "Streaming response completed", "流式响应完成": "Streaming response completed",
"流式输出": "Streaming Output", "流式输出": "Streaming Output",
"流式": "Streaming",
"流量端口": "Traffic Port", "流量端口": "Traffic Port",
"浅色": "Light", "浅色": "Light",
"浅色模式": "Light Mode", "浅色模式": "Light Mode",

View File

@@ -1558,6 +1558,7 @@
"流": "Flux", "流": "Flux",
"流式响应完成": "Flux terminé", "流式响应完成": "Flux terminé",
"流式输出": "Sortie en flux", "流式输出": "Sortie en flux",
"流式": "Streaming",
"流量端口": "Traffic Port", "流量端口": "Traffic Port",
"浅色": "Clair", "浅色": "Clair",
"浅色模式": "Mode clair", "浅色模式": "Mode clair",

View File

@@ -1543,6 +1543,7 @@
"流": "ストリーム", "流": "ストリーム",
"流式响应完成": "ストリーム完了", "流式响应完成": "ストリーム完了",
"流式输出": "ストリーム出力", "流式输出": "ストリーム出力",
"流式": "ストリーミング",
"流量端口": "Traffic Port", "流量端口": "Traffic Port",
"浅色": "ライト", "浅色": "ライト",
"浅色模式": "ライトモード", "浅色模式": "ライトモード",

View File

@@ -1569,6 +1569,7 @@
"流": "Поток", "流": "Поток",
"流式响应完成": "Поток завершён", "流式响应完成": "Поток завершён",
"流式输出": "Потоковый вывод", "流式输出": "Потоковый вывод",
"流式": "Стриминг",
"流量端口": "Traffic Port", "流量端口": "Traffic Port",
"浅色": "Светлая", "浅色": "Светлая",
"浅色模式": "Светлый режим", "浅色模式": "Светлый режим",

View File

@@ -1597,6 +1597,7 @@
"流": "luồng", "流": "luồng",
"流式响应完成": "Luồng hoàn tất", "流式响应完成": "Luồng hoàn tất",
"流式输出": "Đầu ra luồng", "流式输出": "Đầu ra luồng",
"流式": "Streaming",
"流量端口": "Traffic Port", "流量端口": "Traffic Port",
"浅色": "Sáng", "浅色": "Sáng",
"浅色模式": "Chế độ sáng", "浅色模式": "Chế độ sáng",

View File

@@ -1538,6 +1538,7 @@
"流": "流", "流": "流",
"流式响应完成": "流式响应完成", "流式响应完成": "流式响应完成",
"流式输出": "流式输出", "流式输出": "流式输出",
"流式": "流式",
"流量端口": "流量端口", "流量端口": "流量端口",
"浅色": "浅色", "浅色": "浅色",
"浅色模式": "浅色模式", "浅色模式": "浅色模式",