Merge pull request #2355 from QuantumNous/feat/optimize-token-counter

feat: refactor token estimation logic
This commit is contained in:
Calcium-Ion
2025-12-02 21:51:09 +08:00
committed by GitHub
26 changed files with 396 additions and 275 deletions

View File

@@ -209,7 +209,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
Type: "message",
Role: "assistant",
Usage: &dto.ClaudeUsage{
InputTokens: info.PromptTokens,
InputTokens: info.GetEstimatePromptTokens(),
OutputTokens: 0,
},
}
@@ -734,12 +734,18 @@ func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamRespon
geminiResponse := &dto.GeminiChatResponse{
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: info.PromptTokens,
PromptTokenCount: info.GetEstimatePromptTokens(),
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
TotalTokenCount: info.PromptTokens,
TotalTokenCount: info.GetEstimatePromptTokens(),
},
}
if openAIResponse.Usage != nil {
geminiResponse.UsageMetadata.PromptTokenCount = openAIResponse.Usage.PromptTokens
geminiResponse.UsageMetadata.CandidatesTokenCount = openAIResponse.Usage.CompletionTokens
geminiResponse.UsageMetadata.TotalTokenCount = openAIResponse.Usage.TotalTokens
}
for _, choice := range openAIResponse.Choices {
candidate := dto.GeminiChatCandidate{
Index: int64(choice.Index),

View File

@@ -1,7 +1,6 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"image"
@@ -12,7 +11,6 @@ import (
"math"
"path/filepath"
"strings"
"sync"
"unicode/utf8"
"github.com/QuantumNous/new-api/common"
@@ -23,64 +21,8 @@ import (
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/tiktoken-go/tokenizer"
"github.com/tiktoken-go/tokenizer/codec"
)
// tokenEncoderMap won't grow after initialization
var defaultTokenEncoder tokenizer.Codec
// tokenEncoderMap is used to store token encoders for different models
var tokenEncoderMap = make(map[string]tokenizer.Codec)
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
var tokenEncoderMutex sync.RWMutex
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
defaultTokenEncoder = codec.NewCl100kBase()
common.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) tokenizer.Codec {
// First, try to get the encoder from cache with read lock
tokenEncoderMutex.RLock()
if encoder, exists := tokenEncoderMap[model]; exists {
tokenEncoderMutex.RUnlock()
return encoder
}
tokenEncoderMutex.RUnlock()
// If not in cache, create new encoder with write lock
tokenEncoderMutex.Lock()
defer tokenEncoderMutex.Unlock()
// Double-check if another goroutine already created the encoder
if encoder, exists := tokenEncoderMap[model]; exists {
return encoder
}
// Create new encoder
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
if err != nil {
// Cache the default encoder for this model to avoid repeated failures
tokenEncoderMap[model] = defaultTokenEncoder
return defaultTokenEncoder
}
// Cache the new encoder
tokenEncoderMap[model] = modelCodec
return modelCodec
}
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
if text == "" {
return 0
}
tkm, _ := tokenEncoder.Count(text)
return tkm
}
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
if fileMeta == nil {
return 0, fmt.Errorf("image_url_is_nil")
@@ -257,7 +199,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
return tiles*tileTokens + baseTokens, nil
}
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
// 是否统计token
if !constant.CountToken {
return 0, nil
@@ -375,14 +317,14 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
for i, file := range meta.Files {
switch file.FileType {
case types.FileTypeImage:
if info.RelayFormat == types.RelayFormatGemini {
tkm += 520 // gemini per input image tokens
} else {
if common.IsOpenAITextModel(info.UpstreamModelName) {
token, err := getImageToken(file, model, info.IsStream)
if err != nil {
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
}
tkm += token
} else {
tkm += 520
}
case types.FileTypeAudio:
tkm += 256
@@ -399,111 +341,6 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
return tkm, nil
}
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
tkm := 0
// Count tokens in messages
msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
if err != nil {
return 0, err
}
tkm += msgTokens
// Count tokens in system message
if request.System != "" {
systemTokens := CountTokenInput(request.System, model)
tkm += systemTokens
}
if request.Tools != nil {
// check is array
if tools, ok := request.Tools.([]any); ok {
if len(tools) > 0 {
parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
if err1 != nil {
return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
}
toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
if err2 != nil {
return 0, fmt.Errorf("tools: %v", err)
}
tkm += toolTokens
}
} else {
return 0, errors.New("tools: Input should be a valid list")
}
}
return tkm, nil
}
func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
tokenEncoder := getTokenEncoder(model)
tokenNum := 0
for _, message := range messages {
// Count tokens for role
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.IsStringContent() {
tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
} else {
content, err := message.ParseContent()
if err != nil {
return 0, err
}
for _, mediaMessage := range content {
switch mediaMessage.Type {
case "text":
tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
case "image":
//imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
//if err != nil {
// return 0, err
//}
tokenNum += 1000
case "tool_use":
if mediaMessage.Input != nil {
tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
inputJSON, _ := json.Marshal(mediaMessage.Input)
tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
}
case "tool_result":
if mediaMessage.Content != nil {
contentJSON, _ := json.Marshal(mediaMessage.Content)
tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
}
}
}
}
}
// Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
return tokenNum, nil
}
func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
tokenEncoder := getTokenEncoder(model)
tokenNum := 0
for _, tool := range tools {
tokenNum += getTokenNum(tokenEncoder, tool.Name)
tokenNum += getTokenNum(tokenEncoder, tool.Description)
schemaJSON, err := json.Marshal(tool.InputSchema)
if err != nil {
return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
}
tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
}
// Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
return tokenNum, nil
}
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
audioToken := 0
textToken := 0
@@ -578,31 +415,6 @@ func CountTokenInput(input any, model string) int {
return CountTokenInput(fmt.Sprintf("%v", input), model)
}
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
tokens := 0
for _, message := range messages {
tkm := CountTokenInput(message.Delta.GetContentString(), model)
tokens += tkm
if message.Delta.ToolCalls != nil {
for _, tool := range message.Delta.ToolCalls {
tkm := CountTokenInput(tool.Function.Name, model)
tokens += tkm
tkm = CountTokenInput(tool.Function.Arguments, model)
tokens += tkm
}
}
}
return tokens
}
func CountTTSToken(text string, model string) int {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text)
} else {
return CountTextToken(text, model)
}
}
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
@@ -625,17 +437,16 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
return int(duration / 60 * 200 / 0.24), nil
}
//func CountAudioToken(sec float64, audioType string) {
// if audioType == "input" {
//
// }
//}
// CountTextToken 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
// CountTextToken 统计文本的token数量仅OpenAI模型使用tokenizer其余模型使用估算
func CountTextToken(text string, model string) int {
if text == "" {
return 0
}
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
if common.IsOpenAITextModel(model) {
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
} else {
// 非openai模型使用tiktoken-go计算没有意义使用估算节省资源
return EstimateTokenByModel(model, text)
}
}

