diff --git a/common/init.go b/common/init.go index 66f05165b..1ea356ba3 100644 --- a/common/init.go +++ b/common/init.go @@ -111,6 +111,7 @@ func initConstantEnv() { constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) // ForceStreamOption 覆盖请求参数,强制返回usage信息 constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) + constant.CountToken = GetEnvOrDefaultBool("CountToken", true) constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false) constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true) diff --git a/constant/env.go b/constant/env.go index 09d4a2f36..ade835c01 100644 --- a/constant/env.go +++ b/constant/env.go @@ -4,6 +4,7 @@ var StreamingTimeout int var DifyDebug bool var MaxFileDownloadMB int var ForceStreamOption bool +var CountToken bool var GetMediaToken bool var GetMediaTokenNotStream bool var UpdateTask bool diff --git a/service/token_counter.go b/service/token_counter.go index 325fbd7ab..e4bd1a3cc 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -143,6 +143,12 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er if fileMeta.Detail == "low" && !isPatchBased { return baseTokens, nil } + + // Whether to count image tokens at all + if !constant.GetMediaToken { + return 3 * baseTokens, nil + } + if !constant.GetMediaTokenNotStream && !stream { return 3 * baseTokens, nil } @@ -150,10 +156,6 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er if fileMeta.Detail == "auto" || fileMeta.Detail == "" { fileMeta.Detail = "high" } - // Whether to count image tokens at all - if !constant.GetMediaToken { - return 3 * baseTokens, nil - } // Decode image to get dimensions var config image.Config @@ -256,16 +258,15 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er } func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) { + // 是否统计token + if !constant.CountToken { + return 0, nil + } + if meta == nil { return 0, errors.New("token count meta is nil") } - if !constant.GetMediaToken { - return 0, nil - } - if !constant.GetMediaTokenNotStream && !info.IsStream { - return 0, nil - } if info.RelayFormat == types.RelayFormatOpenAIRealtime { return 0, nil } @@ -316,9 +317,19 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco shouldFetchFiles = false } - if shouldFetchFiles { - for _, file := range meta.Files { - if strings.HasPrefix(file.OriginData, "http") { + // 是否本地计算媒体token数量 + if !constant.GetMediaToken { + shouldFetchFiles = false + } + + // 是否在非流模式下本地计算媒体token数量 + if !constant.GetMediaTokenNotStream && !info.IsStream { + shouldFetchFiles = false + } + + for _, file := range meta.Files { + if strings.HasPrefix(file.OriginData, "http") { + if shouldFetchFiles { mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter") if err != nil { return 0, fmt.Errorf("error getting file base64 from url: %v", err) @@ -333,28 +344,28 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco file.FileType = types.FileTypeFile } file.MimeType = mineType - } else if strings.HasPrefix(file.OriginData, "data:") { - // get mime type from base64 header - parts := strings.SplitN(file.OriginData, ",", 2) - if len(parts) >= 1 { - header := parts[0] - // Extract mime type from "data:mime/type;base64" format - if strings.Contains(header, ":") && strings.Contains(header, ";") { - mimeStart := strings.Index(header, ":") + 1 - mimeEnd := strings.Index(header, ";") - if mimeStart < mimeEnd { - mineType := header[mimeStart:mimeEnd] - if strings.HasPrefix(mineType, "image/") { - file.FileType = types.FileTypeImage - } else if strings.HasPrefix(mineType, "video/") { - file.FileType = types.FileTypeVideo - } else if strings.HasPrefix(mineType, "audio/") { - file.FileType = types.FileTypeAudio - } else { - file.FileType = types.FileTypeFile - } - file.MimeType = mineType + } + } else if strings.HasPrefix(file.OriginData, "data:") { + // get mime type from base64 header + parts := strings.SplitN(file.OriginData, ",", 2) + if len(parts) >= 1 { + header := parts[0] + // Extract mime type from "data:mime/type;base64" format + if strings.Contains(header, ":") && strings.Contains(header, ";") { + mimeStart := strings.Index(header, ":") + 1 + mimeEnd := strings.Index(header, ";") + if mimeStart < mimeEnd { + mineType := header[mimeStart:mimeEnd] + if strings.HasPrefix(mineType, "image/") { + file.FileType = types.FileTypeImage + } else if strings.HasPrefix(mineType, "video/") { + file.FileType = types.FileTypeVideo + } else if strings.HasPrefix(mineType, "audio/") { + file.FileType = types.FileTypeAudio + } else { + file.FileType = types.FileTypeFile } + file.MimeType = mineType } } } @@ -365,7 +376,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco switch file.FileType { case types.FileTypeImage: if info.RelayFormat == types.RelayFormatGemini { - tkm += 256 + tkm += 520 // gemini per input image tokens } else { token, err := getImageToken(file, model, info.IsStream) if err != nil {