mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-03 07:51:07 +00:00
Compare commits
65 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e69c98b42 | ||
|
|
ca29fc5702 | ||
|
|
fca015c6c4 | ||
|
|
23292a5ae9 | ||
|
|
e346f0bf16 | ||
|
|
cae05c068c | ||
|
|
78c10209c0 | ||
|
|
4ffd54c50d | ||
|
|
08466358b2 | ||
|
|
5212fbd73d | ||
|
|
b0e120dcab | ||
|
|
9561c7b50f | ||
|
|
1cb2b6f882 | ||
|
|
5889571108 | ||
|
|
2e33948842 | ||
|
|
d1aaa07ad7 | ||
|
|
ea70c20f8e | ||
|
|
c7539d11a0 | ||
|
|
3ebc713327 | ||
|
|
72d2a94b0d | ||
|
|
12a5c7ce5e | ||
|
|
5eae6a3874 | ||
|
|
7b108a6900 | ||
|
|
3d282ac548 | ||
|
|
121746a79e | ||
|
|
c3c119a9b4 | ||
|
|
6d6e5b3337 | ||
|
|
d64205e35a | ||
|
|
0b9f6a58bc | ||
|
|
293a5de0f8 | ||
|
|
c07347f24f | ||
|
|
896e4ac671 | ||
|
|
7d1bad1b37 | ||
|
|
8e7be25429 | ||
|
|
2e37347851 | ||
|
|
45556c961f | ||
|
|
ffc45a756e | ||
|
|
48635360cd | ||
|
|
e7e5cc2c05 | ||
|
|
0c051e968f | ||
|
|
f5b409d74f | ||
|
|
509d1f633a | ||
|
|
0c6d890f6e | ||
|
|
2f7eebcd10 | ||
|
|
3954feb993 | ||
|
|
d3ca454c3b | ||
|
|
46aca8fad3 | ||
|
|
86aeb72549 | ||
|
|
4dbdbdec1d | ||
|
|
b6a02d8303 | ||
|
|
36a739e777 | ||
|
|
98f92f990a | ||
|
|
3f7ea1fd83 | ||
|
|
f6e7a2344b | ||
|
|
3257723a55 | ||
|
|
b19b2d62df | ||
|
|
f9c8624f2c | ||
|
|
6c8253156b | ||
|
|
a66b314f5b | ||
|
|
e29ff0060d | ||
|
|
d4a2c2ab54 | ||
|
|
ded463ee57 | ||
|
|
e337936227 | ||
|
|
c6125eccb1 | ||
|
|
138810f19c |
@@ -14,7 +14,7 @@ ENV GO111MODULE=on CGO_ENABLED=0
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
ENV GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH:-amd64}
|
||||
|
||||
ENV GOEXPERIMENT=greenteagc
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
@@ -25,10 +25,11 @@ COPY . .
|
||||
COPY --from=builder /build/dist ./web/dist
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
FROM alpine
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apk upgrade --no-cache \
|
||||
&& apk add --no-cache ca-certificates tzdata \
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates tzdata libasan8 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& update-ca-certificates
|
||||
|
||||
COPY --from=builder2 /build/new-api /
|
||||
|
||||
@@ -238,6 +238,7 @@ docker run --name new-api -d --restart always \
|
||||
- `gemini-2.5-flash-nothinking` - Disable thinking mode
|
||||
- `gemini-2.5-pro-thinking` - Enable thinking mode
|
||||
- `gemini-2.5-pro-thinking-128` - Enable thinking mode with thinking budget of 128 tokens
|
||||
- You can also append `-low`, `-medium`, or `-high` to any Gemini model name to request the corresponding reasoning effort (no extra thinking-budget suffix needed).
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -234,6 +234,7 @@ docker run --name new-api -d --restart always \
|
||||
- `gemini-2.5-flash-nothinking` - Désactiver le mode de pensée
|
||||
- `gemini-2.5-pro-thinking` - Activer le mode de pensée
|
||||
- `gemini-2.5-pro-thinking-128` - Activer le mode de pensée avec budget de pensée de 128 tokens
|
||||
- Vous pouvez également ajouter les suffixes `-low`, `-medium` ou `-high` aux modèles Gemini pour fixer le niveau d’effort de raisonnement (sans suffixe de budget supplémentaire).
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -243,6 +243,7 @@ docker run --name new-api -d --restart always \
|
||||
- `gemini-2.5-flash-nothinking` - 思考モードを無効にする
|
||||
- `gemini-2.5-pro-thinking` - 思考モードを有効にする
|
||||
- `gemini-2.5-pro-thinking-128` - 思考モードを有効にし、思考予算を128トークンに設定する
|
||||
- Gemini モデル名の末尾に `-low` / `-medium` / `-high` を付けることで推論強度を直接指定できます(追加の思考予算サフィックスは不要です)。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -239,6 +239,7 @@ docker run --name new-api -d --restart always \
|
||||
- `gemini-2.5-flash-nothinking` - 禁用思考模式
|
||||
- `gemini-2.5-pro-thinking` - 启用思考模式
|
||||
- `gemini-2.5-pro-thinking-128` - 启用思考模式,并设置思考预算为128tokens
|
||||
- 也可以直接在 Gemini 模型名称后追加 `-low` / `-medium` / `-high` 来控制思考力度(无需再设置思考预算后缀)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -121,6 +121,9 @@ var BatchUpdateInterval int
|
||||
|
||||
var RelayTimeout int // unit is second
|
||||
|
||||
var RelayMaxIdleConns int
|
||||
var RelayMaxIdleConnsPerHost int
|
||||
|
||||
var GeminiSafetySetting string
|
||||
|
||||
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
|
||||
|
||||
@@ -90,6 +90,8 @@ func InitEnv() {
|
||||
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
|
||||
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
|
||||
RelayMaxIdleConns = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS", 500)
|
||||
RelayMaxIdleConnsPerHost = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS_PER_HOST", 100)
|
||||
|
||||
// Initialize string variables with GetEnvOrDefaultString
|
||||
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
@@ -129,6 +131,8 @@ func initConstantEnv() {
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
// 任务轮询时查询的最大数量
|
||||
constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
|
||||
|
||||
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
|
||||
if soraPatchStr != "" {
|
||||
|
||||
@@ -23,11 +23,11 @@ func Marshal(v any) ([]byte, error) {
|
||||
}
|
||||
|
||||
func GetJsonType(data json.RawMessage) string {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 {
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if len(trimmed) == 0 {
|
||||
return "unknown"
|
||||
}
|
||||
firstChar := bytes.TrimSpace(data)[0]
|
||||
firstChar := trimmed[0]
|
||||
switch firstChar {
|
||||
case '{':
|
||||
return "object"
|
||||
|
||||
@@ -17,6 +17,13 @@ var (
|
||||
"flux-",
|
||||
"flux.1-",
|
||||
}
|
||||
OpenAITextModels = []string{
|
||||
"gpt-",
|
||||
"o1",
|
||||
"o3",
|
||||
"o4",
|
||||
"chatgpt",
|
||||
}
|
||||
)
|
||||
|
||||
func IsOpenAIResponseOnlyModel(modelName string) bool {
|
||||
@@ -40,3 +47,13 @@ func IsImageGenerationModel(modelName string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsOpenAITextModel(modelName string) bool {
|
||||
modelName = strings.ToLower(modelName)
|
||||
for _, m := range OpenAITextModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -3,12 +3,19 @@ package common
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var (
|
||||
maskURLPattern = regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||
maskDomainPattern = regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
||||
maskIPPattern = regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||
)
|
||||
|
||||
func GetStringIfEmpty(str string, defaultValue string) string {
|
||||
@@ -19,12 +26,10 @@ func GetStringIfEmpty(str string, defaultValue string) string {
|
||||
}
|
||||
|
||||
func GetRandomString(length int) string {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
key := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
||||
if length <= 0 {
|
||||
return ""
|
||||
}
|
||||
return string(key)
|
||||
return lo.RandomString(length, lo.AlphanumericCharset)
|
||||
}
|
||||
|
||||
func MapToJsonStr(m map[string]interface{}) string {
|
||||
@@ -170,8 +175,7 @@ func maskHostForPlainDomain(domain string) string {
|
||||
// api.openai.com -> ***.***.com
|
||||
func MaskSensitiveInfo(str string) string {
|
||||
// Mask URLs
|
||||
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||
str = maskURLPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return urlStr
|
||||
@@ -224,14 +228,12 @@ func MaskSensitiveInfo(str string) string {
|
||||
})
|
||||
|
||||
// Mask domain names without protocol (like openai.com, www.openai.com)
|
||||
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
||||
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
|
||||
str = maskDomainPattern.ReplaceAllStringFunc(str, func(domain string) string {
|
||||
return maskHostForPlainDomain(domain)
|
||||
})
|
||||
|
||||
// Mask IP addresses
|
||||
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||
str = maskIPPattern.ReplaceAllString(str, "***.***.***.***")
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
@@ -3,8 +3,9 @@ package constant
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
ContextKeyTokenCountMeta ContextKey = "token_count_meta"
|
||||
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
||||
ContextKeyTokenCountMeta ContextKey = "token_count_meta"
|
||||
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
||||
ContextKeyEstimatedTokens ContextKey = "estimated_tokens"
|
||||
|
||||
ContextKeyOriginalModel ContextKey = "original_model"
|
||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||
|
||||
@@ -15,6 +15,7 @@ var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
var TaskQueryLimit int
|
||||
|
||||
// temporary variable for sora patch, will be removed in future
|
||||
var TaskPricePatches []string
|
||||
|
||||
@@ -15,6 +15,7 @@ const (
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
TaskActionFirstTailGenerate = "firstTailGenerate"
|
||||
TaskActionReferenceGenerate = "referenceGenerate"
|
||||
TaskActionRemix = "remixGenerate"
|
||||
)
|
||||
|
||||
var SunoModel2Action = map[string]string{
|
||||
|
||||
@@ -351,7 +351,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
info.SetEstimatePromptTokens(usage.PromptTokens)
|
||||
|
||||
quota := 0
|
||||
if !priceData.UsePrice {
|
||||
|
||||
@@ -165,6 +165,30 @@ func GetAllChannels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) {
|
||||
var headers http.Header
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
headers = GetClaudeAuthHeader(key)
|
||||
default:
|
||||
headers = GetAuthHeader(key)
|
||||
}
|
||||
|
||||
headerOverride := channel.GetHeaderOverride()
|
||||
for k, v := range headerOverride {
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid header override for key %s", k)
|
||||
}
|
||||
if strings.Contains(str, "{api_key}") {
|
||||
str = strings.ReplaceAll(str, "{api_key}", key)
|
||||
}
|
||||
headers.Set(k, str)
|
||||
}
|
||||
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
func FetchUpstreamModels(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
@@ -223,14 +247,13 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
|
||||
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
||||
var body []byte
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
body, err = GetResponseBody("GET", url, channel, GetClaudeAuthHeader(key))
|
||||
default:
|
||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
|
||||
headers, err := buildFetchModelsHeaders(channel, key)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := GetResponseBody("GET", url, channel, headers)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
|
||||
@@ -125,13 +125,13 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
}
|
||||
}
|
||||
|
||||
tokens, err := service.CountRequestToken(c, meta, relayInfo)
|
||||
tokens, err := service.EstimateRequestToken(c, meta, relayInfo)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
return
|
||||
}
|
||||
|
||||
relayInfo.SetPromptTokens(tokens)
|
||||
relayInfo.SetEstimatePromptTokens(tokens)
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
|
||||
if err != nil {
|
||||
@@ -285,7 +285,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
|
||||
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
if service.ShouldDisableChannel(channelError.ChannelType, err) && channelError.AutoBan {
|
||||
gopool.Go(func() {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
})
|
||||
|
||||
@@ -29,7 +29,7 @@ func UpdateTaskBulk() {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
common.SysLog("任务进度轮询开始")
|
||||
ctx := context.TODO()
|
||||
allTasks := model.GetAllUnFinishSyncTasks(500)
|
||||
allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
|
||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||
for _, t := range allTasks {
|
||||
platformTask[t.Platform] = append(platformTask[t.Platform], t)
|
||||
@@ -116,9 +116,10 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
if adaptor == nil {
|
||||
return errors.New("adaptor not found")
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
|
||||
"ids": taskIds,
|
||||
})
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
return err
|
||||
|
||||
@@ -67,6 +67,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
@@ -76,7 +77,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
})
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ func AddToken(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
if len(token.Name) > 50 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌名称过长",
|
||||
@@ -208,7 +208,7 @@ func UpdateToken(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
if len(token.Name) > 50 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌名称过长",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -75,11 +77,22 @@ func VideoProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
var videoURL string
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
proxy := channel.GetSetting().Proxy
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to create proxy client",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, "", nil)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 60*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
|
||||
@@ -35,10 +35,11 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string)
|
||||
return "", fmt.Errorf("api key not available for task")
|
||||
}
|
||||
|
||||
proxy := channel.GetSetting().Proxy
|
||||
resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
|
||||
"task_id": task.TaskID,
|
||||
"action": task.Action,
|
||||
})
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch task failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
# API 鉴权文档
|
||||
|
||||
## 认证方式
|
||||
|
||||
### Access Token
|
||||
|
||||
对于需要鉴权的 API 接口,必须同时提供以下两个请求头来进行 Access Token 认证:
|
||||
|
||||
1. **请求头中的 `Authorization` 字段**
|
||||
|
||||
将 Access Token 放置于 HTTP 请求头部的 `Authorization` 字段中,格式如下:
|
||||
|
||||
```
|
||||
Authorization: <your_access_token>
|
||||
```
|
||||
|
||||
其中 `<your_access_token>` 需要替换为实际的 Access Token 值。
|
||||
|
||||
2. **请求头中的 `New-Api-User` 字段**
|
||||
|
||||
将用户 ID 放置于 HTTP 请求头部的 `New-Api-User` 字段中,格式如下:
|
||||
|
||||
```
|
||||
New-Api-User: <your_user_id>
|
||||
```
|
||||
|
||||
其中 `<your_user_id>` 需要替换为实际的用户 ID。
|
||||
|
||||
**注意:**
|
||||
|
||||
* **必须同时提供 `Authorization` 和 `New-Api-User` 两个请求头才能通过鉴权。**
|
||||
* 如果只提供其中一个请求头,或者两个请求头都未提供,则会返回 `401 Unauthorized` 错误。
|
||||
* 如果 `Authorization` 中的 Access Token 无效,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,access token 无效”。
|
||||
* 如果 `New-Api-User` 中的用户 ID 与 Access Token 不匹配,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,与登录用户不匹配,请重新登录”。
|
||||
* 如果没有提供 `New-Api-User` 请求头,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,未提供 New-Api-User”。
|
||||
* 如果 `New-Api-User` 请求头格式错误,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,New-Api-User 格式错误”。
|
||||
* 如果用户已被禁用,则会返回 `403 Forbidden` 错误,并提示“用户已被封禁”。
|
||||
* 如果用户权限不足,则会返回 `403 Forbidden` 错误,并提示“无权进行此操作,权限不足”。
|
||||
* 如果用户信息无效,则会返回 `403 Forbidden` 错误,并提示“无权进行此操作,用户信息无效”。
|
||||
|
||||
## Curl 示例
|
||||
|
||||
假设您的 Access Token 为 `access_token`,用户 ID 为 `123`,要访问的 API 接口为 `/api/user/self`,则可以使用以下 curl 命令:
|
||||
|
||||
```bash
|
||||
curl -X GET \
|
||||
-H "Authorization: access_token" \
|
||||
-H "New-Api-User: 123" \
|
||||
https://your-domain.com/api/user/self
|
||||
```
|
||||
|
||||
请将 `access_token`、`123` 和 `https://your-domain.com` 替换为实际的值。
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
# New API – Web 界面后端接口文档
|
||||
|
||||
> 本文档汇总了 **New API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
|
||||
>
|
||||
> 接口前缀统一为 `https://<your-domain>`,以下仅列出 **路径**、**HTTP 方法**、**鉴权要求** 与 **功能简介**。
|
||||
>
|
||||
> 鉴权级别说明:
|
||||
> * **公开** – 不需要登录即可调用
|
||||
> * **用户** – 需携带用户 Token(`middleware.UserAuth`)
|
||||
> * **管理员** – 需管理员 Token(`middleware.AdminAuth`)
|
||||
> * **Root** – 仅限最高权限 Root 用户(`middleware.RootAuth`)
|
||||
|
||||
---
|
||||
|
||||
## 1. 初始化 / 系统状态
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/setup | 公开 | 获取系统初始化状态 |
|
||||
| POST | /api/setup | 公开 | 完成首次安装向导 |
|
||||
| GET | /api/status | 公开 | 获取运行状态摘要 |
|
||||
| GET | /api/uptime/status | 公开 | Uptime-Kuma 兼容状态探针 |
|
||||
| GET | /api/status/test | 管理员 | 测试后端与依赖组件是否正常 |
|
||||
|
||||
## 2. 公共信息
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/models | 用户 | 获取前端可用模型列表 |
|
||||
| GET | /api/notice | 公开 | 获取公告栏内容 |
|
||||
| GET | /api/about | 公开 | 关于页面信息 |
|
||||
| GET | /api/home_page_content | 公开 | 首页自定义内容 |
|
||||
| GET | /api/pricing | 可匿名/用户 | 价格与套餐信息 |
|
||||
| GET | /api/ratio_config | 公开 | 模型倍率配置(仅公开字段) |
|
||||
|
||||
## 3. 邮件 / 身份验证
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/verification | 公开 (限流) | 发送邮箱验证邮件 |
|
||||
| GET | /api/reset_password | 公开 (限流) | 发送重置密码邮件 |
|
||||
| POST | /api/user/reset | 公开 | 提交重置密码请求 |
|
||||
|
||||
## 4. OAuth / 第三方登录
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
|
||||
| GET | /api/oauth/discord | 公开 | Discord 通用 OAuth 跳转 |
|
||||
| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
|
||||
| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
|
||||
| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
|
||||
| GET | /api/oauth/wechat/bind | 公开 | 微信账户绑定 |
|
||||
| GET | /api/oauth/email/bind | 公开 | 邮箱绑定 |
|
||||
| GET | /api/oauth/telegram/login | 公开 | Telegram 登录 |
|
||||
| GET | /api/oauth/telegram/bind | 公开 | Telegram 账户绑定 |
|
||||
| GET | /api/oauth/state | 公开 | 获取随机 state(防 CSRF) |
|
||||
|
||||
## 5. 用户模块
|
||||
### 5.1 账号注册/登录
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| POST | /api/user/register | 公开 | 注册新账号 |
|
||||
| POST | /api/user/login | 公开 | 用户登录 |
|
||||
| GET | /api/user/logout | 用户 | 退出登录 |
|
||||
| GET | /api/user/epay/notify | 公开 | Epay 支付回调 |
|
||||
| GET | /api/user/groups | 公开 | 列出所有分组(无鉴权版) |
|
||||
|
||||
### 5.2 用户自身操作 (需登录)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/user/self/groups | 用户 | 获取自己所在分组 |
|
||||
| GET | /api/user/self | 用户 | 获取个人资料 |
|
||||
| GET | /api/user/models | 用户 | 获取模型可见性 |
|
||||
| PUT | /api/user/self | 用户 | 修改个人资料 |
|
||||
| DELETE | /api/user/self | 用户 | 注销账号 |
|
||||
| GET | /api/user/token | 用户 | 生成用户级别 Access Token |
|
||||
| GET | /api/user/aff | 用户 | 获取推广码信息 |
|
||||
| POST | /api/user/topup | 用户 | 余额直充 |
|
||||
| POST | /api/user/pay | 用户 | 提交支付订单 |
|
||||
| POST | /api/user/amount | 用户 | 余额支付 |
|
||||
| POST | /api/user/aff_transfer | 用户 | 推广额度转账 |
|
||||
| PUT | /api/user/setting | 用户 | 更新用户设置 |
|
||||
|
||||
### 5.3 管理员用户管理
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/user/ | 管理员 | 获取全部用户列表 |
|
||||
| GET | /api/user/search | 管理员 | 搜索用户 |
|
||||
| GET | /api/user/:id | 管理员 | 获取单个用户信息 |
|
||||
| POST | /api/user/ | 管理员 | 创建用户 |
|
||||
| POST | /api/user/manage | 管理员 | 冻结/重置等管理操作 |
|
||||
| PUT | /api/user/ | 管理员 | 更新用户 |
|
||||
| DELETE | /api/user/:id | 管理员 | 删除用户 |
|
||||
|
||||
## 6. 站点选项 (Root)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/option/ | Root | 获取全局配置 |
|
||||
| PUT | /api/option/ | Root | 更新全局配置 |
|
||||
| POST | /api/option/rest_model_ratio | Root | 重置模型倍率 |
|
||||
| POST | /api/option/migrate_console_setting | Root | 迁移旧版控制台配置 |
|
||||
|
||||
## 7. 模型倍率同步 (Root)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/ratio_sync/channels | Root | 获取可同步渠道列表 |
|
||||
| POST | /api/ratio_sync/fetch | Root | 从上游拉取倍率 |
|
||||
|
||||
## 8. 渠道管理 (管理员)
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| GET | /api/channel/ | 获取渠道列表 |
|
||||
| GET | /api/channel/search | 搜索渠道 |
|
||||
| GET | /api/channel/models | 查询渠道模型能力 |
|
||||
| GET | /api/channel/models_enabled | 查询启用模型能力 |
|
||||
| GET | /api/channel/:id | 获取单个渠道 |
|
||||
| GET | /api/channel/test | 批量测试渠道连通性 |
|
||||
| GET | /api/channel/test/:id | 单个渠道测试 |
|
||||
| GET | /api/channel/update_balance | 批量刷新余额 |
|
||||
| GET | /api/channel/update_balance/:id | 单个刷新余额 |
|
||||
| POST | /api/channel/ | 新增渠道 |
|
||||
| PUT | /api/channel/ | 更新渠道 |
|
||||
| DELETE | /api/channel/disabled | 删除已禁用渠道 |
|
||||
| POST | /api/channel/tag/disabled | 批量禁用标签渠道 |
|
||||
| POST | /api/channel/tag/enabled | 批量启用标签渠道 |
|
||||
| PUT | /api/channel/tag | 编辑渠道标签 |
|
||||
| DELETE | /api/channel/:id | 删除渠道 |
|
||||
| POST | /api/channel/batch | 批量删除渠道 |
|
||||
| POST | /api/channel/fix | 修复渠道能力表 |
|
||||
| GET | /api/channel/fetch_models/:id | 拉取单渠道模型 |
|
||||
| POST | /api/channel/fetch_models | 拉取全部渠道模型 |
|
||||
| POST | /api/channel/batch/tag | 批量设置渠道标签 |
|
||||
| GET | /api/channel/tag/models | 根据标签获取模型 |
|
||||
| POST | /api/channel/copy/:id | 复制渠道 |
|
||||
|
||||
## 9. Token 管理
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/token/ | 用户 | 获取全部 Token |
|
||||
| GET | /api/token/search | 用户 | 搜索 Token |
|
||||
| GET | /api/token/:id | 用户 | 获取单个 Token |
|
||||
| POST | /api/token/ | 用户 | 创建 Token |
|
||||
| PUT | /api/token/ | 用户 | 更新 Token |
|
||||
| DELETE | /api/token/:id | 用户 | 删除 Token |
|
||||
| POST | /api/token/batch | 用户 | 批量删除 Token |
|
||||
|
||||
## 10. 兑换码管理 (管理员)
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| GET | /api/redemption/ | 获取兑换码列表 |
|
||||
| GET | /api/redemption/search | 搜索兑换码 |
|
||||
| GET | /api/redemption/:id | 获取单个兑换码 |
|
||||
| POST | /api/redemption/ | 创建兑换码 |
|
||||
| PUT | /api/redemption/ | 更新兑换码 |
|
||||
| DELETE | /api/redemption/invalid | 删除无效兑换码 |
|
||||
| DELETE | /api/redemption/:id | 删除兑换码 |
|
||||
|
||||
## 11. 日志
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/log/ | 管理员 | 获取全部日志 |
|
||||
| DELETE | /api/log/ | 管理员 | 删除历史日志 |
|
||||
| GET | /api/log/stat | 管理员 | 日志统计 |
|
||||
| GET | /api/log/self/stat | 用户 | 我的日志统计 |
|
||||
| GET | /api/log/search | 管理员 | 搜索全部日志 |
|
||||
| GET | /api/log/self | 用户 | 获取我的日志 |
|
||||
| GET | /api/log/self/search | 用户 | 搜索我的日志 |
|
||||
| GET | /api/log/token | 公开 | 根据 Token 查询日志(支持 CORS) |
|
||||
|
||||
## 12. 数据统计
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/data/ | 管理员 | 全站用量按日期统计 |
|
||||
| GET | /api/data/self | 用户 | 我的用量按日期统计 |
|
||||
|
||||
## 13. 分组
|
||||
| GET | /api/group/ | 管理员 | 获取全部分组列表 |
|
||||
|
||||
## 14. Midjourney 任务
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/mj/self | 用户 | 获取自己的 MJ 任务 |
|
||||
| GET | /api/mj/ | 管理员 | 获取全部 MJ 任务 |
|
||||
|
||||
## 15. 任务中心
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/task/self | 用户 | 获取我的任务 |
|
||||
| GET | /api/task/ | 管理员 | 获取全部任务 |
|
||||
|
||||
## 16. 账户计费面板 (Dashboard)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /dashboard/billing/subscription | 用户 Token | 获取订阅额度信息 |
|
||||
| GET | /v1/dashboard/billing/subscription | 同上 | 兼容 OpenAI SDK 路径 |
|
||||
| GET | /dashboard/billing/usage | 用户 Token | 获取使用量信息 |
|
||||
| GET | /v1/dashboard/billing/usage | 同上 | 兼容 OpenAI SDK 路径 |
|
||||
|
||||
---
|
||||
|
||||
> **更新日期**:2025.07.17
|
||||
@@ -1,82 +0,0 @@
|
||||
# Midjourney Proxy API文档
|
||||
|
||||
**简介**:Midjourney Proxy API文档
|
||||
|
||||
## 接口列表
|
||||
支持的接口如下:
|
||||
+ [x] /mj/submit/imagine
|
||||
+ [x] /mj/submit/change
|
||||
+ [x] /mj/submit/blend
|
||||
+ [x] /mj/submit/describe
|
||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||
+ [x] /task/list-by-condition
|
||||
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||
+ [x] /mj/submit/modal
|
||||
+ [x] /mj/submit/shorten
|
||||
+ [x] /mj/task/{id}/image-seed
|
||||
+ [x] /mj/insight-face/swap (InsightFace)
|
||||
|
||||
## 模型列表
|
||||
|
||||
### midjourney-proxy支持
|
||||
|
||||
- mj_imagine (绘图)
|
||||
- mj_variation (变换)
|
||||
- mj_reroll (重绘)
|
||||
- mj_blend (混合)
|
||||
- mj_upscale (放大)
|
||||
- mj_describe (图生文)
|
||||
|
||||
### 仅midjourney-proxy-plus支持
|
||||
|
||||
- mj_zoom (比例变焦)
|
||||
- mj_shorten (提示词缩短)
|
||||
- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
|
||||
- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
|
||||
- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
|
||||
- mj_high_variation (强变换)
|
||||
- mj_low_variation (弱变换)
|
||||
- mj_pan (平移)
|
||||
- swap_face (换脸)
|
||||
|
||||
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
|
||||
```json
|
||||
{
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_modal": 0.1,
|
||||
"mj_zoom": 0.1,
|
||||
"mj_shorten": 0.1,
|
||||
"mj_high_variation": 0.1,
|
||||
"mj_low_variation": 0.1,
|
||||
"mj_pan": 0.1,
|
||||
"mj_inpaint": 0,
|
||||
"mj_custom_zoom": 0,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
"swap_face": 0.05
|
||||
}
|
||||
```
|
||||
其中mj_inpaint和mj_custom_zoom的价格设置为0,是因为这两个模型需要搭配mj_modal使用,所以价格由mj_modal决定。
|
||||
|
||||
## 渠道设置
|
||||
|
||||
### 对接 midjourney-proxy(plus)
|
||||
|
||||
1.
|
||||
|
||||
部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
|
||||
|
||||
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
|
||||
,模型请参考上方模型列表
|
||||
3. **代理**填写midjourney-proxy部署的地址,例如:http://localhost:8080
|
||||
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
|
||||
|
||||
### 对接上游new api
|
||||
|
||||
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
|
||||
2. **代理**填写上游new api的地址,例如:http://localhost:3000
|
||||
3. 密钥填写上游new api的密钥
|
||||
@@ -1,62 +0,0 @@
|
||||
# Rerank API文档
|
||||
|
||||
**简介**:Rerank API文档
|
||||
|
||||
## 接入Dify
|
||||
模型供应商选择Jina,按要求填写模型信息即可接入Dify。
|
||||
|
||||
## 请求方式
|
||||
|
||||
Post: /v1/rerank
|
||||
|
||||
Request:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "jina-reranker-v2-base-multilingual",
|
||||
"query": "What is the capital of the United States?",
|
||||
"top_n": 3,
|
||||
"documents": [
|
||||
"Carson City is the capital city of the American state of Nevada.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
|
||||
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
|
||||
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"document": {
|
||||
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
|
||||
},
|
||||
"index": 2,
|
||||
"relevance_score": 0.9999702
|
||||
},
|
||||
{
|
||||
"document": {
|
||||
"text": "Carson City is the capital city of the American state of Nevada."
|
||||
},
|
||||
"index": 0,
|
||||
"relevance_score": 0.67800725
|
||||
},
|
||||
{
|
||||
"document": {
|
||||
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
|
||||
},
|
||||
"index": 3,
|
||||
"relevance_score": 0.02800752
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 158,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 158
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -1,44 +0,0 @@
|
||||
# Suno API文档
|
||||
|
||||
**简介**:Suno API文档
|
||||
|
||||
## 接口列表
|
||||
支持的接口如下:
|
||||
+ [x] /suno/submit/music
|
||||
+ [x] /suno/submit/lyrics
|
||||
+ [x] /suno/fetch
|
||||
+ [x] /suno/fetch/:id
|
||||
|
||||
## 模型列表
|
||||
|
||||
### Suno API支持
|
||||
|
||||
- suno_music (自定义模式、灵感模式、续写)
|
||||
- suno_lyrics (生成歌词)
|
||||
|
||||
|
||||
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
|
||||
```json
|
||||
{
|
||||
"suno_music": 0.3,
|
||||
"suno_lyrics": 0.01
|
||||
}
|
||||
```
|
||||
|
||||
## 渠道设置
|
||||
|
||||
### 对接 Suno API
|
||||
|
||||
1.
|
||||
部署 Suno API,并配置好suno账号等(强烈建议设置密钥),[项目地址](https://github.com/Suno-API/Suno-API)
|
||||
|
||||
2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
|
||||
,模型请参考上方模型列表
|
||||
3. **代理**填写 Suno API 部署的地址,例如:http://localhost:8080
|
||||
4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
|
||||
|
||||
### 对接上游new api
|
||||
|
||||
1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
|
||||
2. **代理**填写上游new api的地址,例如:http://localhost:3000
|
||||
3. 密钥填写上游new api的密钥
|
||||
7818
docs/openapi/api.json
Normal file
7818
docs/openapi/api.json
Normal file
File diff suppressed because it is too large
Load Diff
7141
docs/openapi/relay.json
Normal file
7141
docs/openapi/relay.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -142,7 +142,7 @@ type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
// TODO Conflict with thinkingbudget.
|
||||
ThinkingLevel json.RawMessage `json:"thinkingLevel,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON allows GeminiThinkingConfig to accept both snake_case and camelCase fields.
|
||||
@@ -150,9 +150,9 @@ func (c *GeminiThinkingConfig) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiThinkingConfig
|
||||
var aux struct {
|
||||
Alias
|
||||
IncludeThoughtsSnake *bool `json:"include_thoughts,omitempty"`
|
||||
ThinkingBudgetSnake *int `json:"thinking_budget,omitempty"`
|
||||
ThinkingLevelSnake json.RawMessage `json:"thinking_level,omitempty"`
|
||||
IncludeThoughtsSnake *bool `json:"include_thoughts,omitempty"`
|
||||
ThinkingBudgetSnake *int `json:"thinking_budget,omitempty"`
|
||||
ThinkingLevelSnake string `json:"thinking_level,omitempty"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
@@ -169,7 +169,7 @@ func (c *GeminiThinkingConfig) UnmarshalJSON(data []byte) error {
|
||||
c.ThinkingBudget = aux.ThinkingBudgetSnake
|
||||
}
|
||||
|
||||
if len(aux.ThinkingLevelSnake) > 0 {
|
||||
if aux.ThinkingLevelSnake != "" {
|
||||
c.ThinkingLevel = aux.ThinkingLevelSnake
|
||||
}
|
||||
|
||||
|
||||
@@ -27,8 +27,11 @@ type ImageRequest struct {
|
||||
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
||||
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||
// Stream bool `json:"stream,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
Image json.RawMessage `json:"image,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
// zhipu 4v
|
||||
WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"`
|
||||
UserId json.RawMessage `json:"user_id,omitempty"`
|
||||
Image json.RawMessage `json:"image,omitempty"`
|
||||
// 用匿名参数接收额外参数
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
}
|
||||
@@ -169,7 +172,7 @@ type ImageResponse struct {
|
||||
Extra any `json:"extra,omitempty"`
|
||||
}
|
||||
type ImageData struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
B64Json string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
Url string `json:"url"`
|
||||
B64Json string `json:"b64_json"`
|
||||
RevisedPrompt string `json:"revised_prompt"`
|
||||
}
|
||||
|
||||
@@ -83,6 +83,7 @@ type GeneralOpenAIRequest struct {
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"`
|
||||
ChatTemplateKwargs json.RawMessage `json:"chat_template_kwargs,omitempty"`
|
||||
// ollama Params
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
// baidu v2
|
||||
|
||||
13
go.mod
13
go.mod
@@ -33,7 +33,7 @@ require (
|
||||
github.com/mewkiz/flac v1.0.13
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pquerna/otp v1.5.0
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/samber/lo v1.52.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/stripe/stripe-go/v81 v81.4.0
|
||||
@@ -99,6 +99,7 @@ require (
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
@@ -110,13 +111,13 @@ require (
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.21.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.22.5 // indirect
|
||||
modernc.org/mathutil v1.5.0 // indirect
|
||||
modernc.org/memory v1.5.0 // indirect
|
||||
modernc.org/sqlite v1.23.1 // indirect
|
||||
modernc.org/libc v1.66.10 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.40.1 // indirect
|
||||
)
|
||||
|
||||
15
go.sum
15
go.sum
@@ -120,6 +120,7 @@ github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -193,6 +194,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
@@ -219,6 +222,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw=
|
||||
github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
@@ -285,6 +290,8 @@ golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
@@ -345,9 +352,17 @@ gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
|
||||
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
|
||||
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
|
||||
modernc.org/libc v1.66.10 h1:yZkb3YeLx4oynyR+iUsXsybsX4Ubx7MQlSYEw4yj59A=
|
||||
modernc.org/libc v1.66.10/go.mod h1:8vGSEwvoUoltr4dlywvHqjtAqHBaw0j1jI7iFBTAr2I=
|
||||
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
|
||||
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
|
||||
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
|
||||
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
|
||||
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
|
||||
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
|
||||
|
||||
@@ -181,6 +181,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/videos/") && strings.HasSuffix(c.Request.URL.Path, "/remix") {
|
||||
relayMode := relayconstant.RelayModeVideoSubmit
|
||||
c.Set("relay_mode", relayMode)
|
||||
shouldSelectChannel = false
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/videos") {
|
||||
//curl https://api.openai.com/v1/videos \
|
||||
// -H "Authorization: Bearer $OPENAI_API_KEY" \
|
||||
|
||||
@@ -47,7 +47,7 @@ type TaskAdaptor interface {
|
||||
GetChannelName() string
|
||||
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
|
||||
|
||||
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||
}
|
||||
|
||||
@@ -27,8 +27,6 @@ import (
|
||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||
// multipart/form-data
|
||||
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
// multipart/form-data
|
||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||
// websocket
|
||||
} else {
|
||||
|
||||
@@ -9,6 +9,7 @@ var ModelList = []string{
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
|
||||
@@ -673,7 +673,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
@@ -734,10 +734,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
||||
claudeInfo.Usage.CompletionTokens = completionTokens
|
||||
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeResponse.Completion, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
|
||||
@@ -74,7 +74,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -105,7 +105,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
response.Usage = *usage
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
@@ -142,10 +142,6 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
usage := service.ResponseText2Usage(c, cfResp.Result.Text, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
@@ -225,9 +225,9 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
usage.CompletionTokens = 0
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
usage.TotalTokens = info.GetEstimatePromptTokens()
|
||||
} else {
|
||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
||||
|
||||
@@ -246,7 +246,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
})
|
||||
helper.Done(c)
|
||||
if usage.TotalTokens == 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
}
|
||||
usage.CompletionTokens += nodeToken
|
||||
return usage, nil
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -57,139 +55,9 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type ImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"`
|
||||
ImageSize string `json:"imageSize,omitempty"`
|
||||
}
|
||||
|
||||
type SizeMapping struct {
|
||||
AspectRatio string
|
||||
ImageSize string
|
||||
}
|
||||
|
||||
type QualityMapping struct {
|
||||
Standard string
|
||||
HD string
|
||||
High string
|
||||
FourK string
|
||||
Auto string
|
||||
}
|
||||
|
||||
func getImageSizeMapping() QualityMapping {
|
||||
return QualityMapping{
|
||||
Standard: "1K",
|
||||
HD: "2K",
|
||||
High: "2K",
|
||||
FourK: "4K",
|
||||
Auto: "1K",
|
||||
}
|
||||
}
|
||||
|
||||
func getSizeMappings() map[string]SizeMapping {
|
||||
return map[string]SizeMapping{
|
||||
// Gemini 2.5 Flash Image - default 1K resolutions
|
||||
"1024x1024": {AspectRatio: "1:1", ImageSize: ""},
|
||||
"832x1248": {AspectRatio: "2:3", ImageSize: ""},
|
||||
"1248x832": {AspectRatio: "3:2", ImageSize: ""},
|
||||
"864x1184": {AspectRatio: "3:4", ImageSize: ""},
|
||||
"1184x864": {AspectRatio: "4:3", ImageSize: ""},
|
||||
"896x1152": {AspectRatio: "4:5", ImageSize: ""},
|
||||
"1152x896": {AspectRatio: "5:4", ImageSize: ""},
|
||||
"768x1344": {AspectRatio: "9:16", ImageSize: ""},
|
||||
"1344x768": {AspectRatio: "16:9", ImageSize: ""},
|
||||
"1536x672": {AspectRatio: "21:9", ImageSize: ""},
|
||||
|
||||
// Gemini 3 Pro Image Preview resolutions
|
||||
"1536x1024": {AspectRatio: "3:2", ImageSize: ""},
|
||||
"1024x1536": {AspectRatio: "2:3", ImageSize: ""},
|
||||
"1024x1792": {AspectRatio: "9:16", ImageSize: ""},
|
||||
"1792x1024": {AspectRatio: "16:9", ImageSize: ""},
|
||||
"2048x2048": {AspectRatio: "1:1", ImageSize: "2K"},
|
||||
"4096x4096": {AspectRatio: "1:1", ImageSize: "4K"},
|
||||
}
|
||||
}
|
||||
|
||||
func processSizeParameters(size, quality string) ImageConfig {
|
||||
config := ImageConfig{} // 默认为空值
|
||||
|
||||
if size != "" {
|
||||
if strings.Contains(size, ":") {
|
||||
config.AspectRatio = size // 直接设置,不与默认值比较
|
||||
} else {
|
||||
if mapping, exists := getSizeMappings()[size]; exists {
|
||||
if mapping.AspectRatio != "" {
|
||||
config.AspectRatio = mapping.AspectRatio
|
||||
}
|
||||
if mapping.ImageSize != "" {
|
||||
config.ImageSize = mapping.ImageSize
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if quality != "" {
|
||||
qualityMapping := getImageSizeMapping()
|
||||
switch strings.ToLower(strings.TrimSpace(quality)) {
|
||||
case "hd", "high":
|
||||
config.ImageSize = qualityMapping.HD
|
||||
case "4k":
|
||||
config.ImageSize = qualityMapping.FourK
|
||||
case "standard", "medium", "low", "auto", "1k":
|
||||
config.ImageSize = qualityMapping.Standard
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
|
||||
var content any
|
||||
if base64Data, err := relaycommon.GetImageBase64sFromForm(c); err == nil {
|
||||
content = []any{
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: request.Prompt,
|
||||
},
|
||||
dto.MediaContent{
|
||||
Type: dto.ContentTypeFile,
|
||||
File: &dto.MessageFile{
|
||||
FileData: base64Data.String(),
|
||||
},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
content = request.Prompt
|
||||
}
|
||||
|
||||
chatRequest := dto.GeneralOpenAIRequest{
|
||||
Model: request.Model,
|
||||
Messages: []dto.Message{
|
||||
{Role: "user", Content: content},
|
||||
},
|
||||
N: int(request.N),
|
||||
}
|
||||
|
||||
config := processSizeParameters(strings.TrimSpace(request.Size), request.Quality)
|
||||
|
||||
// 兼容 nano-banana 传quality[imageSize]会报错 An internal error has occurred. Please retry or report in https://developers.generativeai.google/guide/troubleshooting
|
||||
if slices.Contains([]string{"nano-banana", "gemini-2.5-flash-image"}, info.UpstreamModelName) {
|
||||
config.ImageSize = ""
|
||||
}
|
||||
|
||||
googleGenerationConfig := map[string]interface{}{
|
||||
"responseModalities": []string{"TEXT", "IMAGE"},
|
||||
"imageConfig": config,
|
||||
}
|
||||
|
||||
extraBody := map[string]interface{}{
|
||||
"google": map[string]interface{}{
|
||||
"generationConfig": googleGenerationConfig,
|
||||
},
|
||||
}
|
||||
chatRequest.ExtraBody, _ = json.Marshal(extraBody)
|
||||
|
||||
return a.ConvertOpenAIRequest(c, info, &chatRequest)
|
||||
if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
return nil, errors.New("not supported model for image generation")
|
||||
}
|
||||
|
||||
// convert size to aspect ratio but allow user to specify aspect ratio
|
||||
@@ -199,8 +67,17 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
if strings.Contains(size, ":") {
|
||||
aspectRatio = size
|
||||
} else {
|
||||
if mapping, exists := getSizeMappings()[size]; exists && mapping.AspectRatio != "" {
|
||||
aspectRatio = mapping.AspectRatio
|
||||
switch size {
|
||||
case "256x256", "512x512", "1024x1024":
|
||||
aspectRatio = "1:1"
|
||||
case "1536x1024":
|
||||
aspectRatio = "3:2"
|
||||
case "1024x1536":
|
||||
aspectRatio = "2:3"
|
||||
case "1024x1792":
|
||||
aspectRatio = "9:16"
|
||||
case "1792x1024":
|
||||
aspectRatio = "16:9"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -260,6 +137,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
} else if baseModel, level := parseThinkingLevelSuffix(info.UpstreamModelName); level != "" {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
}
|
||||
|
||||
@@ -381,10 +260,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return GeminiImageHandler(c, info, resp)
|
||||
}
|
||||
|
||||
if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
|
||||
return ChatImageHandler(c, info, resp)
|
||||
}
|
||||
|
||||
// check if the model is an embedding model
|
||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -70,12 +69,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
|
||||
println(string(responseBody))
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
|
||||
if info.IsGeminiBatchEmbedding {
|
||||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||
|
||||
@@ -19,8 +19,8 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -122,6 +122,14 @@ func clampThinkingBudgetByEffort(modelName string, effort string) int {
|
||||
return clampThinkingBudget(modelName, maxBudget)
|
||||
}
|
||||
|
||||
func parseThinkingLevelSuffix(modelName string) (string, string) {
|
||||
base, level, ok := reasoning.TrimEffortSuffix(modelName)
|
||||
if !ok {
|
||||
return modelName, ""
|
||||
}
|
||||
return base, level
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
@@ -178,12 +186,18 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
} else if _, level := parseThinkingLevelSuffix(modelName); level != "" {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
ThinkingLevel: level,
|
||||
}
|
||||
info.ReasoningEffort = level
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo, base64Data ...*relaycommon.Base64Data) (*dto.GeminiChatRequest, error) {
|
||||
func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
geminiRequest := dto.GeminiChatRequest{
|
||||
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
@@ -208,6 +222,7 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
|
||||
adaptorWithExtraBody := false
|
||||
|
||||
// patch extra_body
|
||||
if len(textRequest.ExtraBody) > 0 {
|
||||
if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
var extraBody map[string]interface{}
|
||||
@@ -240,13 +255,36 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
}
|
||||
}
|
||||
|
||||
if generationConfig, ok := googleBody["generationConfig"].(map[string]any); ok {
|
||||
generationConfigBytes, err := json.Marshal(generationConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal generationConfig: %w", err)
|
||||
// check error param name like imageConfig, should be image_config
|
||||
if _, hasErrorParam := googleBody["imageConfig"]; hasErrorParam {
|
||||
return nil, errors.New("extra_body.google.imageConfig is not supported, use extra_body.google.image_config instead")
|
||||
}
|
||||
|
||||
if imageConfig, ok := googleBody["image_config"].(map[string]interface{}); ok {
|
||||
// check error param name like aspectRatio, should be aspect_ratio
|
||||
if _, hasErrorParam := imageConfig["aspectRatio"]; hasErrorParam {
|
||||
return nil, errors.New("extra_body.google.image_config.aspectRatio is not supported, use extra_body.google.image_config.aspect_ratio instead")
|
||||
}
|
||||
if err := json.Unmarshal(generationConfigBytes, &geminiRequest.GenerationConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal generationConfig: %w", err)
|
||||
// check error param name like imageSize, should be image_size
|
||||
if _, hasErrorParam := imageConfig["imageSize"]; hasErrorParam {
|
||||
return nil, errors.New("extra_body.google.image_config.imageSize is not supported, use extra_body.google.image_config.image_size instead")
|
||||
}
|
||||
|
||||
// convert snake_case to camelCase for Gemini API
|
||||
geminiImageConfig := make(map[string]interface{})
|
||||
if aspectRatio, ok := imageConfig["aspect_ratio"]; ok {
|
||||
geminiImageConfig["aspectRatio"] = aspectRatio
|
||||
}
|
||||
if imageSize, ok := imageConfig["image_size"]; ok {
|
||||
geminiImageConfig["imageSize"] = imageSize
|
||||
}
|
||||
|
||||
if len(geminiImageConfig) > 0 {
|
||||
imageConfigBytes, err := common.Marshal(geminiImageConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal image_config: %w", err)
|
||||
}
|
||||
geminiRequest.GenerationConfig.ImageConfig = imageConfigBytes
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -422,9 +460,68 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
if part.Text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
// check markdown image 
|
||||
// 使用字符串查找而非正则,避免大文本性能问题
|
||||
text := part.Text
|
||||
hasMarkdownImage := false
|
||||
for {
|
||||
// 快速检查是否包含 markdown 图片标记
|
||||
startIdx := strings.Index(text, "
|
||||
if bracketIdx == -1 {
|
||||
break
|
||||
}
|
||||
bracketIdx += startIdx
|
||||
// 找到闭合的 )
|
||||
closeIdx := strings.Index(text[bracketIdx+2:], ")")
|
||||
if closeIdx == -1 {
|
||||
break
|
||||
}
|
||||
closeIdx += bracketIdx + 2
|
||||
|
||||
hasMarkdownImage = true
|
||||
// 添加图片前的文本
|
||||
if startIdx > 0 {
|
||||
textBefore := text[:startIdx]
|
||||
if textBefore != "" {
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
Text: textBefore,
|
||||
})
|
||||
}
|
||||
}
|
||||
// 提取 data URL (从 "](" 后面开始,到 ")" 之前)
|
||||
dataUrl := text[bracketIdx+2 : closeIdx]
|
||||
imageNum += 1
|
||||
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
||||
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
||||
}
|
||||
format, base64String, err := service.DecodeBase64FileData(dataUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode markdown base64 image data failed: %s", err.Error())
|
||||
}
|
||||
imgPart := dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
}
|
||||
if shouldAttachThoughtSignature {
|
||||
imgPart.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
|
||||
}
|
||||
parts = append(parts, imgPart)
|
||||
// 继续处理剩余文本
|
||||
text = text[closeIdx+1:]
|
||||
}
|
||||
// 添加剩余文本或原始文本(如果没有找到 markdown 图片)
|
||||
if !hasMarkdownImage {
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
}
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
|
||||
@@ -464,11 +561,10 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
})
|
||||
}
|
||||
} else if part.Type == dto.ContentTypeFile {
|
||||
file := part.GetFile()
|
||||
if file.FileId != "" {
|
||||
if part.GetFile().FileId != "" {
|
||||
return nil, fmt.Errorf("only base64 file is supported in gemini")
|
||||
}
|
||||
format, base64String, err := service.DecodeBase64FileData(file.FileData)
|
||||
format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||
}
|
||||
@@ -1033,7 +1129,7 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
if usage.CompletionTokens <= 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
@@ -1206,11 +1302,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
// Google has not yet clarified how embedding models will be billed
|
||||
// refer to openai billing method to use input tokens billing
|
||||
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
openAIResponse.Usage = *usage
|
||||
|
||||
jsonResponse, jsonErr := common.Marshal(openAIResponse)
|
||||
@@ -1275,70 +1367,3 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func convertToOaiImageResponse(geminiResponse *dto.GeminiChatResponse) (*dto.ImageResponse, error) {
|
||||
openAIResponse := &dto.ImageResponse{
|
||||
Created: common.GetTimestamp(),
|
||||
Data: make([]dto.ImageData, 0),
|
||||
}
|
||||
|
||||
// extract images from candidates' inlineData
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
|
||||
B64Json: part.InlineData.Data,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(openAIResponse.Data) == 0 {
|
||||
return nil, errors.New("no images found in response")
|
||||
}
|
||||
|
||||
return openAIResponse, nil
|
||||
}
|
||||
|
||||
func ChatImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("ChatImageHandler response:", string(responseBody))
|
||||
}
|
||||
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIResponse, err := convertToOaiImageResponse(&geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -163,7 +163,7 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
}
|
||||
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: int(minimaxResp.ExtraInfo.UsageCharacters),
|
||||
}
|
||||
|
||||
@@ -306,10 +306,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
request.Temperature = nil
|
||||
}
|
||||
|
||||
// gpt-5系列模型适配 归零不再支持的参数
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
if info.UpstreamModelName != "gpt-5-chat-latest" {
|
||||
request.Temperature = nil
|
||||
}
|
||||
request.Temperature = nil
|
||||
request.TopP = 0 // oai 的 top_p 默认值是 1.0,但是为了 omitempty 属性直接不传,这里显式设置为 0
|
||||
request.LogProbs = false
|
||||
}
|
||||
|
||||
// 转换模型推理力度后缀
|
||||
|
||||
@@ -183,7 +183,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
@@ -245,9 +245,9 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
}
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: info.PromptTokens + completionTokens,
|
||||
TotalTokens: info.GetEstimatePromptTokens() + completionTokens,
|
||||
}
|
||||
usageModified = true
|
||||
}
|
||||
@@ -336,8 +336,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
// and can be terminated directly.
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
usage.TotalTokens = info.GetEstimatePromptTokens()
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
@@ -383,7 +383,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
usage.CompletionTokens = 0
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return nil, usage
|
||||
|
||||
@@ -141,7 +141,7 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
}
|
||||
|
||||
if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
}
|
||||
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
@@ -81,7 +81,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
usage, err = palmHandler(c, info, resp)
|
||||
}
|
||||
|
||||
@@ -121,13 +121,8 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
|
||||
}, resp.StatusCode)
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: info.PromptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
usage := service.ResponseText2Usage(c, palmResponse.Candidates[0].Content, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
fullTextResponse.Usage = *usage
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
@@ -135,5 +130,5 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return &usage, nil
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -393,7 +393,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask 查询任务状态
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -408,7 +408,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -146,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -163,7 +163,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -24,9 +24,13 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// VideoGenerationConfig represents the video generation configuration
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
// GeminiVideoGenerationConfig represents the video generation configuration
|
||||
// Based on: https://ai.google.dev/gemini-api/docs/video
|
||||
type VideoGenerationConfig struct {
|
||||
type GeminiVideoGenerationConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"` // "16:9" or "9:16"
|
||||
DurationSeconds float64 `json:"durationSeconds,omitempty"` // 4, 6, or 8 (as number)
|
||||
NegativePrompt string `json:"negativePrompt,omitempty"` // unwanted elements
|
||||
@@ -34,21 +38,15 @@ type VideoGenerationConfig struct {
|
||||
Resolution string `json:"resolution,omitempty"` // video resolution
|
||||
}
|
||||
|
||||
type Image struct {
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded,omitempty"`
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
// GeminiVideoRequest represents a single video generation instance
|
||||
type GeminiVideoRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type VideoRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Image *Image `json:"image,omitempty"`
|
||||
LastFrame *Image `json:"lastFrame,omitempty"`
|
||||
}
|
||||
|
||||
// VideoPayload represents the complete video generation request payload
|
||||
type VideoPayload struct {
|
||||
Instances []VideoRequest `json:"instances"`
|
||||
Parameters VideoGenerationConfig `json:"parameters,omitempty"`
|
||||
// GeminiVideoPayload represents the complete video generation request payload
|
||||
type GeminiVideoPayload struct {
|
||||
Instances []GeminiVideoRequest `json:"instances"`
|
||||
Parameters GeminiVideoGenerationConfig `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type submitResponse struct {
|
||||
@@ -77,8 +75,6 @@ type operationResponse struct {
|
||||
URI string `json:"uri"`
|
||||
} `json:"video"`
|
||||
} `json:"generatedSamples"`
|
||||
RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
|
||||
RaiMediaFilteredReasons []string `json:"raiMediaFilteredReasons"`
|
||||
} `json:"generateVideoResponse"`
|
||||
} `json:"response"`
|
||||
Error struct {
|
||||
@@ -104,7 +100,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
// Use the standard validation method for TaskSubmitReq
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
@@ -140,21 +137,13 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
// Create structured video generation request
|
||||
body := VideoPayload{
|
||||
Instances: []VideoRequest{
|
||||
body := GeminiVideoPayload{
|
||||
Instances: []GeminiVideoRequest{
|
||||
{Prompt: req.Prompt},
|
||||
},
|
||||
Parameters: VideoGenerationConfig{},
|
||||
Parameters: GeminiVideoGenerationConfig{},
|
||||
}
|
||||
|
||||
if len(req.Images) > 0 {
|
||||
body.Instances[0].Image = a.convertImage(req.Images[0])
|
||||
}
|
||||
if len(req.Images) > 1 {
|
||||
body.Instances[0].LastFrame = a.convertImage(req.Images[1])
|
||||
}
|
||||
|
||||
// Parse metadata for additional configuration
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
@@ -211,7 +200,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -234,7 +223,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("x-goog-api-key", key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
@@ -258,19 +251,20 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
if len(op.Response.GenerateVideoResponse.GeneratedSamples) == 0 {
|
||||
ti.Status = model.TaskStatusFailure
|
||||
ti.Reason = fmt.Sprintf("no generated video url found: %s", strings.Join(op.Response.GenerateVideoResponse.RaiMediaFilteredReasons, "; "))
|
||||
} else {
|
||||
if uri := op.Response.GenerateVideoResponse.GeneratedSamples[0].Video.URI; uri != "" {
|
||||
ti.RemoteUrl = uri
|
||||
}
|
||||
ti.Status = model.TaskStatusSuccess
|
||||
}
|
||||
ti.Status = model.TaskStatusSuccess
|
||||
ti.Progress = "100%"
|
||||
|
||||
taskID := encodeLocalTaskID(op.Name)
|
||||
ti.TaskID = taskID
|
||||
ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
|
||||
|
||||
// Extract URL from generateVideoResponse if available
|
||||
if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
|
||||
if uri := op.Response.GenerateVideoResponse.GeneratedSamples[0].Video.URI; uri != "" {
|
||||
ti.RemoteUrl = uri
|
||||
}
|
||||
}
|
||||
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
@@ -299,30 +293,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
return common.Marshal(video)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertImage(imageStr string) *Image {
|
||||
if strings.TrimSpace(imageStr) == "" {
|
||||
return nil
|
||||
}
|
||||
img := &Image{
|
||||
MimeType: "image/png",
|
||||
BytesBase64Encoded: imageStr,
|
||||
}
|
||||
if strings.HasPrefix(imageStr, "data:image/") {
|
||||
parts := strings.Split(imageStr, ";base64,")
|
||||
if len(parts) == 2 {
|
||||
img.MimeType = strings.TrimPrefix(parts[0], "data:")
|
||||
img.BytesBase64Encoded = parts[1]
|
||||
}
|
||||
} else if strings.HasPrefix(imageStr, "http") {
|
||||
mimeType, data, err := service.GetImageFromUrl(imageStr)
|
||||
if err == nil {
|
||||
img.MimeType = mimeType
|
||||
img.BytesBase64Encoded = data
|
||||
}
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
@@ -110,7 +110,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return hResp.TaskID, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -126,7 +126,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -210,7 +210,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -251,7 +251,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
}
|
||||
}
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -199,7 +199,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -228,7 +228,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
@@ -67,11 +69,30 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.apiKey = info.ApiKey
|
||||
}
|
||||
|
||||
func validateRemixRequest(c *gin.Context) *dto.TaskError {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return validateRemixRequest(c)
|
||||
}
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
|
||||
}
|
||||
|
||||
@@ -125,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -140,7 +161,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -132,7 +132,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
|
||||
byteBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
@@ -153,11 +153,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
@@ -121,7 +120,11 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
proxy := ""
|
||||
if info != nil {
|
||||
proxy = info.ChannelSetting.Proxy
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
@@ -147,13 +150,40 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
body.Parameters["storageUri"] = v
|
||||
}
|
||||
if v, ok := req.Metadata["sampleCount"]; ok {
|
||||
body.Parameters["sampleCount"] = v
|
||||
if i, ok := v.(int); ok {
|
||||
body.Parameters["sampleCount"] = i
|
||||
}
|
||||
if f, ok := v.(float64); ok {
|
||||
body.Parameters["sampleCount"] = int(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, ok := body.Parameters["sampleCount"]; !ok {
|
||||
body.Parameters["sampleCount"] = 1
|
||||
}
|
||||
|
||||
if body.Parameters["sampleCount"].(int) <= 0 {
|
||||
return nil, fmt.Errorf("sampleCount must be greater than 0")
|
||||
}
|
||||
|
||||
// if req.Duration > 0 {
|
||||
// body.Parameters["durationSeconds"] = req.Duration
|
||||
// } else if req.Seconds != "" {
|
||||
// seconds, err := strconv.Atoi(req.Seconds)
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrap(err, "convert seconds to int failed")
|
||||
// }
|
||||
// body.Parameters["durationSeconds"] = seconds
|
||||
// }
|
||||
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
"sampleCount": float64(body.Parameters["sampleCount"].(int)),
|
||||
}
|
||||
|
||||
// if v, ok := body.Parameters["durationSeconds"]; ok {
|
||||
// info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
|
||||
// }
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -190,7 +220,7 @@ func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generat
|
||||
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -223,7 +253,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
@@ -235,7 +265,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
|
||||
@@ -188,7 +188,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return vResp.TaskId, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -204,7 +204,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -105,7 +105,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
|
||||
var tencentResponse TencentChatResponse
|
||||
err := json.Unmarshal([]byte(data), &tencentResponse)
|
||||
err := common.Unmarshal([]byte(data), &tencentResponse)
|
||||
if err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
@@ -130,7 +130,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||
return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens()), nil
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -181,6 +182,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
} else if baseModel, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -184,9 +184,9 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
c.Data(http.StatusOK, contentType, audioData)
|
||||
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
TotalTokens: info.GetEstimatePromptTokens(),
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
@@ -284,9 +284,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
||||
if msg.Sequence < 0 {
|
||||
c.Status(http.StatusOK)
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
TotalTokens: info.GetEstimatePromptTokens(),
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
@@ -297,9 +297,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
TotalTokens: info.GetEstimatePromptTokens(),
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
|
||||
@@ -36,8 +36,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -63,6 +62,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/embeddings", specialPlan.OpenAIBaseURL), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/api/paas/v4/embeddings", baseURL), nil
|
||||
case relayconstant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/paas/v4/images/generations", baseURL), nil
|
||||
default:
|
||||
if hasSpecialPlan && specialPlan.OpenAIBaseURL != "" {
|
||||
return fmt.Sprintf("%s/chat/completions", specialPlan.OpenAIBaseURL), nil
|
||||
@@ -114,6 +115,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
if info.RelayMode == relayconstant.RelayModeImagesGenerations {
|
||||
return zhipu4vImageHandler(c, resp, info)
|
||||
}
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
|
||||
127
relay/channel/zhipu_4v/image.go
Normal file
127
relay/channel/zhipu_4v/image.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package zhipu_4v
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type zhipuImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
WatermarkEnabled *bool `json:"watermark_enabled,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type zhipuImageResponse struct {
|
||||
Created *int64 `json:"created,omitempty"`
|
||||
Data []zhipuImageData `json:"data,omitempty"`
|
||||
ContentFilter any `json:"content_filter,omitempty"`
|
||||
Usage *dto.Usage `json:"usage,omitempty"`
|
||||
Error *zhipuImageError `json:"error,omitempty"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
ExtendParam map[string]string `json:"extendParam,omitempty"`
|
||||
}
|
||||
|
||||
type zhipuImageError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type zhipuImageData struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
ImageUrl string `json:"image_url,omitempty"`
|
||||
B64Json string `json:"b64_json,omitempty"`
|
||||
B64Image string `json:"b64_image,omitempty"`
|
||||
}
|
||||
|
||||
type openAIImagePayload struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []openAIImageData `json:"data"`
|
||||
}
|
||||
|
||||
type openAIImageData struct {
|
||||
B64Json string `json:"b64_json"`
|
||||
}
|
||||
|
||||
func zhipu4vImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var zhipuResp zhipuImageResponse
|
||||
if err := common.Unmarshal(responseBody, &zhipuResp); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if zhipuResp.Error != nil && zhipuResp.Error.Message != "" {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
Message: zhipuResp.Error.Message,
|
||||
Type: "zhipu_image_error",
|
||||
Code: zhipuResp.Error.Code,
|
||||
}, resp.StatusCode)
|
||||
}
|
||||
|
||||
payload := openAIImagePayload{}
|
||||
if zhipuResp.Created != nil && *zhipuResp.Created != 0 {
|
||||
payload.Created = *zhipuResp.Created
|
||||
} else {
|
||||
payload.Created = info.StartTime.Unix()
|
||||
}
|
||||
for _, data := range zhipuResp.Data {
|
||||
url := data.Url
|
||||
if url == "" {
|
||||
url = data.ImageUrl
|
||||
}
|
||||
if url == "" {
|
||||
logger.LogWarn(c, "zhipu_image_missing_url")
|
||||
continue
|
||||
}
|
||||
|
||||
var b64 string
|
||||
switch {
|
||||
case data.B64Json != "":
|
||||
b64 = data.B64Json
|
||||
case data.B64Image != "":
|
||||
b64 = data.B64Image
|
||||
default:
|
||||
_, downloaded, err := service.GetImageFromUrl(url)
|
||||
if err != nil {
|
||||
logger.LogError(c, "zhipu_image_get_b64_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64 = downloaded
|
||||
}
|
||||
|
||||
if b64 == "" {
|
||||
logger.LogWarn(c, "zhipu_image_empty_b64")
|
||||
continue
|
||||
}
|
||||
|
||||
imageData := openAIImageData{
|
||||
B64Json: b64,
|
||||
}
|
||||
payload.Data = append(payload.Data, imageData)
|
||||
}
|
||||
|
||||
jsonResp, err := common.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResp)
|
||||
|
||||
return &dto.Usage{}, nil
|
||||
}
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
|
||||
|
||||
type ConditionOperation struct {
|
||||
Path string `json:"path"` // JSON路径
|
||||
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
|
||||
@@ -186,8 +188,7 @@ func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperat
|
||||
}
|
||||
|
||||
func processNegativeIndex(jsonStr string, path string) string {
|
||||
re := regexp.MustCompile(`\.(-\d+)`)
|
||||
matches := re.FindAllStringSubmatch(path, -1)
|
||||
matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return path
|
||||
|
||||
@@ -73,6 +73,11 @@ type ChannelMeta struct {
|
||||
SupportStreamOptions bool // 是否支持流式选项
|
||||
}
|
||||
|
||||
type TokenCountMeta struct {
|
||||
//promptTokens int
|
||||
estimatePromptTokens int
|
||||
}
|
||||
|
||||
type RelayInfo struct {
|
||||
TokenId int
|
||||
TokenKey string
|
||||
@@ -91,7 +96,6 @@ type RelayInfo struct {
|
||||
RelayMode int
|
||||
OriginModelName string
|
||||
RequestURLPath string
|
||||
PromptTokens int
|
||||
ShouldIncludeUsage bool
|
||||
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
||||
ClientWs *websocket.Conn
|
||||
@@ -115,6 +119,7 @@ type RelayInfo struct {
|
||||
Request dto.Request
|
||||
|
||||
ThinkingContentInfo
|
||||
TokenCountMeta
|
||||
*ClaudeConvertInfo
|
||||
*RerankerInfo
|
||||
*ResponsesUsageInfo
|
||||
@@ -189,7 +194,7 @@ func (info *RelayInfo) ToString() string {
|
||||
fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
|
||||
fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
|
||||
fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
|
||||
fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens)
|
||||
fmt.Fprintf(b, "EstimatePromptTokens: %d, ", info.estimatePromptTokens)
|
||||
fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
|
||||
fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
|
||||
fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
|
||||
@@ -391,7 +396,6 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||
|
||||
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
|
||||
|
||||
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
|
||||
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
|
||||
@@ -408,6 +412,10 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
},
|
||||
TokenCountMeta: TokenCountMeta{
|
||||
//promptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
|
||||
estimatePromptTokens: common.GetContextKeyInt(c, constant.ContextKeyEstimatedTokens),
|
||||
},
|
||||
}
|
||||
|
||||
if info.RelayMode == relayconstant.RelayModeUnknown {
|
||||
@@ -463,8 +471,16 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
|
||||
}
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||
info.PromptTokens = promptTokens
|
||||
//func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||
// info.promptTokens = promptTokens
|
||||
//}
|
||||
|
||||
func (info *RelayInfo) SetEstimatePromptTokens(promptTokens int) {
|
||||
info.estimatePromptTokens = promptTokens
|
||||
}
|
||||
|
||||
func (info *RelayInfo) GetEstimatePromptTokens() int {
|
||||
return info.estimatePromptTokens
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetFirstResponseTime() {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -229,54 +226,3 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
|
||||
storeTaskRequest(c, info, action, req)
|
||||
return nil
|
||||
}
|
||||
func GetImagesBase64sFromForm(c *gin.Context) ([]*Base64Data, error) {
|
||||
return GetBase64sFromForm(c, "image")
|
||||
}
|
||||
func GetImageBase64sFromForm(c *gin.Context) (*Base64Data, error) {
|
||||
base64s, err := GetImagesBase64sFromForm(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return base64s[0], nil
|
||||
}
|
||||
|
||||
type Base64Data struct {
|
||||
MimeType string
|
||||
Data string
|
||||
}
|
||||
|
||||
func (m Base64Data) String() string {
|
||||
return fmt.Sprintf("data:%s;base64,%s", m.MimeType, m.Data)
|
||||
}
|
||||
func GetBase64sFromForm(c *gin.Context, fieldName string) ([]*Base64Data, error) {
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
if _, err := c.MultipartForm(); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||
}
|
||||
mf = c.Request.MultipartForm
|
||||
}
|
||||
imageFiles, exists := mf.File[fieldName]
|
||||
if !exists || len(imageFiles) == 0 {
|
||||
return nil, errors.New("field " + fieldName + " is not found or empty")
|
||||
}
|
||||
var imageBase64s []*Base64Data
|
||||
for _, file := range imageFiles {
|
||||
image, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to open image file")
|
||||
}
|
||||
defer image.Close()
|
||||
imageData, err := io.ReadAll(image)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read image file")
|
||||
}
|
||||
mimeType := http.DetectContentType(imageData)
|
||||
base64Data := base64.StdEncoding.EncodeToString(imageData)
|
||||
imageBase64s = append(imageBase64s, &Base64Data{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
})
|
||||
}
|
||||
return imageBase64s, nil
|
||||
}
|
||||
|
||||
@@ -57,8 +57,8 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
jinaResp = dto.RerankResponse{
|
||||
Results: jinaRespResults,
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
TotalTokens: info.PromptTokens,
|
||||
PromptTokens: info.GetEstimatePromptTokens(),
|
||||
TotalTokens: info.GetEstimatePromptTokens(),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -192,9 +192,9 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.PromptTokens,
|
||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: relayInfo.PromptTokens,
|
||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
}
|
||||
extraContent += "(可能是请求出错)"
|
||||
}
|
||||
|
||||
@@ -99,7 +99,10 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
// check if free model pre-consume is disabled
|
||||
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
|
||||
// if model price or ratio is 0, do not pre-consume quota
|
||||
if usePrice {
|
||||
if groupRatioInfo.GroupRatio == 0 {
|
||||
preConsumedQuota = 0
|
||||
freeModel = true
|
||||
} else if usePrice {
|
||||
if modelPrice == 0 {
|
||||
preConsumedQuota = 0
|
||||
freeModel = true
|
||||
|
||||
@@ -72,6 +72,8 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
if common.DebugEnabled {
|
||||
// print timeout and ping interval for debugging
|
||||
println("relay timeout seconds:", common.RelayTimeout)
|
||||
println("relay max idle conns:", common.RelayMaxIdleConns)
|
||||
println("relay max idle conns per host:", common.RelayMaxIdleConnsPerHost)
|
||||
println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
|
||||
println("ping interval seconds:", int64(pingInterval.Seconds()))
|
||||
}
|
||||
|
||||
@@ -141,7 +141,6 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
|
||||
imageRequest.N = uint(common.String2Int(formData.Get("n")))
|
||||
imageRequest.Quality = formData.Get("quality")
|
||||
imageRequest.Size = formData.Get("size")
|
||||
imageRequest.ResponseFormat = formData.Get("response_format")
|
||||
if imageValue := formData.Get("image"); imageValue != "" {
|
||||
imageRequest.Image, _ = json.Marshal(imageValue)
|
||||
}
|
||||
|
||||
@@ -32,7 +32,94 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
if info.TaskRelayInfo == nil {
|
||||
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
|
||||
info.Action = constant.TaskActionRemix
|
||||
}
|
||||
|
||||
// 提取 remix 任务的 video_id
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
videoID := c.Param("video_id")
|
||||
if strings.TrimSpace(videoID) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
info.OriginTaskID = videoID
|
||||
}
|
||||
|
||||
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||
|
||||
// 获取原始任务信息
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if info.OriginModelName == "" {
|
||||
if originTask.Properties.OriginModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.OriginModelName
|
||||
} else if originTask.Properties.UpstreamModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.UpstreamModelName
|
||||
} else {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
if m, ok := taskData["model"].(string); ok && m != "" {
|
||||
info.OriginModelName = m
|
||||
platform = originTask.Platform
|
||||
}
|
||||
}
|
||||
}
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
key, _, newAPIError := channel.GetNextEnabledKey()
|
||||
if newAPIError != nil {
|
||||
taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
|
||||
return
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
|
||||
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
info.ChannelType = channel.Type
|
||||
info.ApiKey = key
|
||||
platform = originTask.Platform
|
||||
}
|
||||
|
||||
// 使用原始任务的参数
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
secondsStr, _ := taskData["seconds"].(string)
|
||||
seconds, _ := strconv.Atoi(secondsStr)
|
||||
if seconds <= 0 {
|
||||
seconds = 4
|
||||
}
|
||||
sizeStr, _ := taskData["size"].(string)
|
||||
if info.PriceData.OtherRatios == nil {
|
||||
info.PriceData.OtherRatios = map[string]float64{}
|
||||
}
|
||||
info.PriceData.OtherRatios["seconds"] = float64(seconds)
|
||||
info.PriceData.OtherRatios["size"] = 1
|
||||
if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
|
||||
info.PriceData.OtherRatios["size"] = 1.666667
|
||||
}
|
||||
}
|
||||
}
|
||||
if platform == "" {
|
||||
platform = GetTaskPlatform(c)
|
||||
}
|
||||
@@ -94,34 +181,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
return
|
||||
}
|
||||
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
|
||||
}
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
}
|
||||
}
|
||||
|
||||
// build body
|
||||
requestBody, err := adaptor.BuildRequestBody(c, info)
|
||||
if err != nil {
|
||||
@@ -326,6 +385,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
if channelModel.GetBaseURL() != "" {
|
||||
baseURL = channelModel.GetBaseURL()
|
||||
}
|
||||
proxy := channelModel.GetSetting().Proxy
|
||||
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||
if adaptor == nil {
|
||||
return
|
||||
@@ -333,7 +393,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||
"task_id": originTask.TaskID,
|
||||
"action": originTask.Action,
|
||||
})
|
||||
}, proxy)
|
||||
if err2 != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func SetVideoRouter(router *gin.Engine) {
|
||||
videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
|
||||
videoV1Router.POST("/video/generations", controller.RelayTask)
|
||||
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
|
||||
videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)
|
||||
}
|
||||
// openai compatible API video routes
|
||||
// docs: https://platform.openai.com/docs/api-reference/videos/create
|
||||
|
||||
@@ -201,6 +201,10 @@ func generateStopBlock(index int) *dto.ClaudeResponse {
|
||||
}
|
||||
|
||||
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
|
||||
if info.ClaudeConvertInfo.Done {
|
||||
return nil
|
||||
}
|
||||
|
||||
var claudeResponses []*dto.ClaudeResponse
|
||||
if info.SendResponseCount == 1 {
|
||||
msg := &dto.ClaudeMediaMessage{
|
||||
@@ -209,7 +213,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: info.PromptTokens,
|
||||
InputTokens: info.GetEstimatePromptTokens(),
|
||||
OutputTokens: 0,
|
||||
},
|
||||
}
|
||||
@@ -218,45 +222,117 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
Type: "message_start",
|
||||
Message: msg,
|
||||
})
|
||||
claudeResponses = append(claudeResponses)
|
||||
//claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
// Type: "ping",
|
||||
//})
|
||||
if openAIResponse.IsToolCall() {
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
|
||||
var toolCall dto.ToolCallResponse
|
||||
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolCall = openAIResponse.Choices[0].Delta.ToolCalls[0]
|
||||
} else {
|
||||
first := openAIResponse.GetFirstToolCall()
|
||||
if first != nil {
|
||||
toolCall = *first
|
||||
} else {
|
||||
toolCall = dto.ToolCallResponse{}
|
||||
}
|
||||
}
|
||||
resp := &dto.ClaudeResponse{
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Id: openAIResponse.GetFirstToolCall().ID,
|
||||
Id: toolCall.ID,
|
||||
Type: "tool_use",
|
||||
Name: openAIResponse.GetFirstToolCall().Function.Name,
|
||||
Name: toolCall.Function.Name,
|
||||
Input: map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
resp.SetIndex(0)
|
||||
claudeResponses = append(claudeResponses, resp)
|
||||
// 首块包含工具 delta,则追加 input_json_delta
|
||||
if toolCall.Function.Arguments != "" {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "input_json_delta",
|
||||
PartialJson: &toolCall.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
|
||||
}
|
||||
// 判断首个响应是否存在内容(非标准的 OpenAI 响应)
|
||||
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 {
|
||||
if len(openAIResponse.Choices) > 0 {
|
||||
reasoning := openAIResponse.Choices[0].Delta.GetReasoningContent()
|
||||
content := openAIResponse.Choices[0].Delta.GetContentString()
|
||||
|
||||
if reasoning != "" {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking",
|
||||
Thinking: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking_delta",
|
||||
Thinking: &reasoning,
|
||||
},
|
||||
})
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
|
||||
} else if content != "" {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "text_delta",
|
||||
Text: common.GetPointer[string](content),
|
||||
},
|
||||
})
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||
}
|
||||
}
|
||||
|
||||
// 如果首块就带 finish_reason,需要立即发送停止块
|
||||
if len(openAIResponse.Choices) > 0 && openAIResponse.Choices[0].FinishReason != nil && *openAIResponse.Choices[0].FinishReason != "" {
|
||||
info.FinishReason = *openAIResponse.Choices[0].FinishReason
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
oaiUsage := openAIResponse.Usage
|
||||
if oaiUsage == nil {
|
||||
oaiUsage = info.ClaudeConvertInfo.Usage
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
})
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
Type: "message_stop",
|
||||
})
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "text_delta",
|
||||
Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()),
|
||||
},
|
||||
})
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
}
|
||||
return claudeResponses
|
||||
}
|
||||
@@ -264,7 +340,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
if len(openAIResponse.Choices) == 0 {
|
||||
// no choices
|
||||
// 可能为非标准的 OpenAI 响应,判断是否已经完成
|
||||
if info.Done {
|
||||
if info.ClaudeConvertInfo.Done {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
if oaiUsage != nil {
|
||||
@@ -288,16 +364,110 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
return claudeResponses
|
||||
} else {
|
||||
chosenChoice := openAIResponse.Choices[0]
|
||||
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
|
||||
// should be done
|
||||
doneChunk := chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != ""
|
||||
if doneChunk {
|
||||
info.FinishReason = *chosenChoice.FinishReason
|
||||
if !info.Done {
|
||||
return claudeResponses
|
||||
}
|
||||
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
var isEmpty bool
|
||||
claudeResponse.Type = "content_block_delta"
|
||||
if len(chosenChoice.Delta.ToolCalls) > 0 {
|
||||
toolCalls := chosenChoice.Delta.ToolCalls
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
info.ClaudeConvertInfo.Index++
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
|
||||
|
||||
for i, toolCall := range toolCalls {
|
||||
blockIndex := info.ClaudeConvertInfo.Index
|
||||
if toolCall.Index != nil {
|
||||
blockIndex = *toolCall.Index
|
||||
} else if len(toolCalls) > 1 {
|
||||
blockIndex = info.ClaudeConvertInfo.Index + i
|
||||
}
|
||||
|
||||
idx := blockIndex
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &idx,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Id: toolCall.ID,
|
||||
Type: "tool_use",
|
||||
Name: toolCall.Function.Name,
|
||||
Input: map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &idx,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "input_json_delta",
|
||||
PartialJson: &toolCall.Function.Arguments,
|
||||
},
|
||||
})
|
||||
|
||||
info.ClaudeConvertInfo.Index = blockIndex
|
||||
}
|
||||
} else {
|
||||
reasoning := chosenChoice.Delta.GetReasoningContent()
|
||||
textContent := chosenChoice.Delta.GetContentString()
|
||||
if reasoning != "" || textContent != "" {
|
||||
if reasoning != "" {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking",
|
||||
Thinking: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "thinking_delta",
|
||||
Thinking: &reasoning,
|
||||
}
|
||||
} else {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
|
||||
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeThinking || info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeTools {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
info.ClaudeConvertInfo.Index++
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "text_delta",
|
||||
Text: common.GetPointer[string](textContent),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
isEmpty = true
|
||||
}
|
||||
}
|
||||
if info.Done {
|
||||
|
||||
claudeResponse.Index = &info.ClaudeConvertInfo.Index
|
||||
if !isEmpty && claudeResponse.Delta != nil {
|
||||
claudeResponses = append(claudeResponses, &claudeResponse)
|
||||
}
|
||||
|
||||
if doneChunk || info.ClaudeConvertInfo.Done {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
oaiUsage := openAIResponse.Usage
|
||||
if oaiUsage == nil {
|
||||
oaiUsage = info.ClaudeConvertInfo.Usage
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
@@ -315,83 +485,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_stop",
|
||||
})
|
||||
} else {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
var isEmpty bool
|
||||
claudeResponse.Type = "content_block_delta"
|
||||
if len(chosenChoice.Delta.ToolCalls) > 0 {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
info.ClaudeConvertInfo.Index++
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Id: openAIResponse.GetFirstToolCall().ID,
|
||||
Type: "tool_use",
|
||||
Name: openAIResponse.GetFirstToolCall().Function.Name,
|
||||
Input: map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
|
||||
// tools delta
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "input_json_delta",
|
||||
PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
|
||||
}
|
||||
} else {
|
||||
reasoning := chosenChoice.Delta.GetReasoningContent()
|
||||
textContent := chosenChoice.Delta.GetContentString()
|
||||
if reasoning != "" || textContent != "" {
|
||||
if reasoning != "" {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
|
||||
//info.ClaudeConvertInfo.Index++
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking",
|
||||
Thinking: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
|
||||
// text delta
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "thinking_delta",
|
||||
Thinking: &reasoning,
|
||||
}
|
||||
} else {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
|
||||
if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
info.ClaudeConvertInfo.Index++
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||
// text delta
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "text_delta",
|
||||
Text: common.GetPointer[string](textContent),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
isEmpty = true
|
||||
}
|
||||
}
|
||||
claudeResponse.Index = &info.ClaudeConvertInfo.Index
|
||||
if !isEmpty {
|
||||
claudeResponses = append(claudeResponses, &claudeResponse)
|
||||
}
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
return claudeResponses
|
||||
}
|
||||
}
|
||||
|
||||
@@ -734,12 +829,18 @@ func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: info.PromptTokens,
|
||||
PromptTokenCount: info.GetEstimatePromptTokens(),
|
||||
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
|
||||
TotalTokenCount: info.PromptTokens,
|
||||
TotalTokenCount: info.GetEstimatePromptTokens(),
|
||||
},
|
||||
}
|
||||
|
||||
if openAIResponse.Usage != nil {
|
||||
geminiResponse.UsageMetadata.PromptTokenCount = openAIResponse.Usage.PromptTokens
|
||||
geminiResponse.UsageMetadata.CandidatesTokenCount = openAIResponse.Usage.CompletionTokens
|
||||
geminiResponse.UsageMetadata.TotalTokenCount = openAIResponse.Usage.TotalTokens
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
|
||||
@@ -34,12 +34,20 @@ func checkRedirect(req *http.Request, via []*http.Request) error {
|
||||
}
|
||||
|
||||
func InitHttpClient() {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
|
||||
if common.RelayTimeout == 0 {
|
||||
httpClient = &http.Client{
|
||||
Transport: transport,
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
} else {
|
||||
httpClient = &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
@@ -50,6 +58,14 @@ func GetHttpClient() *http.Client {
|
||||
return httpClient
|
||||
}
|
||||
|
||||
// GetHttpClientWithProxy returns the default client or a proxy-enabled one when proxyURL is provided.
|
||||
func GetHttpClientWithProxy(proxyURL string) (*http.Client, error) {
|
||||
if proxyURL == "" {
|
||||
return GetHttpClient(), nil
|
||||
}
|
||||
return NewProxyHttpClient(proxyURL)
|
||||
}
|
||||
|
||||
// ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化
|
||||
func ResetProxyClientCache() {
|
||||
proxyClientLock.Lock()
|
||||
@@ -84,7 +100,10 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
case "http", "https":
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
},
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
@@ -116,6 +135,9 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
"math"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -23,64 +21,8 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var defaultTokenEncoder tokenizer.Codec
|
||||
|
||||
// tokenEncoderMap is used to store token encoders for different models
|
||||
var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
||||
|
||||
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
|
||||
var tokenEncoderMutex sync.RWMutex
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
defaultTokenEncoder = codec.NewCl100kBase()
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) tokenizer.Codec {
|
||||
// First, try to get the encoder from cache with read lock
|
||||
tokenEncoderMutex.RLock()
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
tokenEncoderMutex.RUnlock()
|
||||
return encoder
|
||||
}
|
||||
tokenEncoderMutex.RUnlock()
|
||||
|
||||
// If not in cache, create new encoder with write lock
|
||||
tokenEncoderMutex.Lock()
|
||||
defer tokenEncoderMutex.Unlock()
|
||||
|
||||
// Double-check if another goroutine already created the encoder
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
return encoder
|
||||
}
|
||||
|
||||
// Create new encoder
|
||||
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
|
||||
if err != nil {
|
||||
// Cache the default encoder for this model to avoid repeated failures
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
// Cache the new encoder
|
||||
tokenEncoderMap[model] = modelCodec
|
||||
return modelCodec
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tkm, _ := tokenEncoder.Count(text)
|
||||
return tkm
|
||||
}
|
||||
|
||||
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||||
if fileMeta == nil {
|
||||
return 0, fmt.Errorf("image_url_is_nil")
|
||||
@@ -257,7 +199,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
return tiles*tileTokens + baseTokens, nil
|
||||
}
|
||||
|
||||
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
// 是否统计token
|
||||
if !constant.CountToken {
|
||||
return 0, nil
|
||||
@@ -375,14 +317,14 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
for i, file := range meta.Files {
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if info.RelayFormat == types.RelayFormatGemini {
|
||||
tkm += 520 // gemini per input image tokens
|
||||
} else {
|
||||
if common.IsOpenAITextModel(info.OriginModelName) {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
|
||||
}
|
||||
tkm += token
|
||||
} else {
|
||||
tkm += 520
|
||||
}
|
||||
case types.FileTypeAudio:
|
||||
tkm += 256
|
||||
@@ -399,111 +341,6 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
||||
tkm := 0
|
||||
|
||||
// Count tokens in messages
|
||||
msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tkm += msgTokens
|
||||
|
||||
// Count tokens in system message
|
||||
if request.System != "" {
|
||||
systemTokens := CountTokenInput(request.System, model)
|
||||
tkm += systemTokens
|
||||
}
|
||||
|
||||
if request.Tools != nil {
|
||||
// check is array
|
||||
if tools, ok := request.Tools.([]any); ok {
|
||||
if len(tools) > 0 {
|
||||
parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
|
||||
if err1 != nil {
|
||||
return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
|
||||
}
|
||||
toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
|
||||
if err2 != nil {
|
||||
return 0, fmt.Errorf("tools: %v", err)
|
||||
}
|
||||
tkm += toolTokens
|
||||
}
|
||||
} else {
|
||||
return 0, errors.New("tools: Input should be a valid list")
|
||||
}
|
||||
}
|
||||
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
tokenNum := 0
|
||||
|
||||
for _, message := range messages {
|
||||
// Count tokens for role
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.IsStringContent() {
|
||||
tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
|
||||
} else {
|
||||
content, err := message.ParseContent()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, mediaMessage := range content {
|
||||
switch mediaMessage.Type {
|
||||
case "text":
|
||||
tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
|
||||
case "image":
|
||||
//imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
|
||||
//if err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
tokenNum += 1000
|
||||
case "tool_use":
|
||||
if mediaMessage.Input != nil {
|
||||
tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
|
||||
inputJSON, _ := json.Marshal(mediaMessage.Input)
|
||||
tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
|
||||
}
|
||||
case "tool_result":
|
||||
if mediaMessage.Content != nil {
|
||||
contentJSON, _ := json.Marshal(mediaMessage.Content)
|
||||
tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
|
||||
tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
|
||||
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
tokenNum := 0
|
||||
|
||||
for _, tool := range tools {
|
||||
tokenNum += getTokenNum(tokenEncoder, tool.Name)
|
||||
tokenNum += getTokenNum(tokenEncoder, tool.Description)
|
||||
|
||||
schemaJSON, err := json.Marshal(tool.InputSchema)
|
||||
if err != nil {
|
||||
return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
|
||||
}
|
||||
|
||||
// Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
|
||||
tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
|
||||
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
||||
audioToken := 0
|
||||
textToken := 0
|
||||
@@ -578,31 +415,6 @@ func CountTokenInput(input any, model string) int {
|
||||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||||
}
|
||||
|
||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||
tokens := 0
|
||||
for _, message := range messages {
|
||||
tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
||||
tokens += tkm
|
||||
if message.Delta.ToolCalls != nil {
|
||||
for _, tool := range message.Delta.ToolCalls {
|
||||
tkm := CountTokenInput(tool.Function.Name, model)
|
||||
tokens += tkm
|
||||
tkm = CountTokenInput(tool.Function.Arguments, model)
|
||||
tokens += tkm
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountTTSToken(text string, model string) int {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
return utf8.RuneCountInString(text)
|
||||
} else {
|
||||
return CountTextToken(text, model)
|
||||
}
|
||||
}
|
||||
|
||||
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
|
||||
if audioBase64 == "" {
|
||||
return 0, nil
|
||||
@@ -625,17 +437,16 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
||||
return int(duration / 60 * 200 / 0.24), nil
|
||||
}
|
||||
|
||||
//func CountAudioToken(sec float64, audioType string) {
|
||||
// if audioType == "input" {
|
||||
//
|
||||
// }
|
||||
//}
|
||||
|
||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
// CountTextToken 统计文本的token数量,仅OpenAI模型使用tokenizer,其余模型使用估算
|
||||
func CountTextToken(text string, model string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
if common.IsOpenAITextModel(model) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
} else {
|
||||
// 非openai模型,使用tiktoken-go计算没有意义,使用估算节省资源
|
||||
return EstimateTokenByModel(model, text)
|
||||
}
|
||||
}
|
||||
|
||||
230
service/token_estimator.go
Normal file
230
service/token_estimator.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Provider 定义模型厂商大类
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
OpenAI Provider = "openai" // 代表 GPT-3.5, GPT-4, GPT-4o
|
||||
Gemini Provider = "gemini" // 代表 Gemini 1.0, 1.5 Pro/Flash
|
||||
Claude Provider = "claude" // 代表 Claude 3, 3.5 Sonnet
|
||||
Unknown Provider = "unknown" // 兜底默认
|
||||
)
|
||||
|
||||
// multipliers 定义不同厂商的计费权重
|
||||
type multipliers struct {
|
||||
Word float64 // 英文单词 (每词)
|
||||
Number float64 // 数字 (每连续数字串)
|
||||
CJK float64 // 中日韩字符 (每字)
|
||||
Symbol float64 // 普通标点符号 (每个)
|
||||
MathSymbol float64 // 数学符号 (∑,∫,∂,√等,每个)
|
||||
URLDelim float64 // URL分隔符 (/,:,?,&,=,#,%) - tokenizer优化好
|
||||
AtSign float64 // @符号 - 导致单词切分,消耗较高
|
||||
Emoji float64 // Emoji表情 (每个)
|
||||
Newline float64 // 换行符/制表符 (每个)
|
||||
Space float64 // 空格 (每个)
|
||||
BasePad int // 基础起步消耗 (Start/End tokens)
|
||||
}
|
||||
|
||||
var (
|
||||
multipliersMap = map[Provider]multipliers{
|
||||
Gemini: {
|
||||
Word: 1.15, Number: 2.8, CJK: 0.68, Symbol: 0.38, MathSymbol: 1.05, URLDelim: 1.2, AtSign: 2.5, Emoji: 1.08, Newline: 1.15, Space: 0.2, BasePad: 0,
|
||||
},
|
||||
Claude: {
|
||||
Word: 1.13, Number: 1.63, CJK: 1.21, Symbol: 0.4, MathSymbol: 4.52, URLDelim: 1.26, AtSign: 2.82, Emoji: 2.6, Newline: 0.89, Space: 0.39, BasePad: 0,
|
||||
},
|
||||
OpenAI: {
|
||||
Word: 1.02, Number: 1.55, CJK: 0.85, Symbol: 0.4, MathSymbol: 2.68, URLDelim: 1.0, AtSign: 2.0, Emoji: 2.12, Newline: 0.5, Space: 0.42, BasePad: 0,
|
||||
},
|
||||
}
|
||||
multipliersLock sync.RWMutex
|
||||
)
|
||||
|
||||
// getMultipliers 根据厂商获取权重配置
|
||||
func getMultipliers(p Provider) multipliers {
|
||||
multipliersLock.RLock()
|
||||
defer multipliersLock.RUnlock()
|
||||
|
||||
switch p {
|
||||
case Gemini:
|
||||
return multipliersMap[Gemini]
|
||||
case Claude:
|
||||
return multipliersMap[Claude]
|
||||
case OpenAI:
|
||||
return multipliersMap[OpenAI]
|
||||
default:
|
||||
// 默认兜底 (按 OpenAI 的算)
|
||||
return multipliersMap[OpenAI]
|
||||
}
|
||||
}
|
||||
|
||||
// EstimateToken 计算 Token 数量
|
||||
func EstimateToken(provider Provider, text string) int {
|
||||
m := getMultipliers(provider)
|
||||
var count float64
|
||||
|
||||
// 状态机变量
|
||||
type WordType int
|
||||
const (
|
||||
None WordType = iota
|
||||
Latin
|
||||
Number
|
||||
)
|
||||
currentWordType := None
|
||||
|
||||
for _, r := range text {
|
||||
// 1. 处理空格和换行符
|
||||
if unicode.IsSpace(r) {
|
||||
currentWordType = None
|
||||
// 换行符和制表符使用Newline权重
|
||||
if r == '\n' || r == '\t' {
|
||||
count += m.Newline
|
||||
} else {
|
||||
// 普通空格使用Space权重
|
||||
count += m.Space
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 2. 处理 CJK (中日韩) - 按字符计费
|
||||
if isCJK(r) {
|
||||
currentWordType = None
|
||||
count += m.CJK
|
||||
continue
|
||||
}
|
||||
|
||||
// 3. 处理Emoji - 使用专门的Emoji权重
|
||||
if isEmoji(r) {
|
||||
currentWordType = None
|
||||
count += m.Emoji
|
||||
continue
|
||||
}
|
||||
|
||||
// 4. 处理拉丁字母/数字 (英文单词)
|
||||
if isLatinOrNumber(r) {
|
||||
isNum := unicode.IsNumber(r)
|
||||
newType := Latin
|
||||
if isNum {
|
||||
newType = Number
|
||||
}
|
||||
|
||||
// 如果之前不在单词中,或者类型发生变化(字母<->数字),则视为新token
|
||||
// 注意:对于OpenAI,通常"version 3.5"会切分,"abc123xyz"有时也会切分
|
||||
// 这里简单起见,字母和数字切换时增加权重
|
||||
if currentWordType == None || currentWordType != newType {
|
||||
if newType == Number {
|
||||
count += m.Number
|
||||
} else {
|
||||
count += m.Word
|
||||
}
|
||||
currentWordType = newType
|
||||
}
|
||||
// 单词中间的字符不额外计费
|
||||
continue
|
||||
}
|
||||
|
||||
// 5. 处理标点符号/特殊字符 - 按类型使用不同权重
|
||||
currentWordType = None
|
||||
if isMathSymbol(r) {
|
||||
count += m.MathSymbol
|
||||
} else if r == '@' {
|
||||
count += m.AtSign
|
||||
} else if isURLDelim(r) {
|
||||
count += m.URLDelim
|
||||
} else {
|
||||
count += m.Symbol
|
||||
}
|
||||
}
|
||||
|
||||
// 向上取整并加上基础 padding
|
||||
return int(math.Ceil(count)) + m.BasePad
|
||||
}
|
||||
|
||||
// 辅助:判断是否为 CJK 字符
|
||||
func isCJK(r rune) bool {
|
||||
return unicode.Is(unicode.Han, r) ||
|
||||
(r >= 0x3040 && r <= 0x30FF) || // 日文
|
||||
(r >= 0xAC00 && r <= 0xD7A3) // 韩文
|
||||
}
|
||||
|
||||
// 辅助:判断是否为单词主体 (字母或数字)
|
||||
func isLatinOrNumber(r rune) bool {
|
||||
return unicode.IsLetter(r) || unicode.IsNumber(r)
|
||||
}
|
||||
|
||||
// 辅助:判断是否为Emoji字符
|
||||
func isEmoji(r rune) bool {
|
||||
// Emoji的Unicode范围
|
||||
// 基本范围:0x1F300-0x1F9FF (Emoticons, Symbols, Pictographs)
|
||||
// 补充范围:0x2600-0x26FF (Misc Symbols), 0x2700-0x27BF (Dingbats)
|
||||
// 表情符号:0x1F600-0x1F64F (Emoticons)
|
||||
// 其他:0x1F900-0x1F9FF (Supplemental Symbols and Pictographs)
|
||||
return (r >= 0x1F300 && r <= 0x1F9FF) ||
|
||||
(r >= 0x2600 && r <= 0x26FF) ||
|
||||
(r >= 0x2700 && r <= 0x27BF) ||
|
||||
(r >= 0x1F600 && r <= 0x1F64F) ||
|
||||
(r >= 0x1F900 && r <= 0x1F9FF) ||
|
||||
(r >= 0x1FA00 && r <= 0x1FAFF) // Symbols and Pictographs Extended-A
|
||||
}
|
||||
|
||||
// 辅助:判断是否为数学符号
|
||||
func isMathSymbol(r rune) bool {
|
||||
// 数学运算符和符号
|
||||
// 基本数学符号:∑ ∫ ∂ √ ∞ ≤ ≥ ≠ ≈ ± × ÷
|
||||
// 上下标数字:² ³ ¹ ⁴ ⁵ ⁶ ⁷ ⁸ ⁹ ⁰
|
||||
// 希腊字母等也常用于数学
|
||||
mathSymbols := "∑∫∂√∞≤≥≠≈±×÷∈∉∋∌⊂⊃⊆⊇∪∩∧∨¬∀∃∄∅∆∇∝∟∠∡∢°′″‴⁺⁻⁼⁽⁾ⁿ₀₁₂₃₄₅₆₇₈₉₊₋₌₍₎²³¹⁴⁵⁶⁷⁸⁹⁰"
|
||||
for _, m := range mathSymbols {
|
||||
if r == m {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Mathematical Operators (U+2200–U+22FF)
|
||||
if r >= 0x2200 && r <= 0x22FF {
|
||||
return true
|
||||
}
|
||||
// Supplemental Mathematical Operators (U+2A00–U+2AFF)
|
||||
if r >= 0x2A00 && r <= 0x2AFF {
|
||||
return true
|
||||
}
|
||||
// Mathematical Alphanumeric Symbols (U+1D400–U+1D7FF)
|
||||
if r >= 0x1D400 && r <= 0x1D7FF {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 辅助:判断是否为URL分隔符(tokenizer对这些优化较好)
|
||||
func isURLDelim(r rune) bool {
|
||||
// URL中常见的分隔符,tokenizer通常优化处理
|
||||
urlDelims := "/:?&=;#%"
|
||||
for _, d := range urlDelims {
|
||||
if r == d {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func EstimateTokenByModel(model, text string) int {
|
||||
// strings.Contains(model, "gpt-4o")
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
model = strings.ToLower(model)
|
||||
if strings.Contains(model, "gemini") {
|
||||
return EstimateToken(Gemini, text)
|
||||
} else if strings.Contains(model, "claude") {
|
||||
return EstimateToken(Claude, text)
|
||||
} else {
|
||||
return EstimateToken(OpenAI, text)
|
||||
}
|
||||
}
|
||||
63
service/tokenizer.go
Normal file
63
service/tokenizer.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var defaultTokenEncoder tokenizer.Codec
|
||||
|
||||
// tokenEncoderMap is used to store token encoders for different models
|
||||
var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
||||
|
||||
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
|
||||
var tokenEncoderMutex sync.RWMutex
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
defaultTokenEncoder = codec.NewCl100kBase()
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) tokenizer.Codec {
|
||||
// First, try to get the encoder from cache with read lock
|
||||
tokenEncoderMutex.RLock()
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
tokenEncoderMutex.RUnlock()
|
||||
return encoder
|
||||
}
|
||||
tokenEncoderMutex.RUnlock()
|
||||
|
||||
// If not in cache, create new encoder with write lock
|
||||
tokenEncoderMutex.Lock()
|
||||
defer tokenEncoderMutex.Unlock()
|
||||
|
||||
// Double-check if another goroutine already created the encoder
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
return encoder
|
||||
}
|
||||
|
||||
// Create new encoder
|
||||
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
|
||||
if err != nil {
|
||||
// Cache the default encoder for this model to avoid repeated failures
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
// Cache the new encoder
|
||||
tokenEncoderMap[model] = modelCodec
|
||||
return modelCodec
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tkm, _ := tokenEncoder.Count(text)
|
||||
return tkm
|
||||
}
|
||||
@@ -23,8 +23,7 @@ func ResponseText2Usage(c *gin.Context, responseText string, modeName string, pr
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm := CountTextToken(responseText, modeName)
|
||||
usage.CompletionTokens = ctkm
|
||||
usage.CompletionTokens = EstimateTokenByModel(modeName, responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage
|
||||
}
|
||||
|
||||
@@ -26,10 +26,6 @@ var defaultGeminiSettings = GeminiSettings{
|
||||
SupportedImagineModels: []string{
|
||||
"gemini-2.0-flash-exp-image-generation",
|
||||
"gemini-2.0-flash-exp",
|
||||
"gemini-3-pro-image-preview",
|
||||
"gemini-2.5-flash-image",
|
||||
"nano-banana",
|
||||
"nano-banana-pro",
|
||||
},
|
||||
ThinkingAdapterEnabled: false,
|
||||
ThinkingAdapterBudgetTokensPercentage: 0.6,
|
||||
|
||||
@@ -32,7 +32,7 @@ func GetGlobalSettings() *GlobalSettings {
|
||||
return &globalSettings
|
||||
}
|
||||
|
||||
// ShouldPreserveThinkingSuffix 判断模型是否配置为保留 thinking/-nothinking 后缀
|
||||
// ShouldPreserveThinkingSuffix 判断模型是否配置为保留 thinking/-nothinking/-low/-high/-medium 后缀
|
||||
func ShouldPreserveThinkingSuffix(modelName string) bool {
|
||||
target := strings.TrimSpace(modelName)
|
||||
if target == "" {
|
||||
|
||||
@@ -43,6 +43,7 @@ var defaultCacheRatio = map[string]float64{
|
||||
"claude-3-opus-20240229": 0.1,
|
||||
"claude-3-haiku-20240307": 0.1,
|
||||
"claude-3-5-haiku-20241022": 0.1,
|
||||
"claude-haiku-4-5-20251001": 0.1,
|
||||
"claude-3-5-sonnet-20240620": 0.1,
|
||||
"claude-3-5-sonnet-20241022": 0.1,
|
||||
"claude-3-7-sonnet-20250219": 0.1,
|
||||
@@ -64,6 +65,7 @@ var defaultCreateCacheRatio = map[string]float64{
|
||||
"claude-3-opus-20240229": 1.25,
|
||||
"claude-3-haiku-20240307": 1.25,
|
||||
"claude-3-5-haiku-20241022": 1.25,
|
||||
"claude-haiku-4-5-20251001": 1.25,
|
||||
"claude-3-5-sonnet-20240620": 1.25,
|
||||
"claude-3-5-sonnet-20241022": 1.25,
|
||||
"claude-3-7-sonnet-20250219": 1.25,
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
)
|
||||
|
||||
// from songquanpeng/one-api
|
||||
@@ -136,6 +137,7 @@ var defaultModelRatio = map[string]float64{
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
|
||||
"claude-haiku-4-5-20251001": 0.5, // $1 / 1M tokens
|
||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||
"claude-3-5-sonnet-20240620": 1.5,
|
||||
"claude-3-5-sonnet-20241022": 1.5,
|
||||
@@ -559,7 +561,7 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
|
||||
if strings.Contains(name, "claude-3") {
|
||||
return 5, true
|
||||
} else if strings.Contains(name, "claude-sonnet-4") || strings.Contains(name, "claude-opus-4") {
|
||||
} else if strings.Contains(name, "claude-sonnet-4") || strings.Contains(name, "claude-opus-4") || strings.Contains(name, "claude-haiku-4") {
|
||||
return 5, true
|
||||
} else if strings.Contains(name, "claude-instant-1") || strings.Contains(name, "claude-2") {
|
||||
return 3, true
|
||||
@@ -821,6 +823,10 @@ func FormatMatchingModelName(name string) string {
|
||||
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
|
||||
}
|
||||
|
||||
if base, _, ok := reasoning.TrimEffortSuffix(name); ok {
|
||||
name = base
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
|
||||
20
setting/reasoning/suffix.go
Normal file
20
setting/reasoning/suffix.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package reasoning
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var EffortSuffixes = []string{"-high", "-medium", "-low"}
|
||||
|
||||
// TrimEffortSuffix -> modelName level(low) exists
|
||||
func TrimEffortSuffix(modelName string) (string, string, bool) {
|
||||
suffix, found := lo.Find(EffortSuffixes, func(s string) bool {
|
||||
return strings.HasSuffix(modelName, s)
|
||||
})
|
||||
if !found {
|
||||
return modelName, "", false
|
||||
}
|
||||
return strings.TrimSuffix(modelName, suffix), strings.TrimPrefix(suffix, "-"), true
|
||||
}
|
||||
@@ -294,7 +294,7 @@ const LoginForm = () => {
|
||||
setGithubButtonDisabled(true);
|
||||
}, 20000);
|
||||
try {
|
||||
onGitHubOAuthClicked(status.github_client_id);
|
||||
onGitHubOAuthClicked(status.github_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setGithubLoading(false), 3000);
|
||||
@@ -309,7 +309,7 @@ const LoginForm = () => {
|
||||
}
|
||||
setDiscordLoading(true);
|
||||
try {
|
||||
onDiscordOAuthClicked(status.discord_client_id);
|
||||
onDiscordOAuthClicked(status.discord_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setDiscordLoading(false), 3000);
|
||||
@@ -324,7 +324,12 @@ const LoginForm = () => {
|
||||
}
|
||||
setOidcLoading(true);
|
||||
try {
|
||||
onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id);
|
||||
onOIDCClicked(
|
||||
status.oidc_authorization_endpoint,
|
||||
status.oidc_client_id,
|
||||
false,
|
||||
{ shouldLogout: true },
|
||||
);
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setOidcLoading(false), 3000);
|
||||
@@ -339,7 +344,7 @@ const LoginForm = () => {
|
||||
}
|
||||
setLinuxdoLoading(true);
|
||||
try {
|
||||
onLinuxDOOAuthClicked(status.linuxdo_client_id);
|
||||
onLinuxDOOAuthClicked(status.linuxdo_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setLinuxdoLoading(false), 3000);
|
||||
|
||||
@@ -261,7 +261,7 @@ const RegisterForm = () => {
|
||||
setGithubButtonDisabled(true);
|
||||
}, 20000);
|
||||
try {
|
||||
onGitHubOAuthClicked(status.github_client_id);
|
||||
onGitHubOAuthClicked(status.github_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
setTimeout(() => setGithubLoading(false), 3000);
|
||||
}
|
||||
@@ -270,7 +270,7 @@ const RegisterForm = () => {
|
||||
const handleDiscordClick = () => {
|
||||
setDiscordLoading(true);
|
||||
try {
|
||||
onDiscordOAuthClicked(status.discord_client_id);
|
||||
onDiscordOAuthClicked(status.discord_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
setTimeout(() => setDiscordLoading(false), 3000);
|
||||
}
|
||||
@@ -279,7 +279,12 @@ const RegisterForm = () => {
|
||||
const handleOIDCClick = () => {
|
||||
setOidcLoading(true);
|
||||
try {
|
||||
onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id);
|
||||
onOIDCClicked(
|
||||
status.oidc_authorization_endpoint,
|
||||
status.oidc_client_id,
|
||||
false,
|
||||
{ shouldLogout: true },
|
||||
);
|
||||
} finally {
|
||||
setTimeout(() => setOidcLoading(false), 3000);
|
||||
}
|
||||
@@ -288,7 +293,7 @@ const RegisterForm = () => {
|
||||
const handleLinuxDOClick = () => {
|
||||
setLinuxdoLoading(true);
|
||||
try {
|
||||
onLinuxDOOAuthClicked(status.linuxdo_client_id);
|
||||
onLinuxDOOAuthClicked(status.linuxdo_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
setTimeout(() => setLinuxdoLoading(false), 3000);
|
||||
}
|
||||
|
||||
@@ -377,7 +377,6 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
className='sidebar-container'
|
||||
style={{
|
||||
width: 'var(--sidebar-current-width)',
|
||||
background: 'var(--semi-color-bg-0)',
|
||||
}}
|
||||
>
|
||||
<SkeletonWrapper
|
||||
|
||||
@@ -39,6 +39,7 @@ import {
|
||||
TASK_ACTION_GENERATE,
|
||||
TASK_ACTION_REFERENCE_GENERATE,
|
||||
TASK_ACTION_TEXT_GENERATE,
|
||||
TASK_ACTION_REMIX_GENERATE,
|
||||
} from '../../../constants/common.constant';
|
||||
import { CHANNEL_OPTIONS } from '../../../constants/channel.constants';
|
||||
|
||||
@@ -125,6 +126,12 @@ const renderType = (type, t) => {
|
||||
{t('参照生视频')}
|
||||
</Tag>
|
||||
);
|
||||
case TASK_ACTION_REMIX_GENERATE:
|
||||
return (
|
||||
<Tag color='blue' shape='circle' prefixIcon={<Sparkles size={14} />}>
|
||||
{t('视频Remix')}
|
||||
</Tag>
|
||||
);
|
||||
default:
|
||||
return (
|
||||
<Tag color='white' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
||||
@@ -359,7 +366,8 @@ export const getTaskLogsColumns = ({
|
||||
record.action === TASK_ACTION_GENERATE ||
|
||||
record.action === TASK_ACTION_TEXT_GENERATE ||
|
||||
record.action === TASK_ACTION_FIRST_TAIL_GENERATE ||
|
||||
record.action === TASK_ACTION_REFERENCE_GENERATE;
|
||||
record.action === TASK_ACTION_REFERENCE_GENERATE ||
|
||||
record.action === TASK_ACTION_REMIX_GENERATE;
|
||||
const isSuccess = record.status === 'SUCCESS';
|
||||
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
|
||||
if (isSuccess && isVideoTask && isUrl) {
|
||||
|
||||
@@ -42,3 +42,4 @@ export const TASK_ACTION_GENERATE = 'generate';
|
||||
export const TASK_ACTION_TEXT_GENERATE = 'textGenerate';
|
||||
export const TASK_ACTION_FIRST_TAIL_GENERATE = 'firstTailGenerate';
|
||||
export const TASK_ACTION_REFERENCE_GENERATE = 'referenceGenerate';
|
||||
export const TASK_ACTION_REMIX_GENERATE = 'remixGenerate';
|
||||
|
||||
@@ -231,8 +231,22 @@ export async function getOAuthState() {
|
||||
}
|
||||
}
|
||||
|
||||
export async function onDiscordOAuthClicked(client_id) {
|
||||
const state = await getOAuthState();
|
||||
async function prepareOAuthState(options = {}) {
|
||||
const { shouldLogout = false } = options;
|
||||
if (shouldLogout) {
|
||||
try {
|
||||
await API.get('/api/user/logout', { skipErrorHandler: true });
|
||||
} catch (err) {
|
||||
|
||||
}
|
||||
localStorage.removeItem('user');
|
||||
updateAPI();
|
||||
}
|
||||
return await getOAuthState();
|
||||
}
|
||||
|
||||
export async function onDiscordOAuthClicked(client_id, options = {}) {
|
||||
const state = await prepareOAuthState(options);
|
||||
if (!state) return;
|
||||
const redirect_uri = `${window.location.origin}/oauth/discord`;
|
||||
const response_type = 'code';
|
||||
@@ -242,8 +256,13 @@ export async function onDiscordOAuthClicked(client_id) {
|
||||
);
|
||||
}
|
||||
|
||||
export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) {
|
||||
const state = await getOAuthState();
|
||||
export async function onOIDCClicked(
|
||||
auth_url,
|
||||
client_id,
|
||||
openInNewTab = false,
|
||||
options = {},
|
||||
) {
|
||||
const state = await prepareOAuthState(options);
|
||||
if (!state) return;
|
||||
const url = new URL(auth_url);
|
||||
url.searchParams.set('client_id', client_id);
|
||||
@@ -258,16 +277,19 @@ export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) {
|
||||
}
|
||||
}
|
||||
|
||||
export async function onGitHubOAuthClicked(github_client_id) {
|
||||
const state = await getOAuthState();
|
||||
export async function onGitHubOAuthClicked(github_client_id, options = {}) {
|
||||
const state = await prepareOAuthState(options);
|
||||
if (!state) return;
|
||||
window.open(
|
||||
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`,
|
||||
);
|
||||
}
|
||||
|
||||
export async function onLinuxDOOAuthClicked(linuxdo_client_id) {
|
||||
const state = await getOAuthState();
|
||||
export async function onLinuxDOOAuthClicked(
|
||||
linuxdo_client_id,
|
||||
options = { shouldLogout: false },
|
||||
) {
|
||||
const state = await prepareOAuthState(options);
|
||||
if (!state) return;
|
||||
window.open(
|
||||
`https://connect.linux.do/oauth2/authorize?response_type=code&client_id=${linuxdo_client_id}&state=${state}`,
|
||||
|
||||
@@ -1086,9 +1086,12 @@ function renderPriceSimpleCore({
|
||||
);
|
||||
const finalGroupRatio = effectiveGroupRatio;
|
||||
|
||||
const { symbol, rate } = getCurrencyConfig();
|
||||
if (modelPrice !== -1) {
|
||||
return i18next.t('价格:${{price}} * {{ratioType}}:{{ratio}}', {
|
||||
price: modelPrice,
|
||||
const displayPrice = (modelPrice * rate).toFixed(6);
|
||||
return i18next.t('价格:{{symbol}}{{price}} * {{ratioType}}:{{ratio}}', {
|
||||
symbol: symbol,
|
||||
price: displayPrice,
|
||||
ratioType: ratioLabel,
|
||||
ratio: finalGroupRatio,
|
||||
});
|
||||
|
||||
@@ -548,6 +548,7 @@
|
||||
"参数值": "Parameter value",
|
||||
"参数覆盖": "Parameters override",
|
||||
"参照生视频": "Reference video generation",
|
||||
"视频Remix": "Video remix",
|
||||
"友情链接": "Friendly links",
|
||||
"发布日期": "Publish Date",
|
||||
"发布时间": "Publish Time",
|
||||
@@ -1996,7 +1997,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "Suitable for personal use, no need to set model price.",
|
||||
"适用于为多个用户提供服务的场景": "Suitable for scenarios where multiple users are provided.",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "Suitable for scenarios where the system functions are displayed, providing basic feature demonstrations.",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Adapt to -thinking, -thinking-budget number, and -nothinking suffixes",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Adapt to -thinking, -thinking-budget number, -nothinking, and -low/-medium/-high suffixes",
|
||||
"选择充值套餐": "Choose a top-up package",
|
||||
"选择充值额度": "Select recharge amount",
|
||||
"选择分组": "Select group",
|
||||
@@ -2178,4 +2179,4 @@
|
||||
"默认测试模型": "Default Test Model",
|
||||
"默认补全倍率": "Default completion ratio"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -551,6 +551,7 @@
|
||||
"参数值": "Valeur du paramètre",
|
||||
"参数覆盖": "Remplacement des paramètres",
|
||||
"参照生视频": "Générer une vidéo par référence",
|
||||
"视频Remix": "Remix vidéo",
|
||||
"友情链接": "Liens amicaux",
|
||||
"发布日期": "Date de publication",
|
||||
"发布时间": "Heure de publication",
|
||||
@@ -2006,7 +2007,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "Adapté à un usage personnel, pas besoin de définir le prix du modèle.",
|
||||
"适用于为多个用户提供服务的场景": "Adapté aux scénarios où plusieurs utilisateurs sont fournis.",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "Adapté aux scénarios où les fonctions du système sont affichées, fournissant des démonstrations de fonctionnalités de base.",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Adapter les suffixes -thinking, -thinking-budget et -nothinking",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Adapter les suffixes -thinking, -thinking-budget, -nothinking et -low/-medium/-high",
|
||||
"选择充值额度": "Sélectionner le montant de la recharge",
|
||||
"选择分组": "Sélectionner un groupe",
|
||||
"选择同步来源": "Sélectionner la source de synchronisation",
|
||||
@@ -2227,4 +2228,4 @@
|
||||
"随机种子 (留空为随机)": "Graine aléatoire (laisser vide pour aléatoire)",
|
||||
"默认补全倍率": "Taux de complétion par défaut"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,6 +510,7 @@
|
||||
"参数值": "パラメータ値",
|
||||
"参数覆盖": "パラメータの上書き",
|
||||
"参照生视频": "参照動画生成",
|
||||
"视频Remix": "動画リミックス",
|
||||
"友情链接": "関連リンク",
|
||||
"发布日期": "公開日",
|
||||
"发布时间": "公開日時",
|
||||
@@ -1903,7 +1904,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "個人利用のシナリオに適しており、モデル料金の設定は不要です",
|
||||
"适用于为多个用户提供服务的场景": "複数のユーザーにサービスを提供するシナリオに適しています",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "システムの機能を紹介するシナリオに適しており、基本的な機能のデモンストレーションを提供します",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "-thinking、-thinking-予算数値、-nothinkingサフィックスに対応",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "-thinking、-thinking-予算数値、-nothinking、および -low/-medium/-high サフィックスに対応",
|
||||
"选择充值额度": "チャージ額を選択",
|
||||
"选择分组": "グループを選択",
|
||||
"选择同步来源": "同期ソースを選択",
|
||||
@@ -2126,4 +2127,4 @@
|
||||
"可选,用于复现结果": "オプション、結果の再現用",
|
||||
"随机种子 (留空为随机)": "ランダムシード(空欄でランダム)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -555,6 +555,7 @@
|
||||
"参数值": "Значение параметра",
|
||||
"参数覆盖": "Переопределение параметров",
|
||||
"参照生视频": "Ссылка на генерацию видео",
|
||||
"视频Remix": "Видео ремикс",
|
||||
"友情链接": "Дружественные ссылки",
|
||||
"发布日期": "Дата публикации",
|
||||
"发布时间": "Время публикации",
|
||||
@@ -2017,7 +2018,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "Подходит для сценариев личного использования, не требует установки цен на модели",
|
||||
"适用于为多个用户提供服务的场景": "Подходит для сценариев предоставления услуг нескольким пользователям",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "Подходит для сценариев демонстрации системных функций, предоставляет демонстрацию базовых функций",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Адаптация суффиксов -thinking, -thinking-бюджетные-цифры и -nothinking",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Адаптация суффиксов -thinking, -thinking-бюджетные-цифры, -nothinking и -low/-medium/-high",
|
||||
"选择充值额度": "Выберите сумму пополнения",
|
||||
"选择分组": "Выберите группу",
|
||||
"选择同步来源": "Выберите источник синхронизации",
|
||||
@@ -2237,4 +2238,4 @@
|
||||
"可选,用于复现结果": "Необязательно, для воспроизводимых результатов",
|
||||
"随机种子 (留空为随机)": "Случайное зерно (оставьте пустым для случайного)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,6 +510,7 @@
|
||||
"参数值": "Giá trị tham số",
|
||||
"参数覆盖": "Ghi đè tham số",
|
||||
"参照生视频": "Tạo video tham chiếu",
|
||||
"视频Remix": "Remix video",
|
||||
"友情链接": "Liên kết thân thiện",
|
||||
"发布日期": "Ngày xuất bản",
|
||||
"发布时间": "Thời gian xuất bản",
|
||||
@@ -2197,7 +2198,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "Phù hợp cho mục đích sử dụng cá nhân, không cần đặt giá mô hình.",
|
||||
"适用于为多个用户提供服务的场景": "Phù hợp cho các kịch bản cung cấp dịch vụ cho nhiều người dùng.",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "Phù hợp cho các kịch bản hiển thị chức năng hệ thống, cung cấp bản demo chức năng cơ bản.",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Thích ứng với các hậu tố -thinking, -thinking-budget number và -nothinking",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "Thích ứng với các hậu tố -thinking, -thinking-budget number, -nothinking và -low/-medium/-high",
|
||||
"选择充值额度": "Chọn hạn ngạch nạp tiền",
|
||||
"选择同步来源": "Chọn nguồn đồng bộ",
|
||||
"选择同步渠道": "Chọn kênh đồng bộ",
|
||||
@@ -2737,4 +2738,4 @@
|
||||
"可选,用于复现结果": "Tùy chọn, để tái tạo kết quả",
|
||||
"随机种子 (留空为随机)": "Hạt giống ngẫu nhiên (để trống cho ngẫu nhiên)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -543,6 +543,7 @@
|
||||
"参数值": "参数值",
|
||||
"参数覆盖": "参数覆盖",
|
||||
"参照生视频": "参照生视频",
|
||||
"视频Remix": "视频 Remix",
|
||||
"友情链接": "友情链接",
|
||||
"发布日期": "发布日期",
|
||||
"发布时间": "发布时间",
|
||||
@@ -1984,7 +1985,7 @@
|
||||
"适用于个人使用的场景,不需要设置模型价格": "适用于个人使用的场景,不需要设置模型价格",
|
||||
"适用于为多个用户提供服务的场景": "适用于为多个用户提供服务的场景",
|
||||
"适用于展示系统功能的场景,提供基础功能演示": "适用于展示系统功能的场景,提供基础功能演示",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "适配 -thinking、-thinking-预算数字 和 -nothinking 后缀",
|
||||
"适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "适配 -thinking、-thinking-预算数字、-nothinking 以及 -low/-medium/-high 后缀",
|
||||
"选择充值额度": "选择充值额度",
|
||||
"选择分组": "选择分组",
|
||||
"选择同步来源": "选择同步来源",
|
||||
@@ -2204,4 +2205,4 @@
|
||||
"可选,用于复现结果": "可选,用于复现结果",
|
||||
"随机种子 (留空为随机)": "随机种子 (留空为随机)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user