230
service/token_estimator.go Normal file
View File

@@ -0,0 +1,230 @@
package service
import (
"math"
"strings"
"sync"
"unicode"
)
// Provider 定义模型厂商大类
type Provider string
const (
OpenAI Provider = "openai" // 代表 GPT-3.5, GPT-4, GPT-4o
Gemini Provider = "gemini" // 代表 Gemini 1.0, 1.5 Pro/Flash
Claude Provider = "claude" // 代表 Claude 3, 3.5 Sonnet
Unknown Provider = "unknown" // 兜底默认
)
// multipliers 定义不同厂商的计费权重
type multipliers struct {
Word float64 // 英文单词 (每词)
Number float64 // 数字 (每连续数字串)
CJK float64 // 中日韩字符 (每字)
Symbol float64 // 普通标点符号 (每个)
MathSymbol float64 // 数学符号 (∑,∫,∂,√等,每个)
URLDelim float64 // URL分隔符 (/,:,?,&,=,#,%) - tokenizer优化好
AtSign float64 // @符号 - 导致单词切分,消耗较高
Emoji float64 // Emoji表情 (每个)
Newline float64 // 换行符/制表符 (每个)
Space float64 // 空格 (每个)
BasePad int // 基础起步消耗 (Start/End tokens)
}
var (
multipliersMap = map[Provider]multipliers{
Gemini: {
Word: 1.15, Number: 2.8, CJK: 0.68, Symbol: 0.38, MathSymbol: 1.05, URLDelim: 1.2, AtSign: 2.5, Emoji: 1.08, Newline: 1.15, Space: 0.2, BasePad: 0,
},
Claude: {
Word: 1.13, Number: 1.63, CJK: 1.21, Symbol: 0.4, MathSymbol: 4.52, URLDelim: 1.26, AtSign: 2.82, Emoji: 2.6, Newline: 0.89, Space: 0.39, BasePad: 0,
},
OpenAI: {
Word: 1.02, Number: 1.55, CJK: 0.85, Symbol: 0.4, MathSymbol: 2.68, URLDelim: 1.0, AtSign: 2.0, Emoji: 2.12, Newline: 0.5, Space: 0.42, BasePad: 0,
},
}
multipliersLock sync.RWMutex
)
// getMultipliers 根据厂商获取权重配置
func getMultipliers(p Provider) multipliers {
multipliersLock.RLock()
defer multipliersLock.RUnlock()
switch p {
case Gemini:
return multipliersMap[Gemini]
case Claude:
return multipliersMap[Claude]
case OpenAI:
return multipliersMap[OpenAI]
default:
// 默认兜底 (按 OpenAI 的算)
return multipliersMap[OpenAI]
}
}
// EstimateToken 计算 Token 数量
func EstimateToken(provider Provider, text string) int {
m := getMultipliers(provider)
var count float64
// 状态机变量
type WordType int
const (
None WordType = iota
Latin
Number
)
currentWordType := None
for _, r := range text {
// 1. 处理空格和换行符
if unicode.IsSpace(r) {
currentWordType = None
// 换行符和制表符使用Newline权重
if r == '\n' || r == '\t' {
count += m.Newline
} else {
// 普通空格使用Space权重
count += m.Space
}
continue
}
// 2. 处理 CJK (中日韩) - 按字符计费
if isCJK(r) {
currentWordType = None
count += m.CJK
continue
}
// 3. 处理Emoji - 使用专门的Emoji权重
if isEmoji(r) {
currentWordType = None
count += m.Emoji
continue
}
// 4. 处理拉丁字母/数字 (英文单词)
if isLatinOrNumber(r) {
isNum := unicode.IsNumber(r)
newType := Latin
if isNum {
newType = Number
}
// 如果之前不在单词中,或者类型发生变化(字母<->数字则视为新token
// 注意对于OpenAI通常"version 3.5"会切分,"abc123xyz"有时也会切分
// 这里简单起见,字母和数字切换时增加权重
if currentWordType == None || currentWordType != newType {
if newType == Number {
count += m.Number
} else {
count += m.Word
}
currentWordType = newType
}
// 单词中间的字符不额外计费
continue
}
// 5. 处理标点符号/特殊字符 - 按类型使用不同权重
currentWordType = None
if isMathSymbol(r) {
count += m.MathSymbol
} else if r == '@' {
count += m.AtSign
} else if isURLDelim(r) {
count += m.URLDelim
} else {
count += m.Symbol
}
}
// 向上取整并加上基础 padding
return int(math.Ceil(count)) + m.BasePad
}
// 辅助:判断是否为 CJK 字符
func isCJK(r rune) bool {
return unicode.Is(unicode.Han, r) ||
(r >= 0x3040 && r <= 0x30FF) || // 日文
(r >= 0xAC00 && r <= 0xD7A3) // 韩文
}
// 辅助:判断是否为单词主体 (字母或数字)
func isLatinOrNumber(r rune) bool {
return unicode.IsLetter(r) || unicode.IsNumber(r)
}
// 辅助判断是否为Emoji字符
func isEmoji(r rune) bool {
// Emoji的Unicode范围
// 基本范围0x1F300-0x1F9FF (Emoticons, Symbols, Pictographs)
// 补充范围0x2600-0x26FF (Misc Symbols), 0x2700-0x27BF (Dingbats)
// 表情符号0x1F600-0x1F64F (Emoticons)
// 其他0x1F900-0x1F9FF (Supplemental Symbols and Pictographs)
return (r >= 0x1F300 && r <= 0x1F9FF) ||
(r >= 0x2600 && r <= 0x26FF) ||
(r >= 0x2700 && r <= 0x27BF) ||
(r >= 0x1F600 && r <= 0x1F64F) ||
(r >= 0x1F900 && r <= 0x1F9FF) ||
(r >= 0x1FA00 && r <= 0x1FAFF) // Symbols and Pictographs Extended-A
}
// 辅助:判断是否为数学符号
func isMathSymbol(r rune) bool {
// 数学运算符和符号
// 基本数学符号:∑ ∫ ∂ √ ∞ ≤ ≥ ≠ ≈ ± × ÷
// 上下标数字:² ³ ¹ ⁴ ⁵ ⁶ ⁷ ⁸ ⁹ ⁰
// 希腊字母等也常用于数学
mathSymbols := "∑∫∂√∞≤≥≠≈±×÷∈∉∋∌⊂⊃⊆⊇∪∩∧∨¬∀∃∄∅∆∇∝∟∠∡∢°′″‴⁺⁻⁼⁽⁾ⁿ₀₁₂₃₄₅₆₇₈₉₊₋₌₍₎²³¹⁴⁵⁶⁷⁸⁹⁰"
for _, m := range mathSymbols {
if r == m {
return true
}
}
// Mathematical Operators (U+2200U+22FF)
if r >= 0x2200 && r <= 0x22FF {
return true
}
// Supplemental Mathematical Operators (U+2A00U+2AFF)
if r >= 0x2A00 && r <= 0x2AFF {
return true
}
// Mathematical Alphanumeric Symbols (U+1D400U+1D7FF)
if r >= 0x1D400 && r <= 0x1D7FF {
return true
}
return false
}
// 辅助判断是否为URL分隔符tokenizer对这些优化较好
func isURLDelim(r rune) bool {
// URL中常见的分隔符tokenizer通常优化处理
urlDelims := "/:?&=;#%"
for _, d := range urlDelims {
if r == d {
return true
}
}
return false
}
func EstimateTokenByModel(model, text string) int {
// strings.Contains(model, "gpt-4o")
if text == "" {
return 0
}
model = strings.ToLower(model)
if strings.Contains(model, "gemini") {
return EstimateToken(Gemini, text)
} else if strings.Contains(model, "claude") {
return EstimateToken(Claude, text)
} else {
return EstimateToken(OpenAI, text)
}
}

