diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index cce9235b5..4b13a7df1 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -2,6 +2,7 @@ package openai import ( "bytes" + "encoding/json" "fmt" "io" "math" @@ -280,11 +281,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { defer service.CloseResponseBodyGracefully(resp) - // count tokens by audio file duration - audioTokens, err := countAudioTokens(c) - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed), nil - } responseBody, err := io.ReadAll(resp.Body) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil @@ -292,6 +288,26 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) + var responseData struct { + Usage *dto.Usage `json:"usage"` + } + if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil { + if responseData.Usage.TotalTokens > 0 { + usage := responseData.Usage + if usage.PromptTokens == 0 { + usage.PromptTokens = usage.InputTokens + } + if usage.CompletionTokens == 0 { + usage.CompletionTokens = usage.OutputTokens + } + return nil, usage + } + } + + audioTokens, err := countAudioTokens(c) + if err != nil { + return types.NewError(err, types.ErrorCodeCountTokenFailed), nil + } usage := &dto.Usage{} usage.PromptTokens = audioTokens usage.CompletionTokens = 0