diff --git a/common/model.go b/common/model.go index 14ca19115..4ebc7b532 100644 --- a/common/model.go +++ b/common/model.go @@ -17,6 +17,13 @@ var ( "flux-", "flux.1-", } + OpenAITextModels = []string{ + "gpt-", + "o1", + "o3", + "o4", + "chatgpt", + } ) func IsOpenAIResponseOnlyModel(modelName string) bool { @@ -40,3 +47,13 @@ func IsImageGenerationModel(modelName string) bool { } return false } + +func IsOpenAITextModel(modelName string) bool { + modelName = strings.ToLower(modelName) + for _, m := range OpenAITextModels { + if strings.Contains(modelName, m) { + return true + } + } + return false +} diff --git a/constant/context_key.go b/constant/context_key.go index e5461d43c..4de704619 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -3,8 +3,9 @@ package constant type ContextKey string const ( - ContextKeyTokenCountMeta ContextKey = "token_count_meta" - ContextKeyPromptTokens ContextKey = "prompt_tokens" + ContextKeyTokenCountMeta ContextKey = "token_count_meta" + ContextKeyPromptTokens ContextKey = "prompt_tokens" + ContextKeyEstimatedTokens ContextKey = "estimated_tokens" ContextKeyOriginalModel ContextKey = "original_model" ContextKeyRequestStartTime ContextKey = "request_start_time" diff --git a/controller/channel-test.go b/controller/channel-test.go index 171cca22b..1c77fb030 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -351,7 +351,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), } } - info.PromptTokens = usage.PromptTokens + info.SetEstimatePromptTokens(usage.PromptTokens) quota := 0 if !priceData.UsePrice { diff --git a/controller/relay.go b/controller/relay.go index f8a233e99..347083521 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -125,13 +125,13 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { } } - tokens, err := service.CountRequestToken(c, meta, relayInfo) + tokens, err := service.EstimateRequestToken(c, meta, relayInfo) if err != nil { newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed) return } - relayInfo.SetPromptTokens(tokens) + relayInfo.SetEstimatePromptTokens(tokens) priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta) if err != nil { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index ea6515b22..b815a69fb 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -673,7 +673,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { if requestMode == RequestModeCompletion { - claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) + claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens()) } else { if claudeInfo.Usage.PromptTokens == 0 { //上游出错 @@ -734,10 +734,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud return types.WithClaudeError(*claudeError, http.StatusInternalServerError) } if requestMode == RequestModeCompletion { - completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) - claudeInfo.Usage.PromptTokens = info.PromptTokens - claudeInfo.Usage.CompletionTokens = completionTokens - claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens + claudeInfo.Usage = service.ResponseText2Usage(c, claudeResponse.Completion, info.UpstreamModelName, info.GetEstimatePromptTokens()) } else { claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index c7b4f2b26..cb8a641a1 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -74,7 +74,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res if err := scanner.Err(); err != nil { logger.LogError(c, "error_scanning_stream_response: "+err.Error()) } - usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens()) if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) err := helper.ObjectData(c, response) @@ -105,7 +105,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) for _, choice := range response.Choices { responseText += choice.Message.StringContent() } - usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens()) response.Usage = *usage response.Id = helper.GetResponseID(c) jsonResponse, err := json.Marshal(response) @@ -142,10 +142,6 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon c.Writer.WriteHeader(resp.StatusCode) _, _ = c.Writer.Write(jsonResponse) - usage := &dto.Usage{} - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - + usage := service.ResponseText2Usage(c, cfResp.Result.Text, info.UpstreamModelName, info.GetEstimatePromptTokens()) return nil, usage } diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 2e23e3867..d51c05499 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -165,7 +165,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } }) if usage.PromptTokens == 0 { - usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens()) } return usage, nil } @@ -225,9 +225,9 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon. } usage := dto.Usage{} if cohereResp.Meta.BilledUnits.InputTokens == 0 { - usage.PromptTokens = info.PromptTokens + usage.PromptTokens = info.GetEstimatePromptTokens() usage.CompletionTokens = 0 - usage.TotalTokens = info.PromptTokens + usage.TotalTokens = info.GetEstimatePromptTokens() } else { usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 8f58e86c6..24f5218a4 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -246,7 +246,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R }) helper.Done(c) if usage.TotalTokens == 0 { - usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens()) } usage.CompletionTokens += nodeToken return usage, nil diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index c3f7aa0a6..f25d9ebf0 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -5,7 +5,6 @@ import ( "net/http" "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/logger" relaycommon "github.com/QuantumNous/new-api/relay/common" @@ -70,12 +69,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel println(string(responseBody)) } - usage := &dto.Usage{ - PromptTokens: info.PromptTokens, - TotalTokens: info.PromptTokens, - } - - common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true) + usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens()) if info.IsGeminiBatchEmbedding { var geminiResponse dto.GeminiBatchEmbeddingResponse diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 2f855d028..ae892ed89 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1115,7 +1115,7 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http if usage.CompletionTokens <= 0 { str := responseText.String() if len(str) > 0 { - usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens()) } else { usage = &dto.Usage{} } @@ -1288,11 +1288,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h // Google has not yet clarified how embedding models will be billed // refer to openai billing method to use input tokens billing // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings - usage := &dto.Usage{ - PromptTokens: info.PromptTokens, - CompletionTokens: 0, - TotalTokens: info.PromptTokens, - } + usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens()) openAIResponse.Usage = *usage jsonResponse, jsonErr := common.Marshal(openAIResponse) diff --git a/relay/channel/minimax/tts.go b/relay/channel/minimax/tts.go index 4a52d2145..8900f5a9f 100644 --- a/relay/channel/minimax/tts.go +++ b/relay/channel/minimax/tts.go @@ -163,7 +163,7 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re } usage = &dto.Usage{ - PromptTokens: info.PromptTokens, + PromptTokens: info.GetEstimatePromptTokens(), CompletionTokens: 0, TotalTokens: int(minimaxResp.ExtraInfo.UsageCharacters), } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index eafb11d99..8c55ae7a7 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -183,7 +183,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } if !containStreamUsage { - usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens()) usage.CompletionTokens += toolCount * 7 } @@ -245,9 +245,9 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } } simpleResponse.Usage = dto.Usage{ - PromptTokens: info.PromptTokens, + PromptTokens: info.GetEstimatePromptTokens(), CompletionTokens: completionTokens, - TotalTokens: info.PromptTokens + completionTokens, + TotalTokens: info.GetEstimatePromptTokens() + completionTokens, } usageModified = true } @@ -336,8 +336,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel // and can be terminated directly. defer service.CloseResponseBodyGracefully(resp) usage := &dto.Usage{} - usage.PromptTokens = info.PromptTokens - usage.TotalTokens = info.PromptTokens + usage.PromptTokens = info.GetEstimatePromptTokens() + usage.TotalTokens = info.GetEstimatePromptTokens() for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } @@ -383,7 +383,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } usage := &dto.Usage{} - usage.PromptTokens = info.PromptTokens + usage.PromptTokens = info.GetEstimatePromptTokens() usage.CompletionTokens = 0 usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 3f8eb69a8..b92c8c723 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -141,7 +141,7 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp } if usage.PromptTokens == 0 && usage.CompletionTokens != 0 { - usage.PromptTokens = info.PromptTokens + usage.PromptTokens = info.GetEstimatePromptTokens() } usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 3ae900981..3c1302d81 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -81,7 +81,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) - usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens()) } else { usage, err = palmHandler(c, info, resp) } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index abfb92c0e..786ea4cd2 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -121,13 +121,8 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons }, resp.StatusCode) } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName) - usage := dto.Usage{ - PromptTokens: info.PromptTokens, - CompletionTokens: completionTokens, - TotalTokens: info.PromptTokens + completionTokens, - } - fullTextResponse.Usage = usage + usage := service.ResponseText2Usage(c, palmResponse.Candidates[0].Content, info.UpstreamModelName, info.GetEstimatePromptTokens()) + fullTextResponse.Usage = *usage jsonResponse, err := common.Marshal(fullTextResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) @@ -135,5 +130,5 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) service.IOCopyBytesGracefully(c, resp, jsonResponse) - return &usage, nil + return usage, nil } diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index 77192c0d2..dbe7750e4 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -105,7 +105,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt data = strings.TrimPrefix(data, "data:") var tencentResponse TencentChatResponse - err := json.Unmarshal([]byte(data), &tencentResponse) + err := common.Unmarshal([]byte(data), &tencentResponse) if err != nil { common.SysLog("error unmarshalling stream response: " + err.Error()) continue @@ -130,7 +130,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt service.CloseResponseBodyGracefully(resp) - return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens), nil + return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens()), nil } func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { diff --git a/relay/channel/volcengine/tts.go b/relay/channel/volcengine/tts.go index 166fab8ef..2b03981d4 100644 --- a/relay/channel/volcengine/tts.go +++ b/relay/channel/volcengine/tts.go @@ -184,9 +184,9 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re c.Data(http.StatusOK, contentType, audioData) usage = &dto.Usage{ - PromptTokens: info.PromptTokens, + PromptTokens: info.GetEstimatePromptTokens(), CompletionTokens: 0, - TotalTokens: info.PromptTokens, + TotalTokens: info.GetEstimatePromptTokens(), } return usage, nil @@ -284,9 +284,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V if msg.Sequence < 0 { c.Status(http.StatusOK) usage = &dto.Usage{ - PromptTokens: info.PromptTokens, + PromptTokens: info.GetEstimatePromptTokens(), CompletionTokens: 0, - TotalTokens: info.PromptTokens, + TotalTokens: info.GetEstimatePromptTokens(), } return usage, nil } @@ -297,9 +297,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V c.Status(http.StatusOK) usage = &dto.Usage{ - PromptTokens: info.PromptTokens, + PromptTokens: info.GetEstimatePromptTokens(), CompletionTokens: 0, - TotalTokens: info.PromptTokens, + TotalTokens: info.GetEstimatePromptTokens(), } return usage, nil } diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index a5acbd2a0..aa4d329f3 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -70,7 +70,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re }) if !containStreamUsage { - usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens()) usage.CompletionTokens += toolCount * 7 } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 33ef4d14c..1882eca89 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -73,6 +73,11 @@ type ChannelMeta struct { SupportStreamOptions bool // 是否支持流式选项 } +type TokenCountMeta struct { + //promptTokens int + estimatePromptTokens int +} + type RelayInfo struct { TokenId int TokenKey string @@ -91,7 +96,6 @@ type RelayInfo struct { RelayMode int OriginModelName string RequestURLPath string - PromptTokens int ShouldIncludeUsage bool DisablePing bool // 是否禁止向下游发送自定义 Ping ClientWs *websocket.Conn @@ -115,6 +119,7 @@ type RelayInfo struct { Request dto.Request ThinkingContentInfo + TokenCountMeta *ClaudeConvertInfo *RerankerInfo *ResponsesUsageInfo @@ -189,7 +194,7 @@ func (info *RelayInfo) ToString() string { fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground) fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath) fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName) - fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens) + fmt.Fprintf(b, "EstimatePromptTokens: %d, ", info.estimatePromptTokens) fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage) fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing) fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount) @@ -391,7 +396,6 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), - PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens), TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId), TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey), @@ -408,6 +412,10 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { IsFirstThinkingContent: true, SendLastThinkingContent: false, }, + TokenCountMeta: TokenCountMeta{ + //promptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens), + estimatePromptTokens: common.GetContextKeyInt(c, constant.ContextKeyEstimatedTokens), + }, } if info.RelayMode == relayconstant.RelayModeUnknown { @@ -463,8 +471,16 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req } } -func (info *RelayInfo) SetPromptTokens(promptTokens int) { - info.PromptTokens = promptTokens +//func (info *RelayInfo) SetPromptTokens(promptTokens int) { +// info.promptTokens = promptTokens +//} + +func (info *RelayInfo) SetEstimatePromptTokens(promptTokens int) { + info.estimatePromptTokens = promptTokens +} + +func (info *RelayInfo) GetEstimatePromptTokens() int { + return info.estimatePromptTokens } func (info *RelayInfo) SetFirstResponseTime() { diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index daf005df4..f52a91b03 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -57,8 +57,8 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo jinaResp = dto.RerankResponse{ Results: jinaRespResults, Usage: dto.Usage{ - PromptTokens: info.PromptTokens, - TotalTokens: info.PromptTokens, + PromptTokens: info.GetEstimatePromptTokens(), + TotalTokens: info.GetEstimatePromptTokens(), }, } } else { diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index cb3b5d5f2..60934505d 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -192,9 +192,9 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { if usage == nil { usage = &dto.Usage{ - PromptTokens: relayInfo.PromptTokens, + PromptTokens: relayInfo.GetEstimatePromptTokens(), CompletionTokens: 0, - TotalTokens: relayInfo.PromptTokens, + TotalTokens: relayInfo.GetEstimatePromptTokens(), } extraContent += "(可能是请求出错)" } diff --git a/service/convert.go b/service/convert.go index 975ab2d0b..93fff2386 100644 --- a/service/convert.go +++ b/service/convert.go @@ -209,7 +209,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon Type: "message", Role: "assistant", Usage: &dto.ClaudeUsage{ - InputTokens: info.PromptTokens, + InputTokens: info.GetEstimatePromptTokens(), OutputTokens: 0, }, } @@ -734,12 +734,18 @@ func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamRespon geminiResponse := &dto.GeminiChatResponse{ Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)), UsageMetadata: dto.GeminiUsageMetadata{ - PromptTokenCount: info.PromptTokens, + PromptTokenCount: info.GetEstimatePromptTokens(), CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息 - TotalTokenCount: info.PromptTokens, + TotalTokenCount: info.GetEstimatePromptTokens(), }, } + if openAIResponse.Usage != nil { + geminiResponse.UsageMetadata.PromptTokenCount = openAIResponse.Usage.PromptTokens + geminiResponse.UsageMetadata.CandidatesTokenCount = openAIResponse.Usage.CompletionTokens + geminiResponse.UsageMetadata.TotalTokenCount = openAIResponse.Usage.TotalTokens + } + for _, choice := range openAIResponse.Choices { candidate := dto.GeminiChatCandidate{ Index: int64(choice.Index), diff --git a/service/token_counter.go b/service/token_counter.go index e4bd1a3cc..4f004b5bb 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -1,7 +1,6 @@ package service import ( - "encoding/json" "errors" "fmt" "image" @@ -12,7 +11,6 @@ import ( "math" "path/filepath" "strings" - "sync" "unicode/utf8" "github.com/QuantumNous/new-api/common" @@ -23,64 +21,8 @@ import ( "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" - "github.com/tiktoken-go/tokenizer" - "github.com/tiktoken-go/tokenizer/codec" ) -// tokenEncoderMap won't grow after initialization -var defaultTokenEncoder tokenizer.Codec - -// tokenEncoderMap is used to store token encoders for different models -var tokenEncoderMap = make(map[string]tokenizer.Codec) - -// tokenEncoderMutex protects tokenEncoderMap for concurrent access -var tokenEncoderMutex sync.RWMutex - -func InitTokenEncoders() { - common.SysLog("initializing token encoders") - defaultTokenEncoder = codec.NewCl100kBase() - common.SysLog("token encoders initialized") -} - -func getTokenEncoder(model string) tokenizer.Codec { - // First, try to get the encoder from cache with read lock - tokenEncoderMutex.RLock() - if encoder, exists := tokenEncoderMap[model]; exists { - tokenEncoderMutex.RUnlock() - return encoder - } - tokenEncoderMutex.RUnlock() - - // If not in cache, create new encoder with write lock - tokenEncoderMutex.Lock() - defer tokenEncoderMutex.Unlock() - - // Double-check if another goroutine already created the encoder - if encoder, exists := tokenEncoderMap[model]; exists { - return encoder - } - - // Create new encoder - modelCodec, err := tokenizer.ForModel(tokenizer.Model(model)) - if err != nil { - // Cache the default encoder for this model to avoid repeated failures - tokenEncoderMap[model] = defaultTokenEncoder - return defaultTokenEncoder - } - - // Cache the new encoder - tokenEncoderMap[model] = modelCodec - return modelCodec -} - -func getTokenNum(tokenEncoder tokenizer.Codec, text string) int { - if text == "" { - return 0 - } - tkm, _ := tokenEncoder.Count(text) - return tkm -} - func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) { if fileMeta == nil { return 0, fmt.Errorf("image_url_is_nil") @@ -257,7 +199,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er return tiles*tileTokens + baseTokens, nil } -func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) { +func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) { // 是否统计token if !constant.CountToken { return 0, nil @@ -375,14 +317,14 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco for i, file := range meta.Files { switch file.FileType { case types.FileTypeImage: - if info.RelayFormat == types.RelayFormatGemini { - tkm += 520 // gemini per input image tokens - } else { + if common.IsOpenAITextModel(info.UpstreamModelName) { token, err := getImageToken(file, model, info.IsStream) if err != nil { return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err) } tkm += token + } else { + tkm += 520 } case types.FileTypeAudio: tkm += 256 @@ -399,111 +341,6 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco return tkm, nil } -func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) { - tkm := 0 - - // Count tokens in messages - msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream) - if err != nil { - return 0, err - } - tkm += msgTokens - - // Count tokens in system message - if request.System != "" { - systemTokens := CountTokenInput(request.System, model) - tkm += systemTokens - } - - if request.Tools != nil { - // check is array - if tools, ok := request.Tools.([]any); ok { - if len(tools) > 0 { - parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools) - if err1 != nil { - return 0, fmt.Errorf("tools: Input should be a valid list: %v", err) - } - toolTokens, err2 := CountTokenClaudeTools(parsedTools, model) - if err2 != nil { - return 0, fmt.Errorf("tools: %v", err) - } - tkm += toolTokens - } - } else { - return 0, errors.New("tools: Input should be a valid list") - } - } - - return tkm, nil -} - -func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) { - tokenEncoder := getTokenEncoder(model) - tokenNum := 0 - - for _, message := range messages { - // Count tokens for role - tokenNum += getTokenNum(tokenEncoder, message.Role) - if message.IsStringContent() { - tokenNum += getTokenNum(tokenEncoder, message.GetStringContent()) - } else { - content, err := message.ParseContent() - if err != nil { - return 0, err - } - for _, mediaMessage := range content { - switch mediaMessage.Type { - case "text": - tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText()) - case "image": - //imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream) - //if err != nil { - // return 0, err - //} - tokenNum += 1000 - case "tool_use": - if mediaMessage.Input != nil { - tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name) - inputJSON, _ := json.Marshal(mediaMessage.Input) - tokenNum += getTokenNum(tokenEncoder, string(inputJSON)) - } - case "tool_result": - if mediaMessage.Content != nil { - contentJSON, _ := json.Marshal(mediaMessage.Content) - tokenNum += getTokenNum(tokenEncoder, string(contentJSON)) - } - } - } - } - } - - // Add a constant for message formatting (this may need adjustment based on Claude's exact formatting) - tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting - - return tokenNum, nil -} - -func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) { - tokenEncoder := getTokenEncoder(model) - tokenNum := 0 - - for _, tool := range tools { - tokenNum += getTokenNum(tokenEncoder, tool.Name) - tokenNum += getTokenNum(tokenEncoder, tool.Description) - - schemaJSON, err := json.Marshal(tool.InputSchema) - if err != nil { - return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error())) - } - tokenNum += getTokenNum(tokenEncoder, string(schemaJSON)) - } - - // Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting) - tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting - - return tokenNum, nil -} - func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) { audioToken := 0 textToken := 0 @@ -578,31 +415,6 @@ func CountTokenInput(input any, model string) int { return CountTokenInput(fmt.Sprintf("%v", input), model) } -func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int { - tokens := 0 - for _, message := range messages { - tkm := CountTokenInput(message.Delta.GetContentString(), model) - tokens += tkm - if message.Delta.ToolCalls != nil { - for _, tool := range message.Delta.ToolCalls { - tkm := CountTokenInput(tool.Function.Name, model) - tokens += tkm - tkm = CountTokenInput(tool.Function.Arguments, model) - tokens += tkm - } - } - } - return tokens -} - -func CountTTSToken(text string, model string) int { - if strings.HasPrefix(model, "tts") { - return utf8.RuneCountInString(text) - } else { - return CountTextToken(text, model) - } -} - func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) { if audioBase64 == "" { return 0, nil @@ -625,17 +437,16 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) return int(duration / 60 * 200 / 0.24), nil } -//func CountAudioToken(sec float64, audioType string) { -// if audioType == "input" { -// -// } -//} - -// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 +// CountTextToken 统计文本的token数量,仅OpenAI模型使用tokenizer,其余模型使用估算 func CountTextToken(text string, model string) int { if text == "" { return 0 } - tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text) + if common.IsOpenAITextModel(model) { + tokenEncoder := getTokenEncoder(model) + return getTokenNum(tokenEncoder, text) + } else { + // 非openai模型,使用tiktoken-go计算没有意义,使用估算节省资源 + return EstimateTokenByModel(model, text) + } } diff --git a/service/token_estimator.go b/service/token_estimator.go new file mode 100644 index 000000000..1579db1c9 --- /dev/null +++ b/service/token_estimator.go @@ -0,0 +1,223 @@ +package service + +import ( + "math" + "strings" + "unicode" +) + +// Provider 定义模型厂商大类 +type Provider string + +const ( + OpenAI Provider = "openai" // 代表 GPT-3.5, GPT-4, GPT-4o + Gemini Provider = "gemini" // 代表 Gemini 1.0, 1.5 Pro/Flash + Claude Provider = "claude" // 代表 Claude 3, 3.5 Sonnet + Unknown Provider = "unknown" // 兜底默认 +) + +// multipliers 定义不同厂商的计费权重 +type multipliers struct { + Word float64 // 英文单词 (每词) + Number float64 // 数字 (每连续数字串) + CJK float64 // 中日韩字符 (每字) + Symbol float64 // 普通标点符号 (每个) + MathSymbol float64 // 数学符号 (∑,∫,∂,√等,每个) + URLDelim float64 // URL分隔符 (/,:,?,&,=,#,%) - tokenizer优化好 + AtSign float64 // @符号 - 导致单词切分,消耗较高 + Emoji float64 // Emoji表情 (每个) + Newline float64 // 换行符/制表符 (每个) + Space float64 // 空格 (每个) + BasePad int // 基础起步消耗 (Start/End tokens) +} + +var multipliersMap = map[Provider]multipliers{ + Gemini: { + Word: 1.15, Number: 2.8, CJK: 0.68, Symbol: 0.38, MathSymbol: 1.05, URLDelim: 1.2, AtSign: 2.5, Emoji: 1.08, Newline: 1.15, Space: 0.2, BasePad: 0, + }, + Claude: { + Word: 1.13, Number: 1.63, CJK: 1.21, Symbol: 0.4, MathSymbol: 4.52, URLDelim: 1.26, AtSign: 2.82, Emoji: 2.6, Newline: 0.89, Space: 0.39, BasePad: 0, + }, + OpenAI: { + Word: 1.02, Number: 1.55, CJK: 0.85, Symbol: 0.4, MathSymbol: 2.68, URLDelim: 1.0, AtSign: 2.0, Emoji: 2.12, Newline: 0.5, Space: 0.42, BasePad: 0, + }, +} + +// getMultipliers 根据厂商获取权重配置 +func getMultipliers(p Provider) multipliers { + switch p { + case Gemini: + return multipliersMap[Gemini] + case Claude: + return multipliersMap[Claude] + case OpenAI: + return multipliersMap[OpenAI] + default: + // 默认兜底 (按 OpenAI 的算) + return multipliersMap[OpenAI] + } +} + +// EstimateToken 计算 Token 数量 +func EstimateToken(provider Provider, text string) int { + m := getMultipliers(provider) + var count float64 + + // 状态机变量 + type WordType int + const ( + None WordType = iota + Latin + Number + ) + currentWordType := None + + for _, r := range text { + // 1. 处理空格和换行符 + if unicode.IsSpace(r) { + currentWordType = None + // 换行符和制表符使用Newline权重 + if r == '\n' || r == '\t' { + count += m.Newline + } else { + // 普通空格使用Space权重 + count += m.Space + } + continue + } + + // 2. 处理 CJK (中日韩) - 按字符计费 + if isCJK(r) { + currentWordType = None + count += m.CJK + continue + } + + // 3. 处理Emoji - 使用专门的Emoji权重 + if isEmoji(r) { + currentWordType = None + count += m.Emoji + continue + } + + // 4. 处理拉丁字母/数字 (英文单词) + if isLatinOrNumber(r) { + isNum := unicode.IsNumber(r) + newType := Latin + if isNum { + newType = Number + } + + // 如果之前不在单词中,或者类型发生变化(字母<->数字),则视为新token + // 注意:对于OpenAI,通常"version 3.5"会切分,"abc123xyz"有时也会切分 + // 这里简单起见,字母和数字切换时增加权重 + if currentWordType == None || currentWordType != newType { + if newType == Number { + count += m.Number + } else { + count += m.Word + } + currentWordType = newType + } + // 单词中间的字符不额外计费 + continue + } + + // 5. 处理标点符号/特殊字符 - 按类型使用不同权重 + currentWordType = None + if isMathSymbol(r) { + count += m.MathSymbol + } else if r == '@' { + count += m.AtSign + } else if isURLDelim(r) { + count += m.URLDelim + } else { + count += m.Symbol + } + } + + // 向上取整并加上基础 padding + return int(math.Ceil(count)) + m.BasePad +} + +// 辅助:判断是否为 CJK 字符 +func isCJK(r rune) bool { + return unicode.Is(unicode.Han, r) || + (r >= 0x3040 && r <= 0x30FF) || // 日文 + (r >= 0xAC00 && r <= 0xD7A3) // 韩文 +} + +// 辅助:判断是否为单词主体 (字母或数字) +func isLatinOrNumber(r rune) bool { + return unicode.IsLetter(r) || unicode.IsNumber(r) +} + +// 辅助:判断是否为Emoji字符 +func isEmoji(r rune) bool { + // Emoji的Unicode范围 + // 基本范围:0x1F300-0x1F9FF (Emoticons, Symbols, Pictographs) + // 补充范围:0x2600-0x26FF (Misc Symbols), 0x2700-0x27BF (Dingbats) + // 表情符号:0x1F600-0x1F64F (Emoticons) + // 其他:0x1F900-0x1F9FF (Supplemental Symbols and Pictographs) + return (r >= 0x1F300 && r <= 0x1F9FF) || + (r >= 0x2600 && r <= 0x26FF) || + (r >= 0x2700 && r <= 0x27BF) || + (r >= 0x1F600 && r <= 0x1F64F) || + (r >= 0x1F900 && r <= 0x1F9FF) || + (r >= 0x1FA00 && r <= 0x1FAFF) // Symbols and Pictographs Extended-A +} + +// 辅助:判断是否为数学符号 +func isMathSymbol(r rune) bool { + // 数学运算符和符号 + // 基本数学符号:∑ ∫ ∂ √ ∞ ≤ ≥ ≠ ≈ ± × ÷ + // 上下标数字:² ³ ¹ ⁴ ⁵ ⁶ ⁷ ⁸ ⁹ ⁰ + // 希腊字母等也常用于数学 + mathSymbols := "∑∫∂√∞≤≥≠≈±×÷∈∉∋∌⊂⊃⊆⊇∪∩∧∨¬∀∃∄∅∆∇∝∟∠∡∢°′″‴⁺⁻⁼⁽⁾ⁿ₀₁₂₃₄₅₆₇₈₉₊₋₌₍₎²³¹⁴⁵⁶⁷⁸⁹⁰" + for _, m := range mathSymbols { + if r == m { + return true + } + } + // Mathematical Operators (U+2200–U+22FF) + if r >= 0x2200 && r <= 0x22FF { + return true + } + // Supplemental Mathematical Operators (U+2A00–U+2AFF) + if r >= 0x2A00 && r <= 0x2AFF { + return true + } + // Mathematical Alphanumeric Symbols (U+1D400–U+1D7FF) + if r >= 0x1D400 && r <= 0x1D7FF { + return true + } + return false +} + +// 辅助:判断是否为URL分隔符(tokenizer对这些优化较好) +func isURLDelim(r rune) bool { + // URL中常见的分隔符,tokenizer通常优化处理 + urlDelims := "/:?&=;#%" + for _, d := range urlDelims { + if r == d { + return true + } + } + return false +} + +func EstimateTokenByModel(model, text string) int { + // strings.Contains(model, "gpt-4o") + if text == "" { + return 0 + } + + model = strings.ToLower(model) + if strings.Contains(model, "gemini") { + return EstimateToken(Gemini, text) + } else if strings.Contains(model, "claude") { + return EstimateToken(Claude, text) + } else { + return EstimateToken(OpenAI, text) + } +} diff --git a/service/tokenizer.go b/service/tokenizer.go new file mode 100644 index 000000000..9cf632b86 --- /dev/null +++ b/service/tokenizer.go @@ -0,0 +1,63 @@ +package service + +import ( + "sync" + + "github.com/QuantumNous/new-api/common" + "github.com/tiktoken-go/tokenizer" + "github.com/tiktoken-go/tokenizer/codec" +) + +// tokenEncoderMap won't grow after initialization +var defaultTokenEncoder tokenizer.Codec + +// tokenEncoderMap is used to store token encoders for different models +var tokenEncoderMap = make(map[string]tokenizer.Codec) + +// tokenEncoderMutex protects tokenEncoderMap for concurrent access +var tokenEncoderMutex sync.RWMutex + +func InitTokenEncoders() { + common.SysLog("initializing token encoders") + defaultTokenEncoder = codec.NewCl100kBase() + common.SysLog("token encoders initialized") +} + +func getTokenEncoder(model string) tokenizer.Codec { + // First, try to get the encoder from cache with read lock + tokenEncoderMutex.RLock() + if encoder, exists := tokenEncoderMap[model]; exists { + tokenEncoderMutex.RUnlock() + return encoder + } + tokenEncoderMutex.RUnlock() + + // If not in cache, create new encoder with write lock + tokenEncoderMutex.Lock() + defer tokenEncoderMutex.Unlock() + + // Double-check if another goroutine already created the encoder + if encoder, exists := tokenEncoderMap[model]; exists { + return encoder + } + + // Create new encoder + modelCodec, err := tokenizer.ForModel(tokenizer.Model(model)) + if err != nil { + // Cache the default encoder for this model to avoid repeated failures + tokenEncoderMap[model] = defaultTokenEncoder + return defaultTokenEncoder + } + + // Cache the new encoder + tokenEncoderMap[model] = modelCodec + return modelCodec +} + +func getTokenNum(tokenEncoder tokenizer.Codec, text string) int { + if text == "" { + return 0 + } + tkm, _ := tokenEncoder.Count(text) + return tkm +} diff --git a/service/usage_helpr.go b/service/usage_helpr.go index 825c9f150..97d54c4f9 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -23,8 +23,7 @@ func ResponseText2Usage(c *gin.Context, responseText string, modeName string, pr common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true) usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm := CountTextToken(responseText, modeName) - usage.CompletionTokens = ctkm + usage.CompletionTokens = EstimateTokenByModel(modeName, responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage }