mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-18 17:07:27 +00:00
Merge pull request #2355 from QuantumNous/feat/optimize-token-counter
feat: refactor token estimation logic
This commit is contained in:
@@ -209,7 +209,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: info.PromptTokens,
|
||||
InputTokens: info.GetEstimatePromptTokens(),
|
||||
OutputTokens: 0,
|
||||
},
|
||||
}
|
||||
@@ -734,12 +734,18 @@ func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: info.PromptTokens,
|
||||
PromptTokenCount: info.GetEstimatePromptTokens(),
|
||||
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
|
||||
TotalTokenCount: info.PromptTokens,
|
||||
TotalTokenCount: info.GetEstimatePromptTokens(),
|
||||
},
|
||||
}
|
||||
|
||||
if openAIResponse.Usage != nil {
|
||||
geminiResponse.UsageMetadata.PromptTokenCount = openAIResponse.Usage.PromptTokens
|
||||
geminiResponse.UsageMetadata.CandidatesTokenCount = openAIResponse.Usage.CompletionTokens
|
||||
geminiResponse.UsageMetadata.TotalTokenCount = openAIResponse.Usage.TotalTokens
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
"math"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -23,64 +21,8 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var defaultTokenEncoder tokenizer.Codec
|
||||
|
||||
// tokenEncoderMap is used to store token encoders for different models
|
||||
var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
||||
|
||||
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
|
||||
var tokenEncoderMutex sync.RWMutex
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
defaultTokenEncoder = codec.NewCl100kBase()
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) tokenizer.Codec {
|
||||
// First, try to get the encoder from cache with read lock
|
||||
tokenEncoderMutex.RLock()
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
tokenEncoderMutex.RUnlock()
|
||||
return encoder
|
||||
}
|
||||
tokenEncoderMutex.RUnlock()
|
||||
|
||||
// If not in cache, create new encoder with write lock
|
||||
tokenEncoderMutex.Lock()
|
||||
defer tokenEncoderMutex.Unlock()
|
||||
|
||||
// Double-check if another goroutine already created the encoder
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
return encoder
|
||||
}
|
||||
|
||||
// Create new encoder
|
||||
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
|
||||
if err != nil {
|
||||
// Cache the default encoder for this model to avoid repeated failures
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
// Cache the new encoder
|
||||
tokenEncoderMap[model] = modelCodec
|
||||
return modelCodec
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tkm, _ := tokenEncoder.Count(text)
|
||||
return tkm
|
||||
}
|
||||
|
||||
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||||
if fileMeta == nil {
|
||||
return 0, fmt.Errorf("image_url_is_nil")
|
||||
@@ -257,7 +199,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
return tiles*tileTokens + baseTokens, nil
|
||||
}
|
||||
|
||||
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
// 是否统计token
|
||||
if !constant.CountToken {
|
||||
return 0, nil
|
||||
@@ -375,14 +317,14 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
for i, file := range meta.Files {
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if info.RelayFormat == types.RelayFormatGemini {
|
||||
tkm += 520 // gemini per input image tokens
|
||||
} else {
|
||||
if common.IsOpenAITextModel(info.UpstreamModelName) {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
|
||||
}
|
||||
tkm += token
|
||||
} else {
|
||||
tkm += 520
|
||||
}
|
||||
case types.FileTypeAudio:
|
||||
tkm += 256
|
||||
@@ -399,111 +341,6 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
||||
tkm := 0
|
||||
|
||||
// Count tokens in messages
|
||||
msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tkm += msgTokens
|
||||
|
||||
// Count tokens in system message
|
||||
if request.System != "" {
|
||||
systemTokens := CountTokenInput(request.System, model)
|
||||
tkm += systemTokens
|
||||
}
|
||||
|
||||
if request.Tools != nil {
|
||||
// check is array
|
||||
if tools, ok := request.Tools.([]any); ok {
|
||||
if len(tools) > 0 {
|
||||
parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
|
||||
if err1 != nil {
|
||||
return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
|
||||
}
|
||||
toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
|
||||
if err2 != nil {
|
||||
return 0, fmt.Errorf("tools: %v", err)
|
||||
}
|
||||
tkm += toolTokens
|
||||
}
|
||||
} else {
|
||||
return 0, errors.New("tools: Input should be a valid list")
|
||||
}
|
||||
}
|
||||
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
tokenNum := 0
|
||||
|
||||
for _, message := range messages {
|
||||
// Count tokens for role
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.IsStringContent() {
|
||||
tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
|
||||
} else {
|
||||
content, err := message.ParseContent()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, mediaMessage := range content {
|
||||
switch mediaMessage.Type {
|
||||
case "text":
|
||||
tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
|
||||
case "image":
|
||||
//imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
|
||||
//if err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
tokenNum += 1000
|
||||
case "tool_use":
|
||||
if mediaMessage.Input != nil {
|
||||
tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
|
||||
inputJSON, _ := json.Marshal(mediaMessage.Input)
|
||||
tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
|
||||
}
|
||||
case "tool_result":
|
||||
if mediaMessage.Content != nil {
|
||||
contentJSON, _ := json.Marshal(mediaMessage.Content)
|
||||
tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
|
||||
tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
|
||||
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
tokenNum := 0
|
||||
|
||||
for _, tool := range tools {
|
||||
tokenNum += getTokenNum(tokenEncoder, tool.Name)
|
||||
tokenNum += getTokenNum(tokenEncoder, tool.Description)
|
||||
|
||||
schemaJSON, err := json.Marshal(tool.InputSchema)
|
||||
if err != nil {
|
||||
return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
|
||||
}
|
||||
|
||||
// Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
|
||||
tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
|
||||
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
||||
audioToken := 0
|
||||
textToken := 0
|
||||
@@ -578,31 +415,6 @@ func CountTokenInput(input any, model string) int {
|
||||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||||
}
|
||||
|
||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||
tokens := 0
|
||||
for _, message := range messages {
|
||||
tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
||||
tokens += tkm
|
||||
if message.Delta.ToolCalls != nil {
|
||||
for _, tool := range message.Delta.ToolCalls {
|
||||
tkm := CountTokenInput(tool.Function.Name, model)
|
||||
tokens += tkm
|
||||
tkm = CountTokenInput(tool.Function.Arguments, model)
|
||||
tokens += tkm
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountTTSToken(text string, model string) int {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
return utf8.RuneCountInString(text)
|
||||
} else {
|
||||
return CountTextToken(text, model)
|
||||
}
|
||||
}
|
||||
|
||||
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
|
||||
if audioBase64 == "" {
|
||||
return 0, nil
|
||||
@@ -625,17 +437,16 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
||||
return int(duration / 60 * 200 / 0.24), nil
|
||||
}
|
||||
|
||||
//func CountAudioToken(sec float64, audioType string) {
|
||||
// if audioType == "input" {
|
||||
//
|
||||
// }
|
||||
//}
|
||||
|
||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
// CountTextToken 统计文本的token数量,仅OpenAI模型使用tokenizer,其余模型使用估算
|
||||
func CountTextToken(text string, model string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
if common.IsOpenAITextModel(model) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
} else {
|
||||
// 非openai模型,使用tiktoken-go计算没有意义,使用估算节省资源
|
||||
return EstimateTokenByModel(model, text)
|
||||
}
|
||||
}
|
||||
|
||||
230
service/token_estimator.go
Normal file
230
service/token_estimator.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Provider 定义模型厂商大类
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
OpenAI Provider = "openai" // 代表 GPT-3.5, GPT-4, GPT-4o
|
||||
Gemini Provider = "gemini" // 代表 Gemini 1.0, 1.5 Pro/Flash
|
||||
Claude Provider = "claude" // 代表 Claude 3, 3.5 Sonnet
|
||||
Unknown Provider = "unknown" // 兜底默认
|
||||
)
|
||||
|
||||
// multipliers 定义不同厂商的计费权重
|
||||
type multipliers struct {
|
||||
Word float64 // 英文单词 (每词)
|
||||
Number float64 // 数字 (每连续数字串)
|
||||
CJK float64 // 中日韩字符 (每字)
|
||||
Symbol float64 // 普通标点符号 (每个)
|
||||
MathSymbol float64 // 数学符号 (∑,∫,∂,√等,每个)
|
||||
URLDelim float64 // URL分隔符 (/,:,?,&,=,#,%) - tokenizer优化好
|
||||
AtSign float64 // @符号 - 导致单词切分,消耗较高
|
||||
Emoji float64 // Emoji表情 (每个)
|
||||
Newline float64 // 换行符/制表符 (每个)
|
||||
Space float64 // 空格 (每个)
|
||||
BasePad int // 基础起步消耗 (Start/End tokens)
|
||||
}
|
||||
|
||||
var (
|
||||
multipliersMap = map[Provider]multipliers{
|
||||
Gemini: {
|
||||
Word: 1.15, Number: 2.8, CJK: 0.68, Symbol: 0.38, MathSymbol: 1.05, URLDelim: 1.2, AtSign: 2.5, Emoji: 1.08, Newline: 1.15, Space: 0.2, BasePad: 0,
|
||||
},
|
||||
Claude: {
|
||||
Word: 1.13, Number: 1.63, CJK: 1.21, Symbol: 0.4, MathSymbol: 4.52, URLDelim: 1.26, AtSign: 2.82, Emoji: 2.6, Newline: 0.89, Space: 0.39, BasePad: 0,
|
||||
},
|
||||
OpenAI: {
|
||||
Word: 1.02, Number: 1.55, CJK: 0.85, Symbol: 0.4, MathSymbol: 2.68, URLDelim: 1.0, AtSign: 2.0, Emoji: 2.12, Newline: 0.5, Space: 0.42, BasePad: 0,
|
||||
},
|
||||
}
|
||||
multipliersLock sync.RWMutex
|
||||
)
|
||||
|
||||
// getMultipliers 根据厂商获取权重配置
|
||||
func getMultipliers(p Provider) multipliers {
|
||||
multipliersLock.RLock()
|
||||
defer multipliersLock.RUnlock()
|
||||
|
||||
switch p {
|
||||
case Gemini:
|
||||
return multipliersMap[Gemini]
|
||||
case Claude:
|
||||
return multipliersMap[Claude]
|
||||
case OpenAI:
|
||||
return multipliersMap[OpenAI]
|
||||
default:
|
||||
// 默认兜底 (按 OpenAI 的算)
|
||||
return multipliersMap[OpenAI]
|
||||
}
|
||||
}
|
||||
|
||||
// EstimateToken 计算 Token 数量
|
||||
func EstimateToken(provider Provider, text string) int {
|
||||
m := getMultipliers(provider)
|
||||
var count float64
|
||||
|
||||
// 状态机变量
|
||||
type WordType int
|
||||
const (
|
||||
None WordType = iota
|
||||
Latin
|
||||
Number
|
||||
)
|
||||
currentWordType := None
|
||||
|
||||
for _, r := range text {
|
||||
// 1. 处理空格和换行符
|
||||
if unicode.IsSpace(r) {
|
||||
currentWordType = None
|
||||
// 换行符和制表符使用Newline权重
|
||||
if r == '\n' || r == '\t' {
|
||||
count += m.Newline
|
||||
} else {
|
||||
// 普通空格使用Space权重
|
||||
count += m.Space
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 2. 处理 CJK (中日韩) - 按字符计费
|
||||
if isCJK(r) {
|
||||
currentWordType = None
|
||||
count += m.CJK
|
||||
continue
|
||||
}
|
||||
|
||||
// 3. 处理Emoji - 使用专门的Emoji权重
|
||||
if isEmoji(r) {
|
||||
currentWordType = None
|
||||
count += m.Emoji
|
||||
continue
|
||||
}
|
||||
|
||||
// 4. 处理拉丁字母/数字 (英文单词)
|
||||
if isLatinOrNumber(r) {
|
||||
isNum := unicode.IsNumber(r)
|
||||
newType := Latin
|
||||
if isNum {
|
||||
newType = Number
|
||||
}
|
||||
|
||||
// 如果之前不在单词中,或者类型发生变化(字母<->数字),则视为新token
|
||||
// 注意:对于OpenAI,通常"version 3.5"会切分,"abc123xyz"有时也会切分
|
||||
// 这里简单起见,字母和数字切换时增加权重
|
||||
if currentWordType == None || currentWordType != newType {
|
||||
if newType == Number {
|
||||
count += m.Number
|
||||
} else {
|
||||
count += m.Word
|
||||
}
|
||||
currentWordType = newType
|
||||
}
|
||||
// 单词中间的字符不额外计费
|
||||
continue
|
||||
}
|
||||
|
||||
// 5. 处理标点符号/特殊字符 - 按类型使用不同权重
|
||||
currentWordType = None
|
||||
if isMathSymbol(r) {
|
||||
count += m.MathSymbol
|
||||
} else if r == '@' {
|
||||
count += m.AtSign
|
||||
} else if isURLDelim(r) {
|
||||
count += m.URLDelim
|
||||
} else {
|
||||
count += m.Symbol
|
||||
}
|
||||
}
|
||||
|
||||
// 向上取整并加上基础 padding
|
||||
return int(math.Ceil(count)) + m.BasePad
|
||||
}
|
||||
|
||||
// 辅助:判断是否为 CJK 字符
|
||||
func isCJK(r rune) bool {
|
||||
return unicode.Is(unicode.Han, r) ||
|
||||
(r >= 0x3040 && r <= 0x30FF) || // 日文
|
||||
(r >= 0xAC00 && r <= 0xD7A3) // 韩文
|
||||
}
|
||||
|
||||
// 辅助:判断是否为单词主体 (字母或数字)
|
||||
func isLatinOrNumber(r rune) bool {
|
||||
return unicode.IsLetter(r) || unicode.IsNumber(r)
|
||||
}
|
||||
|
||||
// 辅助:判断是否为Emoji字符
|
||||
func isEmoji(r rune) bool {
|
||||
// Emoji的Unicode范围
|
||||
// 基本范围:0x1F300-0x1F9FF (Emoticons, Symbols, Pictographs)
|
||||
// 补充范围:0x2600-0x26FF (Misc Symbols), 0x2700-0x27BF (Dingbats)
|
||||
// 表情符号:0x1F600-0x1F64F (Emoticons)
|
||||
// 其他:0x1F900-0x1F9FF (Supplemental Symbols and Pictographs)
|
||||
return (r >= 0x1F300 && r <= 0x1F9FF) ||
|
||||
(r >= 0x2600 && r <= 0x26FF) ||
|
||||
(r >= 0x2700 && r <= 0x27BF) ||
|
||||
(r >= 0x1F600 && r <= 0x1F64F) ||
|
||||
(r >= 0x1F900 && r <= 0x1F9FF) ||
|
||||
(r >= 0x1FA00 && r <= 0x1FAFF) // Symbols and Pictographs Extended-A
|
||||
}
|
||||
|
||||
// 辅助:判断是否为数学符号
|
||||
func isMathSymbol(r rune) bool {
|
||||
// 数学运算符和符号
|
||||
// 基本数学符号:∑ ∫ ∂ √ ∞ ≤ ≥ ≠ ≈ ± × ÷
|
||||
// 上下标数字:² ³ ¹ ⁴ ⁵ ⁶ ⁷ ⁸ ⁹ ⁰
|
||||
// 希腊字母等也常用于数学
|
||||
mathSymbols := "∑∫∂√∞≤≥≠≈±×÷∈∉∋∌⊂⊃⊆⊇∪∩∧∨¬∀∃∄∅∆∇∝∟∠∡∢°′″‴⁺⁻⁼⁽⁾ⁿ₀₁₂₃₄₅₆₇₈₉₊₋₌₍₎²³¹⁴⁵⁶⁷⁸⁹⁰"
|
||||
for _, m := range mathSymbols {
|
||||
if r == m {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Mathematical Operators (U+2200–U+22FF)
|
||||
if r >= 0x2200 && r <= 0x22FF {
|
||||
return true
|
||||
}
|
||||
// Supplemental Mathematical Operators (U+2A00–U+2AFF)
|
||||
if r >= 0x2A00 && r <= 0x2AFF {
|
||||
return true
|
||||
}
|
||||
// Mathematical Alphanumeric Symbols (U+1D400–U+1D7FF)
|
||||
if r >= 0x1D400 && r <= 0x1D7FF {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 辅助:判断是否为URL分隔符(tokenizer对这些优化较好)
|
||||
func isURLDelim(r rune) bool {
|
||||
// URL中常见的分隔符,tokenizer通常优化处理
|
||||
urlDelims := "/:?&=;#%"
|
||||
for _, d := range urlDelims {
|
||||
if r == d {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func EstimateTokenByModel(model, text string) int {
|
||||
// strings.Contains(model, "gpt-4o")
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
model = strings.ToLower(model)
|
||||
if strings.Contains(model, "gemini") {
|
||||
return EstimateToken(Gemini, text)
|
||||
} else if strings.Contains(model, "claude") {
|
||||
return EstimateToken(Claude, text)
|
||||
} else {
|
||||
return EstimateToken(OpenAI, text)
|
||||
}
|
||||
}
|
||||
63
service/tokenizer.go
Normal file
63
service/tokenizer.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var defaultTokenEncoder tokenizer.Codec
|
||||
|
||||
// tokenEncoderMap is used to store token encoders for different models
|
||||
var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
||||
|
||||
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
|
||||
var tokenEncoderMutex sync.RWMutex
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
defaultTokenEncoder = codec.NewCl100kBase()
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) tokenizer.Codec {
|
||||
// First, try to get the encoder from cache with read lock
|
||||
tokenEncoderMutex.RLock()
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
tokenEncoderMutex.RUnlock()
|
||||
return encoder
|
||||
}
|
||||
tokenEncoderMutex.RUnlock()
|
||||
|
||||
// If not in cache, create new encoder with write lock
|
||||
tokenEncoderMutex.Lock()
|
||||
defer tokenEncoderMutex.Unlock()
|
||||
|
||||
// Double-check if another goroutine already created the encoder
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
return encoder
|
||||
}
|
||||
|
||||
// Create new encoder
|
||||
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
|
||||
if err != nil {
|
||||
// Cache the default encoder for this model to avoid repeated failures
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
// Cache the new encoder
|
||||
tokenEncoderMap[model] = modelCodec
|
||||
return modelCodec
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tkm, _ := tokenEncoder.Count(text)
|
||||
return tkm
|
||||
}
|
||||
@@ -23,8 +23,7 @@ func ResponseText2Usage(c *gin.Context, responseText string, modeName string, pr
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm := CountTextToken(responseText, modeName)
|
||||
usage.CompletionTokens = ctkm
|
||||
usage.CompletionTokens = EstimateTokenByModel(modeName, responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user