Merge pull request #2875 from seefs001/feature/channel-test-stream

feat: channel test with stream=true
This commit is contained in:
Calcium-Ion
2026-02-08 00:17:07 +08:00
committed by GitHub
9 changed files with 230 additions and 40 deletions

View File

@@ -31,6 +31,7 @@ import (
"github.com/bytedance/gopkg/util/gopool"
"github.com/samber/lo"
"github.com/tidwall/gjson"
"github.com/gin-gonic/gin"
)
@@ -41,7 +42,21 @@ type testResult struct {
newAPIError *types.NewAPIError
}
func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointType string) string {
normalized := strings.TrimSpace(endpointType)
if normalized != "" {
return normalized
}
if strings.HasSuffix(modelName, ratio_setting.CompactModelSuffix) {
return string(constant.EndpointTypeOpenAIResponseCompact)
}
if channel != nil && channel.Type == constant.ChannelTypeCodex {
return string(constant.EndpointTypeOpenAIResponse)
}
return normalized
}
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
tik := time.Now()
var unsupportedTestChannelTypes = []int{
constant.ChannelTypeMidjourney,
@@ -76,6 +91,8 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
endpointType = normalizeChannelTestEndpoint(channel, testModel, endpointType)
requestPath := "/v1/chat/completions"
// 如果指定了端点类型,使用指定的端点类型
@@ -200,7 +217,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
request := buildTestRequest(testModel, endpointType, channel)
request := buildTestRequest(testModel, endpointType, channel, isStream)
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
@@ -418,16 +435,16 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
newAPIError: respErr,
}
}
if usageA == nil {
usage, usageErr := coerceTestUsage(usageA, isStream, info.GetEstimatePromptTokens())
if usageErr != nil {
return testResult{
context: c,
localErr: errors.New("usage is nil"),
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
localErr: usageErr,
newAPIError: types.NewOpenAIError(usageErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
usage := usageA.(*dto.Usage)
result := w.Result()
respBody, err := io.ReadAll(result.Body)
respBody, err := readTestResponseBody(result.Body, isStream)
if err != nil {
return testResult{
context: c,
@@ -435,6 +452,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
}
}
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
return testResult{
context: c,
localErr: bodyErr,
newAPIError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
info.SetEstimatePromptTokens(usage.PromptTokens)
quota := 0
@@ -473,7 +497,101 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request {
func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
switch u := usageAny.(type) {
case *dto.Usage:
return u, nil
case dto.Usage:
return &u, nil
case nil:
if !isStream {
return nil, errors.New("usage is nil")
}
usage := &dto.Usage{
PromptTokens: estimatePromptTokens,
}
usage.TotalTokens = usage.PromptTokens
return usage, nil
default:
if !isStream {
return nil, fmt.Errorf("invalid usage type: %T", usageAny)
}
usage := &dto.Usage{
PromptTokens: estimatePromptTokens,
}
usage.TotalTokens = usage.PromptTokens
return usage, nil
}
}
func readTestResponseBody(body io.ReadCloser, isStream bool) ([]byte, error) {
defer func() { _ = body.Close() }()
const maxStreamLogBytes = 8 << 10
if isStream {
return io.ReadAll(io.LimitReader(body, maxStreamLogBytes))
}
return io.ReadAll(body)
}
func detectErrorFromTestResponseBody(respBody []byte) error {
b := bytes.TrimSpace(respBody)
if len(b) == 0 {
return nil
}
if message := detectErrorMessageFromJSONBytes(b); message != "" {
return fmt.Errorf("upstream error: %s", message)
}
for _, line := range bytes.Split(b, []byte{'\n'}) {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
if message := detectErrorMessageFromJSONBytes(payload); message != "" {
return fmt.Errorf("upstream error: %s", message)
}
}
return nil
}
func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
if len(jsonBytes) == 0 {
return ""
}
if jsonBytes[0] != '{' && jsonBytes[0] != '[' {
return ""
}
errVal := gjson.GetBytes(jsonBytes, "error")
if !errVal.Exists() || errVal.Type == gjson.Null {
return ""
}
message := gjson.GetBytes(jsonBytes, "error.message").String()
if message == "" {
message = gjson.GetBytes(jsonBytes, "error.error.message").String()
}
if message == "" && errVal.Type == gjson.String {
message = errVal.String()
}
if message == "" {
message = errVal.Raw
}
message = strings.TrimSpace(message)
if message == "" {
return "upstream returned error payload"
}
return message
}
func buildTestRequest(model string, endpointType string, channel *model.Channel, isStream bool) dto.Request {
testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`)
// 根据端点类型构建不同的测试请求
@@ -504,8 +622,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
case constant.EndpointTypeOpenAIResponse:
// 返回 OpenAIResponsesRequest
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Stream: isStream,
}
case constant.EndpointTypeOpenAIResponseCompact:
// 返回 OpenAIResponsesCompactionRequest
@@ -519,9 +638,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
maxTokens = 3000
}
return &dto.GeneralOpenAIRequest{
req := &dto.GeneralOpenAIRequest{
Model: model,
Stream: false,
Stream: isStream,
Messages: []dto.Message{
{
Role: "user",
@@ -530,6 +649,10 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
},
MaxTokens: maxTokens,
}
if isStream {
req.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
}
return req
}
}
@@ -565,15 +688,16 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
// Responses-only models (e.g. codex series)
if strings.Contains(strings.ToLower(model), "codex") {
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Stream: isStream,
}
}
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
testRequest := &dto.GeneralOpenAIRequest{
Model: model,
Stream: false,
Stream: isStream,
Messages: []dto.Message{
{
Role: "user",
@@ -581,6 +705,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
},
},
}
if isStream {
testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
}
if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = 16
@@ -618,8 +745,9 @@ func TestChannel(c *gin.Context) {
//}()
testModel := c.Query("model")
endpointType := c.Query("endpoint_type")
isStream, _ := strconv.ParseBool(c.Query("stream"))
tik := time.Now()
result := testChannel(channel, testModel, endpointType)
result := testChannel(channel, testModel, endpointType, isStream)
if result.localErr != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -678,7 +806,7 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
result := testChannel(channel, "", "")
result := testChannel(channel, "", "", false)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()