diff --git a/dto/audio.go b/dto/audio.go index ea51516f8..c6f5b9479 100644 --- a/dto/audio.go +++ b/dto/audio.go @@ -2,6 +2,7 @@ package dto import ( "encoding/json" + "strings" "github.com/QuantumNous/new-api/types" @@ -24,11 +25,14 @@ func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta { CombineText: r.Input, TokenType: types.TokenTypeTextNumber, } + if strings.Contains(r.Model, "gpt") { + meta.TokenType = types.TokenTypeTokenizer + } return meta } func (r *AudioRequest) IsStream(c *gin.Context) bool { - return false + return r.StreamFormat == "sse" } func (r *AudioRequest) SetModelName(modelName string) { diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 15fbb9390..39eb03d39 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -67,8 +67,11 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - - postConsumeQuota(c, info, usage.(*dto.Usage), "") + if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 { + service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") + } else { + postConsumeQuota(c, info, usage.(*dto.Usage), "") + } return nil } diff --git a/relay/channel/openai/audio.go b/relay/channel/openai/audio.go new file mode 100644 index 000000000..b267dcfbb --- /dev/null +++ b/relay/channel/openai/audio.go @@ -0,0 +1,145 @@ +package openai + +import ( + "bytes" + "fmt" + "io" + "math" + "net/http" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/relay/helper" + "github.com/QuantumNous/new-api/service" + "github.com/QuantumNous/new-api/types" + "github.com/gin-gonic/gin" +) + +func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage { + // the status code has been judged before, if there is a body reading failure, + // it should be regarded as a non-recoverable error, so it should not return err for external retry. + // Analogous to nginx's load balancing, it will only retry if it can't be requested or + // if the upstream returns a specific status code, once the upstream has already written the header, + // the subsequent failure of the response body should be regarded as a non-recoverable error, + // and can be terminated directly. + defer service.CloseResponseBodyGracefully(resp) + usage := &dto.Usage{} + usage.PromptTokens = info.GetEstimatePromptTokens() + usage.TotalTokens = info.GetEstimatePromptTokens() + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + if info.IsStream { + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + if service.SundaySearch(data, "usage") { + var simpleResponse dto.SimpleResponse + err := common.Unmarshal([]byte(data), &simpleResponse) + if err != nil { + logger.LogError(c, err.Error()) + } + if simpleResponse.Usage.TotalTokens != 0 { + usage.PromptTokens = simpleResponse.Usage.InputTokens + usage.CompletionTokens = simpleResponse.OutputTokens + usage.TotalTokens = simpleResponse.TotalTokens + } + } + _ = helper.StringData(c, data) + return true + }) + } else { + common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true) + // 读取响应体到缓冲区 + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + logger.LogError(c, fmt.Sprintf("failed to read TTS response body: %v", err)) + c.Writer.WriteHeaderNow() + return usage + } + + // 写入响应到客户端 + c.Writer.WriteHeaderNow() + _, err = c.Writer.Write(bodyBytes) + if err != nil { + logger.LogError(c, fmt.Sprintf("failed to write TTS response: %v", err)) + } + + // 计算音频时长并更新 usage + audioFormat := "mp3" // 默认格式 + if audioReq, ok := info.Request.(*dto.AudioRequest); ok && audioReq.ResponseFormat != "" { + audioFormat = audioReq.ResponseFormat + } + + var duration float64 + var durationErr error + + if audioFormat == "pcm" { + // PCM 格式没有文件头,根据 OpenAI TTS 的 PCM 参数计算时长 + // 采样率: 24000 Hz, 位深度: 16-bit (2 bytes), 声道数: 1 + const sampleRate = 24000 + const bytesPerSample = 2 + const channels = 1 + duration = float64(len(bodyBytes)) / float64(sampleRate*bytesPerSample*channels) + } else { + ext := "." + audioFormat + reader := bytes.NewReader(bodyBytes) + duration, durationErr = common.GetAudioDuration(c.Request.Context(), reader, ext) + } + + usage.PromptTokensDetails.TextTokens = usage.PromptTokens + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + + if durationErr != nil { + logger.LogWarn(c, fmt.Sprintf("failed to get audio duration: %v", durationErr)) + // 如果无法获取时长,则设置保底的 CompletionTokens,根据body大小计算 + sizeInKB := float64(len(bodyBytes)) / 1000.0 + estimatedTokens := int(math.Ceil(sizeInKB)) // 粗略估算每KB约等于1 token + usage.CompletionTokens = estimatedTokens + usage.CompletionTokenDetails.AudioTokens = estimatedTokens + } else if duration > 0 { + // 计算 token: ceil(duration) / 60.0 * 1000,即每分钟 1000 tokens + completionTokens := int(math.Round(math.Ceil(duration) / 60.0 * 1000)) + usage.CompletionTokens = completionTokens + usage.CompletionTokenDetails.AudioTokens = completionTokens + } + } + + return usage +} + +func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { + defer service.CloseResponseBodyGracefully(resp) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil + } + // 写入新的 response body + service.IOCopyBytesGracefully(c, resp, responseBody) + + var responseData struct { + Usage *dto.Usage `json:"usage"` + } + if err := common.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil { + if responseData.Usage.TotalTokens > 0 { + usage := responseData.Usage + if usage.PromptTokens == 0 { + usage.PromptTokens = usage.InputTokens + } + if usage.CompletionTokens == 0 { + usage.CompletionTokens = usage.OutputTokens + } + return nil, usage + } + } + + usage := &dto.Usage{} + usage.PromptTokens = info.GetEstimatePromptTokens() + usage.CompletionTokens = 0 + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return nil, usage +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 8c55ae7a7..5819f7071 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -1,7 +1,6 @@ package openai import ( - "encoding/json" "fmt" "io" "net/http" @@ -151,7 +150,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var streamResp struct { Usage *dto.Usage `json:"usage"` } - err := json.Unmarshal([]byte(secondLastStreamData), &streamResp) + err := common.Unmarshal([]byte(secondLastStreamData), &streamResp) if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) { usage = streamResp.Usage containStreamUsage = true @@ -327,68 +326,6 @@ func streamTTSResponse(c *gin.Context, resp *http.Response) { } } -func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage { - // the status code has been judged before, if there is a body reading failure, - // it should be regarded as a non-recoverable error, so it should not return err for external retry. - // Analogous to nginx's load balancing, it will only retry if it can't be requested or - // if the upstream returns a specific status code, once the upstream has already written the header, - // the subsequent failure of the response body should be regarded as a non-recoverable error, - // and can be terminated directly. - defer service.CloseResponseBodyGracefully(resp) - usage := &dto.Usage{} - usage.PromptTokens = info.GetEstimatePromptTokens() - usage.TotalTokens = info.GetEstimatePromptTokens() - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - - isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == "" - if isStreaming { - streamTTSResponse(c, resp) - } else { - c.Writer.WriteHeaderNow() - _, err := io.Copy(c.Writer, resp.Body) - if err != nil { - logger.LogError(c, err.Error()) - } - } - return usage -} - -func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { - defer service.CloseResponseBodyGracefully(resp) - - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil - } - // 写入新的 response body - service.IOCopyBytesGracefully(c, resp, responseBody) - - var responseData struct { - Usage *dto.Usage `json:"usage"` - } - if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil { - if responseData.Usage.TotalTokens > 0 { - usage := responseData.Usage - if usage.PromptTokens == 0 { - usage.PromptTokens = usage.InputTokens - } - if usage.CompletionTokens == 0 { - usage.CompletionTokens = usage.OutputTokens - } - return nil, usage - } - } - - usage := &dto.Usage{} - usage.PromptTokens = info.GetEstimatePromptTokens() - usage.CompletionTokens = 0 - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return nil, usage -} - func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) { if info == nil || info.ClientWs == nil || info.TargetWs == nil { return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil @@ -687,7 +624,7 @@ func extractCachedTokensFromBody(body []byte) (int, bool) { } `json:"usage"` } - if err := json.Unmarshal(body, &payload); err != nil { + if err := common.Unmarshal(body, &payload); err != nil { return 0, false } diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 60934505d..f46ff9de9 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -181,7 +181,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types return newApiErr } - if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { + if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 { service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { postConsumeQuota(c, info, usage.(*dto.Usage), "") diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index bef82e57e..89e768a05 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -536,7 +536,7 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) { if name == "gpt-4o-2024-05-13" { return 3, true } - return 4, true + return 4, false } // gpt-5 匹配 if strings.HasPrefix(name, "gpt-5") {