mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-07 22:27:26 +00:00
Compare commits
16 Commits
v0.9.28
...
sora-remix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80ad0067ec | ||
|
|
7395e90013 | ||
|
|
d732cdd259 | ||
|
|
0b9f6a58bc | ||
|
|
7d1bad1b37 | ||
|
|
45556c961f | ||
|
|
ffc45a756e | ||
|
|
48635360cd | ||
|
|
e7e5cc2c05 | ||
|
|
0c051e968f | ||
|
|
f5b409d74f | ||
|
|
509d1f633a | ||
|
|
0c6d890f6e | ||
|
|
2f7eebcd10 | ||
|
|
3954feb993 | ||
|
|
b6a02d8303 |
@@ -238,6 +238,7 @@ docker run --name new-api -d --restart always \
|
||||
- `gemini-2.5-flash-nothinking` - Disable thinking mode
|
||||
- `gemini-2.5-pro-thinking` - Enable thinking mode
|
||||
- `gemini-2.5-pro-thinking-128` - Enable thinking mode with thinking budget of 128 tokens
|
||||
- You can also append `-low`, `-medium`, or `-high` to any Gemini model name to request the corresponding reasoning effort (no extra thinking-budget suffix needed).
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -234,6 +234,7 @@ docker run --name new-api -d --restart always \
|
||||
- `gemini-2.5-flash-nothinking` - Désactiver le mode de pensée
|
||||
- `gemini-2.5-pro-thinking` - Activer le mode de pensée
|
||||
- `gemini-2.5-pro-thinking-128` - Activer le mode de pensée avec budget de pensée de 128 tokens
|
||||
- Vous pouvez également ajouter les suffixes `-low`, `-medium` ou `-high` aux modèles Gemini pour fixer le niveau d’effort de raisonnement (sans suffixe de budget supplémentaire).
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -243,6 +243,7 @@ docker run --name new-api -d --restart always \
|
||||
- `gemini-2.5-flash-nothinking` - 思考モードを無効にする
|
||||
- `gemini-2.5-pro-thinking` - 思考モードを有効にする
|
||||
- `gemini-2.5-pro-thinking-128` - 思考モードを有効にし、思考予算を128トークンに設定する
|
||||
- Gemini モデル名の末尾に `-low` / `-medium` / `-high` を付けることで推論強度を直接指定できます(追加の思考予算サフィックスは不要です)。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -239,6 +239,7 @@ docker run --name new-api -d --restart always \
|
||||
- `gemini-2.5-flash-nothinking` - 禁用思考模式
|
||||
- `gemini-2.5-pro-thinking` - 启用思考模式
|
||||
- `gemini-2.5-pro-thinking-128` - 启用思考模式,并设置思考预算为128tokens
|
||||
- 也可以直接在 Gemini 模型名称后追加 `-low` / `-medium` / `-high` 来控制思考力度(无需再设置思考预算后缀)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -121,6 +121,9 @@ var BatchUpdateInterval int
|
||||
|
||||
var RelayTimeout int // unit is second
|
||||
|
||||
var RelayMaxIdleConns int
|
||||
var RelayMaxIdleConnsPerHost int
|
||||
|
||||
var GeminiSafetySetting string
|
||||
|
||||
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
|
||||
|
||||
@@ -90,6 +90,8 @@ func InitEnv() {
|
||||
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
|
||||
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
|
||||
RelayMaxIdleConns = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS", 500)
|
||||
RelayMaxIdleConnsPerHost = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS_PER_HOST", 100)
|
||||
|
||||
// Initialize string variables with GetEnvOrDefaultString
|
||||
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
@@ -129,6 +131,8 @@ func initConstantEnv() {
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
// 任务轮询时查询的最大数量
|
||||
constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
|
||||
|
||||
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
|
||||
if soraPatchStr != "" {
|
||||
|
||||
@@ -17,6 +17,13 @@ var (
|
||||
"flux-",
|
||||
"flux.1-",
|
||||
}
|
||||
OpenAITextModels = []string{
|
||||
"gpt-",
|
||||
"o1",
|
||||
"o3",
|
||||
"o4",
|
||||
"chatgpt",
|
||||
}
|
||||
)
|
||||
|
||||
func IsOpenAIResponseOnlyModel(modelName string) bool {
|
||||
@@ -40,3 +47,13 @@ func IsImageGenerationModel(modelName string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsOpenAITextModel(modelName string) bool {
|
||||
modelName = strings.ToLower(modelName)
|
||||
for _, m := range OpenAITextModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -3,8 +3,9 @@ package constant
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
ContextKeyTokenCountMeta ContextKey = "token_count_meta"
|
||||
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
||||
ContextKeyTokenCountMeta ContextKey = "token_count_meta"
|
||||
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
||||
ContextKeyEstimatedTokens ContextKey = "estimated_tokens"
|
||||
|
||||
ContextKeyOriginalModel ContextKey = "original_model"
|
||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||
|
||||
@@ -15,6 +15,7 @@ var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
var TaskQueryLimit int
|
||||
|
||||
// temporary variable for sora patch, will be removed in future
|
||||
var TaskPricePatches []string
|
||||
|
||||
@@ -15,6 +15,7 @@ const (
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
TaskActionFirstTailGenerate = "firstTailGenerate"
|
||||
TaskActionReferenceGenerate = "referenceGenerate"
|
||||
TaskActionRemix = "remixGenerate"
|
||||
)
|
||||
|
||||
var SunoModel2Action = map[string]string{
|
||||
|
||||
@@ -351,7 +351,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
info.SetEstimatePromptTokens(usage.PromptTokens)
|
||||
|
||||
quota := 0
|
||||
if !priceData.UsePrice {
|
||||
|
||||
@@ -125,13 +125,13 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
}
|
||||
}
|
||||
|
||||
tokens, err := service.CountRequestToken(c, meta, relayInfo)
|
||||
tokens, err := service.EstimateRequestToken(c, meta, relayInfo)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
return
|
||||
}
|
||||
|
||||
relayInfo.SetPromptTokens(tokens)
|
||||
relayInfo.SetEstimatePromptTokens(tokens)
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
|
||||
if err != nil {
|
||||
|
||||
@@ -29,7 +29,7 @@ func UpdateTaskBulk() {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
common.SysLog("任务进度轮询开始")
|
||||
ctx := context.TODO()
|
||||
allTasks := model.GetAllUnFinishSyncTasks(500)
|
||||
allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
|
||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||
for _, t := range allTasks {
|
||||
platformTask[t.Platform] = append(platformTask[t.Platform], t)
|
||||
|
||||
@@ -16,13 +16,13 @@
|
||||
"name": "OpenAI格式(Responses)"
|
||||
},
|
||||
{
|
||||
"name": "图片生成/编辑(Images)"
|
||||
"name": "图片生成"
|
||||
},
|
||||
{
|
||||
"name": "图片生成/编辑(Images)/OpenAI兼容格式"
|
||||
"name": "图片生成/OpenAI兼容格式"
|
||||
},
|
||||
{
|
||||
"name": "图片生成/编辑(Images)/Qwen千问"
|
||||
"name": "图片生成/Qwen千问"
|
||||
},
|
||||
{
|
||||
"name": "视频生成"
|
||||
@@ -292,7 +292,7 @@
|
||||
"description": " 百炼qwen-image系列图片生成",
|
||||
"operationId": "createImage",
|
||||
"tags": [
|
||||
"图片生成/编辑(Images)/Qwen千问"
|
||||
"图片生成/Qwen千问"
|
||||
],
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
@@ -408,7 +408,7 @@
|
||||
"description": " 百炼qwen-image系列图片生成",
|
||||
"operationId": "createImage",
|
||||
"tags": [
|
||||
"图片生成/编辑(Images)/Qwen千问"
|
||||
"图片生成/Qwen千问"
|
||||
],
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
|
||||
@@ -142,7 +142,7 @@ type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
// TODO Conflict with thinkingbudget.
|
||||
ThinkingLevel json.RawMessage `json:"thinkingLevel,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON allows GeminiThinkingConfig to accept both snake_case and camelCase fields.
|
||||
@@ -150,9 +150,9 @@ func (c *GeminiThinkingConfig) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiThinkingConfig
|
||||
var aux struct {
|
||||
Alias
|
||||
IncludeThoughtsSnake *bool `json:"include_thoughts,omitempty"`
|
||||
ThinkingBudgetSnake *int `json:"thinking_budget,omitempty"`
|
||||
ThinkingLevelSnake json.RawMessage `json:"thinking_level,omitempty"`
|
||||
IncludeThoughtsSnake *bool `json:"include_thoughts,omitempty"`
|
||||
ThinkingBudgetSnake *int `json:"thinking_budget,omitempty"`
|
||||
ThinkingLevelSnake string `json:"thinking_level,omitempty"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
@@ -169,7 +169,7 @@ func (c *GeminiThinkingConfig) UnmarshalJSON(data []byte) error {
|
||||
c.ThinkingBudget = aux.ThinkingBudgetSnake
|
||||
}
|
||||
|
||||
if len(aux.ThinkingLevelSnake) > 0 {
|
||||
if aux.ThinkingLevelSnake != "" {
|
||||
c.ThinkingLevel = aux.ThinkingLevelSnake
|
||||
}
|
||||
|
||||
|
||||
@@ -181,6 +181,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/videos/") && strings.HasSuffix(c.Request.URL.Path, "/remix") {
|
||||
relayMode := relayconstant.RelayModeVideoSubmit
|
||||
c.Set("relay_mode", relayMode)
|
||||
shouldSelectChannel = false
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/videos") {
|
||||
//curl https://api.openai.com/v1/videos \
|
||||
// -H "Authorization: Bearer $OPENAI_API_KEY" \
|
||||
|
||||
@@ -673,7 +673,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
@@ -734,10 +734,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
||||
claudeInfo.Usage.CompletionTokens = completionTokens
|
||||
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeResponse.Completion, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
|
||||
@@ -74,7 +74,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -105,7 +105,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
response.Usage = *usage
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
@@ -142,10 +142,6 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
usage := service.ResponseText2Usage(c, cfResp.Result.Text, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
@@ -225,9 +225,9 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
usage.CompletionTokens = 0
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
usage.TotalTokens = info.GetEstimatePromptTokens()
|
||||
} else {
|
||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
||||
|
||||
@@ -246,7 +246,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
})
|
||||
helper.Done(c)
|
||||
if usage.TotalTokens == 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
}
|
||||
usage.CompletionTokens += nodeToken
|
||||
return usage, nil
|
||||
|
||||
@@ -137,6 +137,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
} else if baseModel, level := parseThinkingLevelSuffix(info.UpstreamModelName); level != "" {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -70,12 +69,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
|
||||
println(string(responseBody))
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
|
||||
if info.IsGeminiBatchEmbedding {
|
||||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||
|
||||
@@ -19,8 +19,8 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -122,6 +122,14 @@ func clampThinkingBudgetByEffort(modelName string, effort string) int {
|
||||
return clampThinkingBudget(modelName, maxBudget)
|
||||
}
|
||||
|
||||
func parseThinkingLevelSuffix(modelName string) (string, string) {
|
||||
base, level, ok := reasoning.TrimEffortSuffix(modelName)
|
||||
if !ok {
|
||||
return modelName, ""
|
||||
}
|
||||
return base, level
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
@@ -178,6 +186,12 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
} else if _, level := parseThinkingLevelSuffix(modelName); level != "" {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
ThinkingLevel: level,
|
||||
}
|
||||
info.ReasoningEffort = level
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1115,7 +1129,7 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
if usage.CompletionTokens <= 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
@@ -1288,11 +1302,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
// Google has not yet clarified how embedding models will be billed
|
||||
// refer to openai billing method to use input tokens billing
|
||||
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
openAIResponse.Usage = *usage
|
||||
|
||||
jsonResponse, jsonErr := common.Marshal(openAIResponse)
|
||||
|
||||
@@ -163,7 +163,7 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
}
|
||||
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: int(minimaxResp.ExtraInfo.UsageCharacters),
|
||||
}
|
||||
|
||||
@@ -183,7 +183,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
@@ -245,9 +245,9 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
}
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: info.PromptTokens + completionTokens,
|
||||
TotalTokens: info.GetEstimatePromptTokens() + completionTokens,
|
||||
}
|
||||
usageModified = true
|
||||
}
|
||||
@@ -336,8 +336,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
// and can be terminated directly.
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
usage.TotalTokens = info.GetEstimatePromptTokens()
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
@@ -383,7 +383,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
usage.CompletionTokens = 0
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return nil, usage
|
||||
|
||||
@@ -141,7 +141,7 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
}
|
||||
|
||||
if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
}
|
||||
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
@@ -81,7 +81,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
usage, err = palmHandler(c, info, resp)
|
||||
}
|
||||
|
||||
@@ -121,13 +121,8 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
|
||||
}, resp.StatusCode)
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: info.PromptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
usage := service.ResponseText2Usage(c, palmResponse.Candidates[0].Content, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
fullTextResponse.Usage = *usage
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
@@ -135,5 +130,5 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return &usage, nil
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
@@ -67,11 +69,30 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.apiKey = info.ApiKey
|
||||
}
|
||||
|
||||
func validateRemixRequest(c *gin.Context) *dto.TaskError {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return validateRemixRequest(c)
|
||||
}
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
|
||||
var tencentResponse TencentChatResponse
|
||||
err := json.Unmarshal([]byte(data), &tencentResponse)
|
||||
err := common.Unmarshal([]byte(data), &tencentResponse)
|
||||
if err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
@@ -130,7 +130,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||
return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens()), nil
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
|
||||
@@ -184,9 +184,9 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
c.Data(http.StatusOK, contentType, audioData)
|
||||
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
TotalTokens: info.GetEstimatePromptTokens(),
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
@@ -284,9 +284,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
||||
if msg.Sequence < 0 {
|
||||
c.Status(http.StatusOK)
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
TotalTokens: info.GetEstimatePromptTokens(),
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
@@ -297,9 +297,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
TotalTokens: info.GetEstimatePromptTokens(),
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
|
||||
@@ -73,6 +73,11 @@ type ChannelMeta struct {
|
||||
SupportStreamOptions bool // 是否支持流式选项
|
||||
}
|
||||
|
||||
type TokenCountMeta struct {
|
||||
//promptTokens int
|
||||
estimatePromptTokens int
|
||||
}
|
||||
|
||||
type RelayInfo struct {
|
||||
TokenId int
|
||||
TokenKey string
|
||||
@@ -91,7 +96,6 @@ type RelayInfo struct {
|
||||
RelayMode int
|
||||
OriginModelName string
|
||||
RequestURLPath string
|
||||
PromptTokens int
|
||||
ShouldIncludeUsage bool
|
||||
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
||||
ClientWs *websocket.Conn
|
||||
@@ -115,6 +119,7 @@ type RelayInfo struct {
|
||||
Request dto.Request
|
||||
|
||||
ThinkingContentInfo
|
||||
TokenCountMeta
|
||||
*ClaudeConvertInfo
|
||||
*RerankerInfo
|
||||
*ResponsesUsageInfo
|
||||
@@ -189,7 +194,7 @@ func (info *RelayInfo) ToString() string {
|
||||
fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
|
||||
fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
|
||||
fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
|
||||
fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens)
|
||||
fmt.Fprintf(b, "EstimatePromptTokens: %d, ", info.estimatePromptTokens)
|
||||
fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
|
||||
fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
|
||||
fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
|
||||
@@ -391,7 +396,6 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||
|
||||
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
|
||||
|
||||
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
|
||||
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
|
||||
@@ -408,6 +412,10 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
},
|
||||
TokenCountMeta: TokenCountMeta{
|
||||
//promptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
|
||||
estimatePromptTokens: common.GetContextKeyInt(c, constant.ContextKeyEstimatedTokens),
|
||||
},
|
||||
}
|
||||
|
||||
if info.RelayMode == relayconstant.RelayModeUnknown {
|
||||
@@ -463,8 +471,16 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
|
||||
}
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||
info.PromptTokens = promptTokens
|
||||
//func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||
// info.promptTokens = promptTokens
|
||||
//}
|
||||
|
||||
func (info *RelayInfo) SetEstimatePromptTokens(promptTokens int) {
|
||||
info.estimatePromptTokens = promptTokens
|
||||
}
|
||||
|
||||
func (info *RelayInfo) GetEstimatePromptTokens() int {
|
||||
return info.estimatePromptTokens
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetFirstResponseTime() {
|
||||
|
||||
@@ -57,8 +57,8 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
jinaResp = dto.RerankResponse{
|
||||
Results: jinaRespResults,
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
TotalTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
TotalTokens: info.GetEstimatePromptTokens(),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -192,9 +192,9 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.PromptTokens,
|
||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: relayInfo.PromptTokens,
|
||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
}
|
||||
extraContent += "(可能是请求出错)"
|
||||
}
|
||||
|
||||
@@ -99,7 +99,10 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
// check if free model pre-consume is disabled
|
||||
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
|
||||
// if model price or ratio is 0, do not pre-consume quota
|
||||
if usePrice {
|
||||
if groupRatioInfo.GroupRatio == 0 {
|
||||
preConsumedQuota = 0
|
||||
freeModel = true
|
||||
} else if usePrice {
|
||||
if modelPrice == 0 {
|
||||
preConsumedQuota = 0
|
||||
freeModel = true
|
||||
|
||||
@@ -72,6 +72,8 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
if common.DebugEnabled {
|
||||
// print timeout and ping interval for debugging
|
||||
println("relay timeout seconds:", common.RelayTimeout)
|
||||
println("relay max idle conns:", common.RelayMaxIdleConns)
|
||||
println("relay max idle conns per host:", common.RelayMaxIdleConnsPerHost)
|
||||
println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
|
||||
println("ping interval seconds:", int64(pingInterval.Seconds()))
|
||||
}
|
||||
|
||||
@@ -32,7 +32,86 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
if info.TaskRelayInfo == nil {
|
||||
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
|
||||
info.Action = constant.TaskActionRemix
|
||||
}
|
||||
|
||||
// 提取 remix 任务的 video_id
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
videoID := c.Param("video_id")
|
||||
if strings.TrimSpace(videoID) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
info.OriginTaskID = videoID
|
||||
}
|
||||
|
||||
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||
|
||||
// 获取原始任务信息
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if info.OriginModelName == "" {
|
||||
if originTask.Properties.OriginModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.OriginModelName
|
||||
} else if originTask.Properties.UpstreamModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.UpstreamModelName
|
||||
} else {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
if m, ok := taskData["model"].(string); ok && m != "" {
|
||||
info.OriginModelName = m
|
||||
platform = originTask.Platform
|
||||
}
|
||||
}
|
||||
}
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
platform = originTask.Platform
|
||||
}
|
||||
|
||||
// 使用原始任务的参数
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
secondsStr, _ := taskData["seconds"].(string)
|
||||
seconds, _ := strconv.Atoi(secondsStr)
|
||||
if seconds <= 0 {
|
||||
seconds = 4
|
||||
}
|
||||
sizeStr, _ := taskData["size"].(string)
|
||||
if info.PriceData.OtherRatios == nil {
|
||||
info.PriceData.OtherRatios = map[string]float64{}
|
||||
}
|
||||
info.PriceData.OtherRatios["seconds"] = float64(seconds)
|
||||
info.PriceData.OtherRatios["size"] = 1
|
||||
if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
|
||||
info.PriceData.OtherRatios["size"] = 1.666667
|
||||
}
|
||||
}
|
||||
}
|
||||
if platform == "" {
|
||||
platform = GetTaskPlatform(c)
|
||||
}
|
||||
@@ -94,34 +173,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
return
|
||||
}
|
||||
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
|
||||
}
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
}
|
||||
}
|
||||
|
||||
// build body
|
||||
requestBody, err := adaptor.BuildRequestBody(c, info)
|
||||
if err != nil {
|
||||
|
||||
@@ -14,6 +14,7 @@ func SetVideoRouter(router *gin.Engine) {
|
||||
videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
|
||||
videoV1Router.POST("/video/generations", controller.RelayTask)
|
||||
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
|
||||
videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)
|
||||
}
|
||||
// openai compatible API video routes
|
||||
// docs: https://platform.openai.com/docs/api-reference/videos/create
|
||||
|
||||
@@ -209,7 +209,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: info.PromptTokens,
|
||||
InputTokens: info.GetEstimatePromptTokens(),
|
||||
OutputTokens: 0,
|
||||
},
|
||||
}
|
||||
@@ -734,12 +734,18 @@ func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: info.PromptTokens,
|
||||
PromptTokenCount: info.GetEstimatePromptTokens(),
|
||||
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
|
||||
TotalTokenCount: info.PromptTokens,
|
||||
TotalTokenCount: info.GetEstimatePromptTokens(),
|
||||
},
|
||||
}
|
||||
|
||||
if openAIResponse.Usage != nil {
|
||||
geminiResponse.UsageMetadata.PromptTokenCount = openAIResponse.Usage.PromptTokens
|
||||
geminiResponse.UsageMetadata.CandidatesTokenCount = openAIResponse.Usage.CompletionTokens
|
||||
geminiResponse.UsageMetadata.TotalTokenCount = openAIResponse.Usage.TotalTokens
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
|
||||
@@ -34,12 +34,20 @@ func checkRedirect(req *http.Request, via []*http.Request) error {
|
||||
}
|
||||
|
||||
func InitHttpClient() {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
|
||||
if common.RelayTimeout == 0 {
|
||||
httpClient = &http.Client{
|
||||
Transport: transport,
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
} else {
|
||||
httpClient = &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
@@ -84,6 +92,9 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
case "http", "https":
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
},
|
||||
CheckRedirect: checkRedirect,
|
||||
@@ -116,6 +127,9 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
"math"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -23,64 +21,8 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var defaultTokenEncoder tokenizer.Codec
|
||||
|
||||
// tokenEncoderMap is used to store token encoders for different models
|
||||
var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
||||
|
||||
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
|
||||
var tokenEncoderMutex sync.RWMutex
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
defaultTokenEncoder = codec.NewCl100kBase()
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) tokenizer.Codec {
|
||||
// First, try to get the encoder from cache with read lock
|
||||
tokenEncoderMutex.RLock()
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
tokenEncoderMutex.RUnlock()
|
||||
return encoder
|
||||
}
|
||||
tokenEncoderMutex.RUnlock()
|
||||
|
||||
// If not in cache, create new encoder with write lock
|
||||
tokenEncoderMutex.Lock()
|
||||
defer tokenEncoderMutex.Unlock()
|
||||
|
||||
// Double-check if another goroutine already created the encoder
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
return encoder
|
||||
}
|
||||
|
||||
// Create new encoder
|
||||
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
|
||||
if err != nil {
|
||||
// Cache the default encoder for this model to avoid repeated failures
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
// Cache the new encoder
|
||||
tokenEncoderMap[model] = modelCodec
|
||||
return modelCodec
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tkm, _ := tokenEncoder.Count(text)
|
||||
return tkm
|
||||
}
|
||||
|
||||
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||||
if fileMeta == nil {
|
||||
return 0, fmt.Errorf("image_url_is_nil")
|
||||
@@ -257,7 +199,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
return tiles*tileTokens + baseTokens, nil
|
||||
}
|
||||
|
||||
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
// 是否统计token
|
||||
if !constant.CountToken {
|
||||
return 0, nil
|
||||
@@ -375,14 +317,14 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
for i, file := range meta.Files {
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if info.RelayFormat == types.RelayFormatGemini {
|
||||
tkm += 520 // gemini per input image tokens
|
||||
} else {
|
||||
if common.IsOpenAITextModel(info.OriginModelName) {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
|
||||
}
|
||||
tkm += token
|
||||
} else {
|
||||
tkm += 520
|
||||
}
|
||||
case types.FileTypeAudio:
|
||||
tkm += 256
|
||||
@@ -399,111 +341,6 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
||||
tkm := 0
|
||||
|
||||
// Count tokens in messages
|
||||
msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tkm += msgTokens
|
||||
|
||||
// Count tokens in system message
|
||||
if request.System != "" {
|
||||
systemTokens := CountTokenInput(request.System, model)
|
||||
tkm += systemTokens
|
||||
}
|
||||
|
||||
if request.Tools != nil {
|
||||
// check is array
|
||||
if tools, ok := request.Tools.([]any); ok {
|
||||
if len(tools) > 0 {
|
||||
parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
|
||||
if err1 != nil {
|
||||
return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
|
||||
}
|
||||
toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
|
||||
if err2 != nil {
|
||||
return 0, fmt.Errorf("tools: %v", err)
|
||||
}
|
||||
tkm += toolTokens
|
||||
}
|
||||
} else {
|
||||
return 0, errors.New("tools: Input should be a valid list")
|
||||
}
|
||||
}
|
||||
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
tokenNum := 0
|
||||
|
||||
for _, message := range messages {
|
||||
// Count tokens for role
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.IsStringContent() {
|
||||
tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
|
||||
} else {
|
||||
content, err := message.ParseContent()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, mediaMessage := range content {
|
||||
switch mediaMessage.Type {
|
||||
case "text":
|
||||
tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
|
||||
case "image":
|
||||
//imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
|
||||
//if err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
tokenNum += 1000
|
||||
case "tool_use":
|
||||
if mediaMessage.Input != nil {
|
||||
tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
|
||||
inputJSON, _ := json.Marshal(mediaMessage.Input)
|
||||
tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
|
||||
}
|
||||
case "tool_result":
|
||||
if mediaMessage.Content != nil {
|
||||
contentJSON, _ := json.Marshal(mediaMessage.Content)
|
||||
tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
|
||||
tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
|
||||
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
tokenNum := 0
|
||||
|
||||
for _, tool := range tools {
|
||||
tokenNum += getTokenNum(tokenEncoder, tool.Name)
|
||||
tokenNum += getTokenNum(tokenEncoder, tool.Description)
|
||||
|
||||
schemaJSON, err := json.Marshal(tool.InputSchema)
|
||||
if err != nil {
|
||||
return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
|
||||
}
|
||||
|
||||
// Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
|
||||
tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
|
||||
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
||||
audioToken := 0
|
||||
textToken := 0
|
||||
@@ -578,31 +415,6 @@ func CountTokenInput(input any, model string) int {
|
||||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||||
}
|
||||
|
||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||
tokens := 0
|
||||
for _, message := range messages {
|
||||
tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
||||
tokens += tkm
|
||||
if message.Delta.ToolCalls != nil {
|
||||
for _, tool := range message.Delta.ToolCalls {
|
||||
tkm := CountTokenInput(tool.Function.Name, model)
|
||||
tokens += tkm
|
||||
tkm = CountTokenInput(tool.Function.Arguments, model)
|
||||
tokens += tkm
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountTTSToken(text string, model string) int {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
return utf8.RuneCountInString(text)
|
||||
} else {
|
||||
return CountTextToken(text, model)
|
||||
}
|
||||
}
|
||||
|
||||
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
|
||||
if audioBase64 == "" {
|
||||
return 0, nil
|
||||
@@ -625,17 +437,16 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
||||
return int(duration / 60 * 200 / 0.24), nil
|
||||
}
|
||||
|
||||
//func CountAudioToken(sec float64, audioType string) {
|
||||
// if audioType == "input" {
|
||||
//
|
||||
// }
|
||||
//}
|
||||
|
||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
// CountTextToken 统计文本的token数量,仅OpenAI模型使用tokenizer,其余模型使用估算
|
||||
func CountTextToken(text string, model string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
if common.IsOpenAITextModel(model) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
} else {
|
||||
// 非openai模型,使用tiktoken-go计算没有意义,使用估算节省资源
|
||||
return EstimateTokenByModel(model, text)
|
||||
}
|
||||
}
|
||||
|
||||
230
service/token_estimator.go
Normal file
230
service/token_estimator.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Provider 定义模型厂商大类
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
OpenAI Provider = "openai" // 代表 GPT-3.5, GPT-4, GPT-4o
|
||||
Gemini Provider = "gemini" // 代表 Gemini 1.0, 1.5 Pro/Flash
|
||||
Claude Provider = "claude" // 代表 Claude 3, 3.5 Sonnet
|
||||
Unknown Provider = "unknown" // 兜底默认
|
||||
)
|
||||
|
||||
// multipliers 定义不同厂商的计费权重
|
||||
type multipliers struct {
|
||||
Word float64 // 英文单词 (每词)
|
||||
Number float64 // 数字 (每连续数字串)
|
||||
CJK float64 // 中日韩字符 (每字)
|
||||
Symbol float64 // 普通标点符号 (每个)
|
||||
MathSymbol float64 // 数学符号 (∑,∫,∂,√等,每个)
|
||||
URLDelim float64 // URL分隔符 (/,:,?,&,=,#,%) - tokenizer优化好
|
||||
AtSign float64 // @符号 - 导致单词切分,消耗较高
|
||||
Emoji float64 // Emoji表情 (每个)
|
||||
Newline float64 // 换行符/制表符 (每个)
|
||||
Space float64 // 空格 (每个)
|
||||
BasePad int // 基础起步消耗 (Start/End tokens)
|
||||
}
|
||||
|
||||
var (
|
||||
multipliersMap = map[Provider]multipliers{
|
||||
Gemini: {
|
||||
Word: 1.15, Number: 2.8, CJK: 0.68, Symbol: 0.38, MathSymbol: 1.05, URLDelim: 1.2, AtSign: 2.5, Emoji: 1.08, Newline: 1.15, Space: 0.2, BasePad: 0,
|
||||
},
|
||||
Claude: {
|
||||
Word: 1.13, Number: 1.63, CJK: 1.21, Symbol: 0.4, MathSymbol: 4.52, URLDelim: 1.26, AtSign: 2.82, Emoji: 2.6, Newline: 0.89, Space: 0.39, BasePad: 0,
|
||||
},
|
||||
OpenAI: {
|
||||
Word: 1.02, Number: 1.55, CJK: 0.85, Symbol: 0.4, MathSymbol: 2.68, URLDelim: 1.0, AtSign: 2.0, Emoji: 2.12, Newline: 0.5, Space: 0.42, BasePad: 0,
|
||||
},
|
||||
}
|
||||
multipliersLock sync.RWMutex
|
||||
)
|
||||
|
||||
// getMultipliers 根据厂商获取权重配置
|
||||
func getMultipliers(p Provider) multipliers {
|
||||
multipliersLock.RLock()
|
||||
defer multipliersLock.RUnlock()
|
||||
|
||||
switch p {
|
||||
case Gemini:
|
||||
return multipliersMap[Gemini]
|
||||
case Claude:
|
||||
return multipliersMap[Claude]
|
||||
case OpenAI:
|
||||
return multipliersMap[OpenAI]
|
||||
default:
|
||||
// 默认兜底 (按 OpenAI 的算)
|
||||
return multipliersMap[OpenAI]
|
||||
}
|
||||
}
|
||||
|
||||
// EstimateToken 计算 Token 数量
|
||||
func EstimateToken(provider Provider, text string) int {
|
||||
m := getMultipliers(provider)
|
||||
var count float64
|
||||
|
||||
// 状态机变量
|
||||
type WordType int
|
||||
const (
|
||||
None WordType = iota
|
||||
Latin
|
||||
Number
|
||||
)
|
||||
currentWordType := None
|
||||
|
||||
for _, r := range text {
|
||||
// 1. 处理空格和换行符
|
||||
if unicode.IsSpace(r) {
|
||||
currentWordType = None
|
||||
// 换行符和制表符使用Newline权重
|
||||
if r == '\n' || r == '\t' {
|
||||
count += m.Newline
|
||||
} else {
|
||||
// 普通空格使用Space权重
|
||||
count += m.Space
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 2. 处理 CJK (中日韩) - 按字符计费
|
||||
if isCJK(r) {
|
||||
currentWordType = None
|
||||
count += m.CJK
|
||||
continue
|
||||
}
|
||||
|
||||
// 3. 处理Emoji - 使用专门的Emoji权重
|
||||
if isEmoji(r) {
|
||||
currentWordType = None
|
||||
count += m.Emoji
|
||||
continue
|
||||
}
|
||||
|
||||
// 4. 处理拉丁字母/数字 (英文单词)
|
||||
if isLatinOrNumber(r) {
|
||||
isNum := unicode.IsNumber(r)
|
||||
newType := Latin
|
||||
if isNum {
|
||||
newType = Number
|
||||
}
|
||||
|
||||
// 如果之前不在单词中,或者类型发生变化(字母<->数字),则视为新token
|
||||
// 注意:对于OpenAI,通常"version 3.5"会切分,"abc123xyz"有时也会切分
|
||||
// 这里简单起见,字母和数字切换时增加权重
|
||||
if currentWordType == None || currentWordType != newType {
|
||||
if newType == Number {
|
||||
count += m.Number
|
||||
} else {
|
||||
count += m.Word
|
||||
}
|
||||
currentWordType = newType
|
||||
}
|
||||
// 单词中间的字符不额外计费
|
||||
continue
|
||||
}
|
||||
|
||||
// 5. 处理标点符号/特殊字符 - 按类型使用不同权重
|
||||
currentWordType = None
|
||||
if isMathSymbol(r) {
|
||||
count += m.MathSymbol
|
||||
} else if r == '@' {
|
||||
count += m.AtSign
|
||||
} else if isURLDelim(r) {
|
||||
count += m.URLDelim
|
||||
} else {
|
||||
count += m.Symbol
|
||||
}
|
||||
}
|
||||
|
||||
// 向上取整并加上基础 padding
|
||||
return int(math.Ceil(count)) + m.BasePad
|
||||
}
|
||||
|
||||
// 辅助:判断是否为 CJK 字符
|
||||
func isCJK(r rune) bool {
|
||||
return unicode.Is(unicode.Han, r) ||
|
||||
(r >= 0x3040 && r <= 0x30FF) || // 日文
|
||||
(r >= 0xAC00 && r <= 0xD7A3) // 韩文
|
||||
}
|
||||
|
||||
// 辅助:判断是否为单词主体 (字母或数字)
|
||||
func isLatinOrNumber(r rune) bool {
|
||||
return unicode.IsLetter(r) || unicode.IsNumber(r)
|
||||
}
|
||||
|
||||
// 辅助:判断是否为Emoji字符
|
||||
func isEmoji(r rune) bool {
|
||||
// Emoji的Unicode范围
|
||||
// 基本范围:0x1F300-0x1F9FF (Emoticons, Symbols, Pictographs)
|
||||
// 补充范围:0x2600-0x26FF (Misc Symbols), 0x2700-0x27BF (Dingbats)
|
||||
// 表情符号:0x1F600-0x1F64F (Emoticons)
|
||||
// 其他:0x1F900-0x1F9FF (Supplemental Symbols and Pictographs)
|
||||
return (r >= 0x1F300 && r <= 0x1F9FF) ||
|
||||
(r >= 0x2600 && r <= 0x26FF) ||
|
||||
(r >= 0x2700 && r <= 0x27BF) ||
|
||||
(r >= 0x1F600 && r <= 0x1F64F) ||
|
||||
(r >= 0x1F900 && r <= 0x1F9FF) ||
|
||||
(r >= 0x1FA00 && r <= 0x1FAFF) // Symbols and Pictographs Extended-A
|
||||
}
|
||||
|
||||
// 辅助:判断是否为数学符号
|
||||
func isMathSymbol(r rune) bool {
|
||||
// 数学运算符和符号
|
||||
// 基本数学符号:∑ ∫ ∂ √ ∞ ≤ ≥ ≠ ≈ ± × ÷
|
||||
// 上下标数字:² ³ ¹ ⁴ ⁵ ⁶ ⁷ ⁸ ⁹ ⁰
|
||||
// 希腊字母等也常用于数学
|
||||
mathSymbols := "∑∫∂√∞≤≥≠≈±×÷∈∉∋∌⊂⊃⊆⊇∪∩∧∨¬∀∃∄∅∆∇∝∟∠∡∢°′″‴⁺⁻⁼⁽⁾ⁿ₀₁₂₃₄₅₆₇₈₉₊₋₌₍₎²³¹⁴⁵⁶⁷⁸⁹⁰"
|
||||
for _, m := range mathSymbols {
|
||||
if r == m {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Mathematical Operators (U+2200–U+22FF)
|
||||
if r >= 0x2200 && r <= 0x22FF {
|
||||
return true
|
||||
}
|
||||
// Supplemental Mathematical Operators (U+2A00–U+2AFF)
|
||||
if r >= 0x2A00 && r <= 0x2AFF {
|
||||
return true
|
||||
}
|
||||
// Mathematical Alphanumeric Symbols (U+1D400–U+1D7FF)
|
||||
if r >= 0x1D400 && r <= 0x1D7FF {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 辅助:判断是否为URL分隔符(tokenizer对这些优化较好)
|
||||
func isURLDelim(r rune) bool {
|
||||
// URL中常见的分隔符,tokenizer通常优化处理
|
||||
urlDelims := "/:?&=;#%"
|
||||
for _, d := range urlDelims {
|
||||
if r == d {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func EstimateTokenByModel(model, text string) int {
|
||||
// strings.Contains(model, "gpt-4o")
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
model = strings.ToLower(model)
|
||||
if strings.Contains(model, "gemini") {
|
||||
return EstimateToken(Gemini, text)
|
||||
} else if strings.Contains(model, "claude") {
|
||||
return EstimateToken(Claude, text)
|
||||
} else {
|
||||
return EstimateToken(OpenAI, text)
|
||||
}
|
||||
}
|
||||
63
service/tokenizer.go
Normal file
63
service/tokenizer.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var defaultTokenEncoder tokenizer.Codec
|
||||
|
||||
// tokenEncoderMap is used to store token encoders for different models
|
||||
var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
||||
|
||||
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
|
||||
var tokenEncoderMutex sync.RWMutex
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
defaultTokenEncoder = codec.NewCl100kBase()
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) tokenizer.Codec {
|
||||
// First, try to get the encoder from cache with read lock
|
||||
tokenEncoderMutex.RLock()
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
tokenEncoderMutex.RUnlock()
|
||||
return encoder
|
||||
}
|
||||
tokenEncoderMutex.RUnlock()
|
||||
|
||||
// If not in cache, create new encoder with write lock
|
||||
tokenEncoderMutex.Lock()
|
||||
defer tokenEncoderMutex.Unlock()
|
||||
|
||||
// Double-check if another goroutine already created the encoder
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
return encoder
|
||||
}
|
||||
|
||||
// Create new encoder
|
||||
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
|
||||
if err != nil {
|
||||
// Cache the default encoder for this model to avoid repeated failures
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
// Cache the new encoder
|
||||
tokenEncoderMap[model] = modelCodec
|
||||
return modelCodec
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tkm, _ := tokenEncoder.Count(text)
|
||||
return tkm
|
||||
}
|
||||
@@ -23,8 +23,7 @@ func ResponseText2Usage(c *gin.Context, responseText string, modeName string, pr
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm := CountTextToken(responseText, modeName)
|
||||
usage.CompletionTokens = ctkm
|
||||
usage.CompletionTokens = EstimateTokenByModel(modeName, responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func GetGlobalSettings() *GlobalSettings {
|
||||
return &globalSettings
|
||||
}
|
||||
|
||||
// ShouldPreserveThinkingSuffix 判断模型是否配置为保留 thinking/-nothinking 后缀
|
||||
// ShouldPreserveThinkingSuffix 判断模型是否配置为保留 thinking/-nothinking/-low/-high/-medium 后缀
|
||||
func ShouldPreserveThinkingSuffix(modelName string) bool {
|
||||
target := strings.TrimSpace(modelName)
|
||||
if target == "" {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
)
|
||||
|
||||
// from songquanpeng/one-api
|
||||
@@ -821,6 +822,10 @@ func FormatMatchingModelName(name string) string {
|
||||
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
|
||||
}
|
||||
|
||||
if base, _, ok := reasoning.TrimEffortSuffix(name); ok {
|
||||
name = base
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
|
||||
20
setting/reasoning/suffix.go
Normal file
20
setting/reasoning/suffix.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package reasoning
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var EffortSuffixes = []string{"-high", "-medium", "-low"}
|
||||
|
||||
// TrimEffortSuffix -> modelName level(low) exists
|
||||
func TrimEffortSuffix(modelName string) (string, string, bool) {
|
||||
suffix, found := lo.Find(EffortSuffixes, func(s string) bool {
|
||||
return strings.HasSuffix(modelName, s)
|
||||
})
|
||||
if !found {
|
||||
return modelName, "", false
|
||||
}
|
||||
return strings.TrimSuffix(modelName, suffix), strings.TrimPrefix(suffix, "-"), true
|
||||
}
|
||||
@@ -39,6 +39,7 @@ import {
|
||||
TASK_ACTION_GENERATE,
|
||||
TASK_ACTION_REFERENCE_GENERATE,
|
||||
TASK_ACTION_TEXT_GENERATE,
|
||||
TASK_ACTION_REMIX_GENERATE,
|
||||
} from '../../../constants/common.constant';
|
||||
import { CHANNEL_OPTIONS } from '../../../constants/channel.constants';
|
||||
|
||||
@@ -125,6 +126,12 @@ const renderType = (type, t) => {
|
||||
{t('参照生视频')}
|
||||
</Tag>
|
||||
);
|
||||
case TASK_ACTION_REMIX_GENERATE:
|
||||
return (
|
||||
<Tag color='blue' shape='circle' prefixIcon={<Sparkles size={14} />}>
|
||||
{t('remix生视频')}
|
||||
</Tag>
|
||||
);
|
||||
default:
|
||||
return (
|
||||
<Tag color='white' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
||||
@@ -359,7 +366,8 @@ export const getTaskLogsColumns = ({
|
||||
record.action === TASK_ACTION_GENERATE ||
|
||||
record.action === TASK_ACTION_TEXT_GENERATE ||
|
||||
record.action === TASK_ACTION_FIRST_TAIL_GENERATE ||
|
||||
record.action === TASK_ACTION_REFERENCE_GENERATE;
|
||||
record.action === TASK_ACTION_REFERENCE_GENERATE ||
|
||||
record.action === TASK_ACTION_REMIX_GENERATE;
|
||||
const isSuccess = record.status === 'SUCCESS';
|
||||
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
|
||||
if (isSuccess && isVideoTask && isUrl) {
|
||||
|
||||
@@ -42,3 +42,4 @@ export const TASK_ACTION_GENERATE = 'generate';
|
||||
export const TASK_ACTION_TEXT_GENERATE = 'textGenerate';
|
||||
export const TASK_ACTION_FIRST_TAIL_GENERATE = 'firstTailGenerate';
|
||||
export const TASK_ACTION_REFERENCE_GENERATE = 'referenceGenerate';
|
||||
export const TASK_ACTION_REMIX_GENERATE = 'remixGenerate';
|
||||
|
||||
@@ -1996,7 +1996,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "Suitable for personal use, no need to set model price.",
|
||||
"适用于为多个用户提供服务的场景": "Suitable for scenarios where multiple users are provided.",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "Suitable for scenarios where the system functions are displayed, providing basic feature demonstrations.",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Adapt to -thinking, -thinking-budget number, and -nothinking suffixes",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Adapt to -thinking, -thinking-budget number, -nothinking, and -low/-medium/-high suffixes",
|
||||
"选择充值套餐": "Choose a top-up package",
|
||||
"选择充值额度": "Select recharge amount",
|
||||
"选择分组": "Select group",
|
||||
@@ -2178,4 +2178,4 @@
|
||||
"默认测试模型": "Default Test Model",
|
||||
"默认补全倍率": "Default completion ratio"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2006,7 +2006,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "Adapté à un usage personnel, pas besoin de définir le prix du modèle.",
|
||||
"适用于为多个用户提供服务的场景": "Adapté aux scénarios où plusieurs utilisateurs sont fournis.",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "Adapté aux scénarios où les fonctions du système sont affichées, fournissant des démonstrations de fonctionnalités de base.",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Adapter les suffixes -thinking, -thinking-budget et -nothinking",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Adapter les suffixes -thinking, -thinking-budget, -nothinking et -low/-medium/-high",
|
||||
"选择充值额度": "Sélectionner le montant de la recharge",
|
||||
"选择分组": "Sélectionner un groupe",
|
||||
"选择同步来源": "Sélectionner la source de synchronisation",
|
||||
@@ -2227,4 +2227,4 @@
|
||||
"随机种子 (留空为随机)": "Graine aléatoire (laisser vide pour aléatoire)",
|
||||
"默认补全倍率": "Taux de complétion par défaut"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1903,7 +1903,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "個人利用のシナリオに適しており、モデル料金の設定は不要です",
|
||||
"适用于为多个用户提供服务的场景": "複数のユーザーにサービスを提供するシナリオに適しています",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "システムの機能を紹介するシナリオに適しており、基本的な機能のデモンストレーションを提供します",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "-thinking、-thinking-予算数値、-nothinkingサフィックスに対応",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "-thinking、-thinking-予算数値、-nothinking、および -low/-medium/-high サフィックスに対応",
|
||||
"选择充值额度": "チャージ額を選択",
|
||||
"选择分组": "グループを選択",
|
||||
"选择同步来源": "同期ソースを選択",
|
||||
@@ -2126,4 +2126,4 @@
|
||||
"可选,用于复现结果": "オプション、結果の再現用",
|
||||
"随机种子 (留空为随机)": "ランダムシード(空欄でランダム)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2017,7 +2017,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "Подходит для сценариев личного использования, не требует установки цен на модели",
|
||||
"适用于为多个用户提供服务的场景": "Подходит для сценариев предоставления услуг нескольким пользователям",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "Подходит для сценариев демонстрации системных функций, предоставляет демонстрацию базовых функций",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Адаптация суффиксов -thinking, -thinking-бюджетные-цифры и -nothinking",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Адаптация суффиксов -thinking, -thinking-бюджетные-цифры, -nothinking и -low/-medium/-high",
|
||||
"选择充值额度": "Выберите сумму пополнения",
|
||||
"选择分组": "Выберите группу",
|
||||
"选择同步来源": "Выберите источник синхронизации",
|
||||
@@ -2237,4 +2237,4 @@
|
||||
"可选,用于复现结果": "Необязательно, для воспроизводимых результатов",
|
||||
"随机种子 (留空为随机)": "Случайное зерно (оставьте пустым для случайного)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2197,7 +2197,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "Phù hợp cho mục đích sử dụng cá nhân, không cần đặt giá mô hình.",
|
||||
"适用于为多个用户提供服务的场景": "Phù hợp cho các kịch bản cung cấp dịch vụ cho nhiều người dùng.",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "Phù hợp cho các kịch bản hiển thị chức năng hệ thống, cung cấp bản demo chức năng cơ bản.",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Thích ứng với các hậu tố -thinking, -thinking-budget number và -nothinking",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Thích ứng với các hậu tố -thinking, -thinking-budget number, -nothinking và -low/-medium/-high",
|
||||
"选择充值额度": "Chọn hạn ngạch nạp tiền",
|
||||
"选择同步来源": "Chọn nguồn đồng bộ",
|
||||
"选择同步渠道": "Chọn kênh đồng bộ",
|
||||
@@ -2737,4 +2737,4 @@
|
||||
"可选,用于复现结果": "Tùy chọn, để tái tạo kết quả",
|
||||
"随机种子 (留空为随机)": "Hạt giống ngẫu nhiên (để trống cho ngẫu nhiên)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1984,7 +1984,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "适用于个人使用的场景,不需要设置模型价格",
|
||||
"适用于为多个用户提供服务的场景": "适用于为多个用户提供服务的场景",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "适用于展示系统功能的场景,提供基础功能演示",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "适配 -thinking、-thinking-预算数字 和 -nothinking 后缀",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "适配 -thinking、-thinking-预算数字、-nothinking 以及 -low/-medium/-high 后缀",
|
||||
"选择充值额度": "选择充值额度",
|
||||
"选择分组": "选择分组",
|
||||
"选择同步来源": "选择同步来源",
|
||||
@@ -2204,4 +2204,4 @@
|
||||
"可选,用于复现结果": "可选,用于复现结果",
|
||||
"随机种子 (留空为随机)": "随机种子 (留空为随机)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user