mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 11:51:20 +00:00
453 lines
13 KiB
Go
453 lines
13 KiB
Go
package service
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"image"
|
||
_ "image/gif"
|
||
_ "image/jpeg"
|
||
_ "image/png"
|
||
"log"
|
||
"math"
|
||
"path/filepath"
|
||
"strings"
|
||
"unicode/utf8"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||
constant2 "github.com/QuantumNous/new-api/relay/constant"
|
||
"github.com/QuantumNous/new-api/types"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||
if fileMeta == nil {
|
||
return 0, fmt.Errorf("image_url_is_nil")
|
||
}
|
||
|
||
// Defaults for 4o/4.1/4.5 family unless overridden below
|
||
baseTokens := 85
|
||
tileTokens := 170
|
||
|
||
// Model classification
|
||
lowerModel := strings.ToLower(model)
|
||
|
||
// Special cases from existing behavior
|
||
if strings.HasPrefix(lowerModel, "glm-4") {
|
||
return 1047, nil
|
||
}
|
||
|
||
// Patch-based models (32x32 patches, capped at 1536, with multiplier)
|
||
isPatchBased := false
|
||
multiplier := 1.0
|
||
switch {
|
||
case strings.Contains(lowerModel, "gpt-4.1-mini"):
|
||
isPatchBased = true
|
||
multiplier = 1.62
|
||
case strings.Contains(lowerModel, "gpt-4.1-nano"):
|
||
isPatchBased = true
|
||
multiplier = 2.46
|
||
case strings.HasPrefix(lowerModel, "o4-mini"):
|
||
isPatchBased = true
|
||
multiplier = 1.72
|
||
case strings.HasPrefix(lowerModel, "gpt-5-mini"):
|
||
isPatchBased = true
|
||
multiplier = 1.62
|
||
case strings.HasPrefix(lowerModel, "gpt-5-nano"):
|
||
isPatchBased = true
|
||
multiplier = 2.46
|
||
}
|
||
|
||
// Tile-based model tokens and bases per doc
|
||
if !isPatchBased {
|
||
if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
|
||
baseTokens = 2833
|
||
tileTokens = 5667
|
||
} else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
|
||
baseTokens = 70
|
||
tileTokens = 140
|
||
} else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
|
||
baseTokens = 75
|
||
tileTokens = 150
|
||
} else if strings.Contains(lowerModel, "computer-use-preview") {
|
||
baseTokens = 65
|
||
tileTokens = 129
|
||
} else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
|
||
baseTokens = 85
|
||
tileTokens = 170
|
||
}
|
||
}
|
||
|
||
// Respect existing feature flags/short-circuits
|
||
if fileMeta.Detail == "low" && !isPatchBased {
|
||
return baseTokens, nil
|
||
}
|
||
|
||
// Whether to count image tokens at all
|
||
if !constant.GetMediaToken {
|
||
return 3 * baseTokens, nil
|
||
}
|
||
|
||
if !constant.GetMediaTokenNotStream && !stream {
|
||
return 3 * baseTokens, nil
|
||
}
|
||
// Normalize detail
|
||
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
|
||
fileMeta.Detail = "high"
|
||
}
|
||
|
||
// Decode image to get dimensions
|
||
var config image.Config
|
||
var err error
|
||
var format string
|
||
var b64str string
|
||
|
||
if fileMeta.ParsedData != nil {
|
||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
|
||
} else {
|
||
if strings.HasPrefix(fileMeta.OriginData, "http") {
|
||
config, format, err = DecodeUrlImageData(fileMeta.OriginData)
|
||
} else {
|
||
common.SysLog(fmt.Sprintf("decoding image"))
|
||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
|
||
}
|
||
fileMeta.MimeType = format
|
||
}
|
||
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
if config.Width == 0 || config.Height == 0 {
|
||
// not an image
|
||
if format != "" && b64str != "" {
|
||
// file type
|
||
return 3 * baseTokens, nil
|
||
}
|
||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
|
||
}
|
||
|
||
width := config.Width
|
||
height := config.Height
|
||
log.Printf("format: %s, width: %d, height: %d", format, width, height)
|
||
|
||
if isPatchBased {
|
||
// 32x32 patch-based calculation with 1536 cap and model multiplier
|
||
ceilDiv := func(a, b int) int { return (a + b - 1) / b }
|
||
rawPatchesW := ceilDiv(width, 32)
|
||
rawPatchesH := ceilDiv(height, 32)
|
||
rawPatches := rawPatchesW * rawPatchesH
|
||
if rawPatches > 1536 {
|
||
// scale down
|
||
area := float64(width * height)
|
||
r := math.Sqrt(float64(32*32*1536) / area)
|
||
wScaled := float64(width) * r
|
||
hScaled := float64(height) * r
|
||
// adjust to fit whole number of patches after scaling
|
||
adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
|
||
adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
|
||
adj := math.Min(adjW, adjH)
|
||
if !math.IsNaN(adj) && adj > 0 {
|
||
r = r * adj
|
||
}
|
||
wScaled = float64(width) * r
|
||
hScaled = float64(height) * r
|
||
patchesW := math.Ceil(wScaled / 32.0)
|
||
patchesH := math.Ceil(hScaled / 32.0)
|
||
imageTokens := int(patchesW * patchesH)
|
||
if imageTokens > 1536 {
|
||
imageTokens = 1536
|
||
}
|
||
return int(math.Round(float64(imageTokens) * multiplier)), nil
|
||
}
|
||
// below cap
|
||
imageTokens := rawPatches
|
||
return int(math.Round(float64(imageTokens) * multiplier)), nil
|
||
}
|
||
|
||
// Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
|
||
// Step 1: fit within 2048x2048 square
|
||
maxSide := math.Max(float64(width), float64(height))
|
||
fitScale := 1.0
|
||
if maxSide > 2048 {
|
||
fitScale = maxSide / 2048.0
|
||
}
|
||
fitW := int(math.Round(float64(width) / fitScale))
|
||
fitH := int(math.Round(float64(height) / fitScale))
|
||
|
||
// Step 2: scale so that shortest side is exactly 768
|
||
minSide := math.Min(float64(fitW), float64(fitH))
|
||
if minSide == 0 {
|
||
return baseTokens, nil
|
||
}
|
||
shortScale := 768.0 / minSide
|
||
finalW := int(math.Round(float64(fitW) * shortScale))
|
||
finalH := int(math.Round(float64(fitH) * shortScale))
|
||
|
||
// Count 512px tiles
|
||
tilesW := (finalW + 512 - 1) / 512
|
||
tilesH := (finalH + 512 - 1) / 512
|
||
tiles := tilesW * tilesH
|
||
|
||
if common.DebugEnabled {
|
||
log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
|
||
}
|
||
|
||
return tiles*tileTokens + baseTokens, nil
|
||
}
|
||
|
||
func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||
// 是否统计token
|
||
if !constant.CountToken {
|
||
return 0, nil
|
||
}
|
||
|
||
if meta == nil {
|
||
return 0, errors.New("token count meta is nil")
|
||
}
|
||
|
||
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
|
||
return 0, nil
|
||
}
|
||
if info.RelayMode == constant2.RelayModeAudioTranscription || info.RelayMode == constant2.RelayModeAudioTranslation {
|
||
multiForm, err := common.ParseMultipartFormReusable(c)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("error parsing multipart form: %v", err)
|
||
}
|
||
fileHeaders := multiForm.File["file"]
|
||
totalAudioToken := 0
|
||
for _, fileHeader := range fileHeaders {
|
||
file, err := fileHeader.Open()
|
||
if err != nil {
|
||
return 0, fmt.Errorf("error opening audio file: %v", err)
|
||
}
|
||
defer file.Close()
|
||
// get ext and io.seeker
|
||
ext := filepath.Ext(fileHeader.Filename)
|
||
duration, err := common.GetAudioDuration(c.Request.Context(), file, ext)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("error getting audio duration: %v", err)
|
||
}
|
||
// 一分钟 1000 token,与 $price / minute 对齐
|
||
totalAudioToken += int(math.Round(math.Ceil(duration) / 60.0 * 1000))
|
||
}
|
||
return totalAudioToken, nil
|
||
}
|
||
|
||
model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||
tkm := 0
|
||
|
||
if meta.TokenType == types.TokenTypeTextNumber {
|
||
tkm += utf8.RuneCountInString(meta.CombineText)
|
||
} else {
|
||
tkm += CountTextToken(meta.CombineText, model)
|
||
}
|
||
|
||
if info.RelayFormat == types.RelayFormatOpenAI {
|
||
tkm += meta.ToolsCount * 8
|
||
tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
|
||
tkm += meta.NameCount * 3
|
||
tkm += 3
|
||
}
|
||
|
||
shouldFetchFiles := true
|
||
|
||
if info.RelayFormat == types.RelayFormatGemini {
|
||
shouldFetchFiles = false
|
||
}
|
||
|
||
// 是否本地计算媒体token数量
|
||
if !constant.GetMediaToken {
|
||
shouldFetchFiles = false
|
||
}
|
||
|
||
// 是否在非流模式下本地计算媒体token数量
|
||
if !constant.GetMediaTokenNotStream && !info.IsStream {
|
||
shouldFetchFiles = false
|
||
}
|
||
|
||
for _, file := range meta.Files {
|
||
if strings.HasPrefix(file.OriginData, "http") {
|
||
if shouldFetchFiles {
|
||
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
|
||
if err != nil {
|
||
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
||
}
|
||
if strings.HasPrefix(mineType, "image/") {
|
||
file.FileType = types.FileTypeImage
|
||
} else if strings.HasPrefix(mineType, "video/") {
|
||
file.FileType = types.FileTypeVideo
|
||
} else if strings.HasPrefix(mineType, "audio/") {
|
||
file.FileType = types.FileTypeAudio
|
||
} else {
|
||
file.FileType = types.FileTypeFile
|
||
}
|
||
file.MimeType = mineType
|
||
}
|
||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
||
// get mime type from base64 header
|
||
parts := strings.SplitN(file.OriginData, ",", 2)
|
||
if len(parts) >= 1 {
|
||
header := parts[0]
|
||
// Extract mime type from "data:mime/type;base64" format
|
||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||
mimeStart := strings.Index(header, ":") + 1
|
||
mimeEnd := strings.Index(header, ";")
|
||
if mimeStart < mimeEnd {
|
||
mineType := header[mimeStart:mimeEnd]
|
||
if strings.HasPrefix(mineType, "image/") {
|
||
file.FileType = types.FileTypeImage
|
||
} else if strings.HasPrefix(mineType, "video/") {
|
||
file.FileType = types.FileTypeVideo
|
||
} else if strings.HasPrefix(mineType, "audio/") {
|
||
file.FileType = types.FileTypeAudio
|
||
} else {
|
||
file.FileType = types.FileTypeFile
|
||
}
|
||
file.MimeType = mineType
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
for i, file := range meta.Files {
|
||
switch file.FileType {
|
||
case types.FileTypeImage:
|
||
if common.IsOpenAITextModel(model) {
|
||
token, err := getImageToken(file, model, info.IsStream)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
|
||
}
|
||
tkm += token
|
||
} else {
|
||
tkm += 520
|
||
}
|
||
case types.FileTypeAudio:
|
||
tkm += 256
|
||
case types.FileTypeVideo:
|
||
tkm += 4096 * 2
|
||
case types.FileTypeFile:
|
||
tkm += 4096
|
||
default:
|
||
tkm += 4096 // Default case for unknown file types
|
||
}
|
||
}
|
||
|
||
common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
|
||
return tkm, nil
|
||
}
|
||
|
||
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
||
audioToken := 0
|
||
textToken := 0
|
||
switch request.Type {
|
||
case dto.RealtimeEventTypeSessionUpdate:
|
||
if request.Session != nil {
|
||
msgTokens := CountTextToken(request.Session.Instructions, model)
|
||
textToken += msgTokens
|
||
}
|
||
case dto.RealtimeEventResponseAudioDelta:
|
||
// count audio token
|
||
atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
|
||
if err != nil {
|
||
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
||
}
|
||
audioToken += atk
|
||
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||
// count text token
|
||
tkm := CountTextToken(request.Delta, model)
|
||
textToken += tkm
|
||
case dto.RealtimeEventInputAudioBufferAppend:
|
||
// count audio token
|
||
atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
|
||
if err != nil {
|
||
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
||
}
|
||
audioToken += atk
|
||
case dto.RealtimeEventConversationItemCreated:
|
||
if request.Item != nil {
|
||
switch request.Item.Type {
|
||
case "message":
|
||
for _, content := range request.Item.Content {
|
||
if content.Type == "input_text" {
|
||
tokens := CountTextToken(content.Text, model)
|
||
textToken += tokens
|
||
}
|
||
}
|
||
}
|
||
}
|
||
case dto.RealtimeEventTypeResponseDone:
|
||
// count tools token
|
||
if !info.IsFirstRequest {
|
||
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||
for _, tool := range info.RealtimeTools {
|
||
toolTokens := CountTokenInput(tool, model)
|
||
textToken += 8
|
||
textToken += toolTokens
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return textToken, audioToken, nil
|
||
}
|
||
|
||
func CountTokenInput(input any, model string) int {
|
||
switch v := input.(type) {
|
||
case string:
|
||
return CountTextToken(v, model)
|
||
case []string:
|
||
text := ""
|
||
for _, s := range v {
|
||
text += s
|
||
}
|
||
return CountTextToken(text, model)
|
||
case []interface{}:
|
||
text := ""
|
||
for _, item := range v {
|
||
text += fmt.Sprintf("%v", item)
|
||
}
|
||
return CountTextToken(text, model)
|
||
}
|
||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||
}
|
||
|
||
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
|
||
if audioBase64 == "" {
|
||
return 0, nil
|
||
}
|
||
duration, err := parseAudio(audioBase64, audioFormat)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return int(duration / 60 * 100 / 0.06), nil
|
||
}
|
||
|
||
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
|
||
if audioBase64 == "" {
|
||
return 0, nil
|
||
}
|
||
duration, err := parseAudio(audioBase64, audioFormat)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return int(duration / 60 * 200 / 0.24), nil
|
||
}
|
||
|
||
// CountTextToken 统计文本的token数量,仅OpenAI模型使用tokenizer,其余模型使用估算
|
||
func CountTextToken(text string, model string) int {
|
||
if text == "" {
|
||
return 0
|
||
}
|
||
if common.IsOpenAITextModel(model) {
|
||
tokenEncoder := getTokenEncoder(model)
|
||
return getTokenNum(tokenEncoder, text)
|
||
} else {
|
||
// 非openai模型,使用tiktoken-go计算没有意义,使用估算节省资源
|
||
return EstimateTokenByModel(model, text)
|
||
}
|
||
}
|