63
service/tokenizer.go Normal file
View File

@@ -0,0 +1,63 @@
package service
import (
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/tiktoken-go/tokenizer"
"github.com/tiktoken-go/tokenizer/codec"
)
// tokenEncoderMap won't grow after initialization
var defaultTokenEncoder tokenizer.Codec
// tokenEncoderMap is used to store token encoders for different models
var tokenEncoderMap = make(map[string]tokenizer.Codec)
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
var tokenEncoderMutex sync.RWMutex
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
defaultTokenEncoder = codec.NewCl100kBase()
common.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) tokenizer.Codec {
// First, try to get the encoder from cache with read lock
tokenEncoderMutex.RLock()
if encoder, exists := tokenEncoderMap[model]; exists {
tokenEncoderMutex.RUnlock()
return encoder
}
tokenEncoderMutex.RUnlock()
// If not in cache, create new encoder with write lock
tokenEncoderMutex.Lock()
defer tokenEncoderMutex.Unlock()
// Double-check if another goroutine already created the encoder
if encoder, exists := tokenEncoderMap[model]; exists {
return encoder
}
// Create new encoder
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
if err != nil {
// Cache the default encoder for this model to avoid repeated failures
tokenEncoderMap[model] = defaultTokenEncoder
return defaultTokenEncoder
}
// Cache the new encoder
tokenEncoderMap[model] = modelCodec
return modelCodec
}
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
if text == "" {
return 0
}
tkm, _ := tokenEncoder.Count(text)
return tkm
}

View File

@@ -23,8 +23,7 @@ func ResponseText2Usage(c *gin.Context, responseText string, modeName string, pr
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
ctkm := CountTextToken(responseText, modeName)
usage.CompletionTokens = ctkm
usage.CompletionTokens = EstimateTokenByModel(modeName, responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}