mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-17 23:47:26 +00:00
Compare commits
44 Commits
v0.10.0-al
...
v0.10.1-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
147659fb6e | ||
|
|
e9fb2ccdd1 | ||
|
|
48a17efade | ||
|
|
7e1d1350c7 | ||
|
|
01b4039e96 | ||
|
|
e1bee48152 | ||
|
|
4e69c98b42 | ||
|
|
ca29fc5702 | ||
|
|
fca015c6c4 | ||
|
|
23292a5ae9 | ||
|
|
e346f0bf16 | ||
|
|
cae05c068c | ||
|
|
78c10209c0 | ||
|
|
4ffd54c50d | ||
|
|
08466358b2 | ||
|
|
5212fbd73d | ||
|
|
b0e120dcab | ||
|
|
9561c7b50f | ||
|
|
1cb2b6f882 | ||
|
|
5889571108 | ||
|
|
2e33948842 | ||
|
|
d1aaa07ad7 | ||
|
|
ea70c20f8e | ||
|
|
c7539d11a0 | ||
|
|
3ebc713327 | ||
|
|
72d2a94b0d | ||
|
|
12a5c7ce5e | ||
|
|
5eae6a3874 | ||
|
|
7b108a6900 | ||
|
|
3d282ac548 | ||
|
|
121746a79e | ||
|
|
c3c119a9b4 | ||
|
|
6d6e5b3337 | ||
|
|
d64205e35a | ||
|
|
0b9f6a58bc | ||
|
|
293a5de0f8 | ||
|
|
c07347f24f | ||
|
|
896e4ac671 | ||
|
|
7d1bad1b37 | ||
|
|
8e7be25429 | ||
|
|
2e37347851 | ||
|
|
45556c961f | ||
|
|
c6125eccb1 | ||
|
|
138810f19c |
@@ -14,7 +14,7 @@ ENV GO111MODULE=on CGO_ENABLED=0
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
ENV GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH:-amd64}
|
||||
|
||||
ENV GOEXPERIMENT=greenteagc
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
@@ -25,10 +25,11 @@ COPY . .
|
||||
COPY --from=builder /build/dist ./web/dist
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
FROM alpine
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apk upgrade --no-cache \
|
||||
&& apk add --no-cache ca-certificates tzdata \
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates tzdata libasan8 wget \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& update-ca-certificates
|
||||
|
||||
COPY --from=builder2 /build/new-api /
|
||||
|
||||
@@ -131,6 +131,8 @@ func initConstantEnv() {
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
// 任务轮询时查询的最大数量
|
||||
constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
|
||||
|
||||
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
|
||||
if soraPatchStr != "" {
|
||||
|
||||
@@ -23,11 +23,11 @@ func Marshal(v any) ([]byte, error) {
|
||||
}
|
||||
|
||||
func GetJsonType(data json.RawMessage) string {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 {
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if len(trimmed) == 0 {
|
||||
return "unknown"
|
||||
}
|
||||
firstChar := bytes.TrimSpace(data)[0]
|
||||
firstChar := trimmed[0]
|
||||
switch firstChar {
|
||||
case '{':
|
||||
return "object"
|
||||
|
||||
@@ -3,12 +3,19 @@ package common
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var (
|
||||
maskURLPattern = regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||
maskDomainPattern = regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
||||
maskIPPattern = regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||
)
|
||||
|
||||
func GetStringIfEmpty(str string, defaultValue string) string {
|
||||
@@ -19,12 +26,10 @@ func GetStringIfEmpty(str string, defaultValue string) string {
|
||||
}
|
||||
|
||||
func GetRandomString(length int) string {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
key := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
||||
if length <= 0 {
|
||||
return ""
|
||||
}
|
||||
return string(key)
|
||||
return lo.RandomString(length, lo.AlphanumericCharset)
|
||||
}
|
||||
|
||||
func MapToJsonStr(m map[string]interface{}) string {
|
||||
@@ -170,8 +175,7 @@ func maskHostForPlainDomain(domain string) string {
|
||||
// api.openai.com -> ***.***.com
|
||||
func MaskSensitiveInfo(str string) string {
|
||||
// Mask URLs
|
||||
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||
str = maskURLPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return urlStr
|
||||
@@ -224,14 +228,12 @@ func MaskSensitiveInfo(str string) string {
|
||||
})
|
||||
|
||||
// Mask domain names without protocol (like openai.com, www.openai.com)
|
||||
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
||||
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
|
||||
str = maskDomainPattern.ReplaceAllStringFunc(str, func(domain string) string {
|
||||
return maskHostForPlainDomain(domain)
|
||||
})
|
||||
|
||||
// Mask IP addresses
|
||||
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||
str = maskIPPattern.ReplaceAllString(str, "***.***.***.***")
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
@@ -18,8 +18,10 @@ const (
|
||||
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||
ContextKeyTokenCrossGroupRetry ContextKey = "token_cross_group_retry"
|
||||
|
||||
/* channel related keys */
|
||||
ContextKeyAutoGroupIndex ContextKey = "auto_group_index"
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelName ContextKey = "channel_name"
|
||||
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
|
||||
|
||||
@@ -15,6 +15,7 @@ var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
var TaskQueryLimit int
|
||||
|
||||
// temporary variable for sora patch, will be removed in future
|
||||
var TaskPricePatches []string
|
||||
|
||||
@@ -15,6 +15,7 @@ const (
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
TaskActionFirstTailGenerate = "firstTailGenerate"
|
||||
TaskActionReferenceGenerate = "referenceGenerate"
|
||||
TaskActionRemix = "remixGenerate"
|
||||
)
|
||||
|
||||
var SunoModel2Action = map[string]string{
|
||||
|
||||
@@ -165,6 +165,30 @@ func GetAllChannels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) {
|
||||
var headers http.Header
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
headers = GetClaudeAuthHeader(key)
|
||||
default:
|
||||
headers = GetAuthHeader(key)
|
||||
}
|
||||
|
||||
headerOverride := channel.GetHeaderOverride()
|
||||
for k, v := range headerOverride {
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid header override for key %s", k)
|
||||
}
|
||||
if strings.Contains(str, "{api_key}") {
|
||||
str = strings.ReplaceAll(str, "{api_key}", key)
|
||||
}
|
||||
headers.Set(k, str)
|
||||
}
|
||||
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
func FetchUpstreamModels(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
@@ -223,14 +247,13 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
|
||||
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
||||
var body []byte
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
body, err = GetResponseBody("GET", url, channel, GetClaudeAuthHeader(key))
|
||||
default:
|
||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
|
||||
headers, err := buildFetchModelsHeaders(channel, key)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := GetResponseBody("GET", url, channel, headers)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
|
||||
@@ -285,7 +285,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
|
||||
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
if service.ShouldDisableChannel(channelError.ChannelType, err) && channelError.AutoBan {
|
||||
gopool.Go(func() {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
})
|
||||
|
||||
@@ -29,7 +29,7 @@ func UpdateTaskBulk() {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
common.SysLog("任务进度轮询开始")
|
||||
ctx := context.TODO()
|
||||
allTasks := model.GetAllUnFinishSyncTasks(500)
|
||||
allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
|
||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||
for _, t := range allTasks {
|
||||
platformTask[t.Platform] = append(platformTask[t.Platform], t)
|
||||
@@ -116,9 +116,10 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
if adaptor == nil {
|
||||
return errors.New("adaptor not found")
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
|
||||
"ids": taskIds,
|
||||
})
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
return err
|
||||
|
||||
@@ -67,6 +67,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
@@ -76,7 +77,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
})
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ func AddToken(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
if len(token.Name) > 50 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌名称过长",
|
||||
@@ -208,7 +208,7 @@ func UpdateToken(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
if len(token.Name) > 50 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌名称过长",
|
||||
@@ -248,6 +248,7 @@ func UpdateToken(c *gin.Context) {
|
||||
cleanToken.ModelLimits = token.ModelLimits
|
||||
cleanToken.AllowIps = token.AllowIps
|
||||
cleanToken.Group = token.Group
|
||||
cleanToken.CrossGroupRetry = token.CrossGroupRetry
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -75,11 +77,22 @@ func VideoProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
var videoURL string
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
proxy := channel.GetSetting().Proxy
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to create proxy client",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, "", nil)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 60*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
|
||||
@@ -35,10 +35,11 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string)
|
||||
return "", fmt.Errorf("api key not available for task")
|
||||
}
|
||||
|
||||
proxy := channel.GetSetting().Proxy
|
||||
resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
|
||||
"task_id": task.TaskID,
|
||||
"action": task.Action,
|
||||
})
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch task failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -27,8 +27,11 @@ type ImageRequest struct {
|
||||
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
||||
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||
// Stream bool `json:"stream,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
Image json.RawMessage `json:"image,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
// zhipu 4v
|
||||
WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"`
|
||||
UserId json.RawMessage `json:"user_id,omitempty"`
|
||||
Image json.RawMessage `json:"image,omitempty"`
|
||||
// 用匿名参数接收额外参数
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
}
|
||||
|
||||
@@ -83,6 +83,7 @@ type GeneralOpenAIRequest struct {
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"`
|
||||
ChatTemplateKwargs json.RawMessage `json:"chat_template_kwargs,omitempty"`
|
||||
// ollama Params
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
// baidu v2
|
||||
|
||||
13
go.mod
13
go.mod
@@ -33,7 +33,7 @@ require (
|
||||
github.com/mewkiz/flac v1.0.13
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pquerna/otp v1.5.0
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/samber/lo v1.52.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/stripe/stripe-go/v81 v81.4.0
|
||||
@@ -99,6 +99,7 @@ require (
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
@@ -110,13 +111,13 @@ require (
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.21.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.22.5 // indirect
|
||||
modernc.org/mathutil v1.5.0 // indirect
|
||||
modernc.org/memory v1.5.0 // indirect
|
||||
modernc.org/sqlite v1.23.1 // indirect
|
||||
modernc.org/libc v1.66.10 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.40.1 // indirect
|
||||
)
|
||||
|
||||
15
go.sum
15
go.sum
@@ -120,6 +120,7 @@ github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -193,6 +194,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
@@ -219,6 +222,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw=
|
||||
github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
@@ -285,6 +290,8 @@ golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
@@ -345,9 +352,17 @@ gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
|
||||
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
|
||||
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
|
||||
modernc.org/libc v1.66.10 h1:yZkb3YeLx4oynyR+iUsXsybsX4Ubx7MQlSYEw4yj59A=
|
||||
modernc.org/libc v1.66.10/go.mod h1:8vGSEwvoUoltr4dlywvHqjtAqHBaw0j1jI7iFBTAr2I=
|
||||
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
|
||||
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
|
||||
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
|
||||
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
|
||||
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
|
||||
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
|
||||
|
||||
@@ -308,6 +308,7 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("token_group", token.Group)
|
||||
c.Set("token_cross_group_retry", token.CrossGroupRetry)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
|
||||
@@ -181,6 +181,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/videos/") && strings.HasSuffix(c.Request.URL.Path, "/remix") {
|
||||
relayMode := relayconstant.RelayModeVideoSubmit
|
||||
c.Set("relay_mode", relayMode)
|
||||
shouldSelectChannel = false
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/videos") {
|
||||
//curl https://api.openai.com/v1/videos \
|
||||
// -H "Authorization: Bearer $OPENAI_API_KEY" \
|
||||
|
||||
@@ -27,6 +27,7 @@ type Token struct {
|
||||
AllowIps *string `json:"allow_ips" gorm:"default:''"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
Group string `json:"group" gorm:"default:''"`
|
||||
CrossGroupRetry bool `json:"cross_group_retry" gorm:"default:false"` // 跨分组重试,仅auto分组有效
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
@@ -185,7 +186,7 @@ func (token *Token) Update() (err error) {
|
||||
}
|
||||
}()
|
||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
|
||||
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
|
||||
"model_limits_enabled", "model_limits", "allow_ips", "group", "cross_group_retry").Updates(token).Error
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ type TaskAdaptor interface {
|
||||
GetChannelName() string
|
||||
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
|
||||
|
||||
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
@@ -129,7 +130,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
@@ -141,7 +142,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
@@ -151,6 +152,24 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
}
|
||||
}
|
||||
|
||||
// buildAwsRequestBody prepares the payload for AWS requests, applying passthrough rules when enabled.
|
||||
func buildAwsRequestBody(c *gin.Context, info *relaycommon.RelayInfo, awsClaudeReq any) ([]byte, error) {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get request body for pass-through fail")
|
||||
}
|
||||
var data map[string]interface{}
|
||||
if err := common.Unmarshal(body, &data); err != nil {
|
||||
return nil, errors.Wrap(err, "pass-through unmarshal request body fail")
|
||||
}
|
||||
delete(data, "model")
|
||||
delete(data, "stream")
|
||||
return common.Marshal(data)
|
||||
}
|
||||
return common.Marshal(awsClaudeReq)
|
||||
}
|
||||
|
||||
func getAwsRegionPrefix(awsRegionId string) string {
|
||||
parts := strings.Split(awsRegionId, "-")
|
||||
regionPrefix := ""
|
||||
|
||||
@@ -9,6 +9,7 @@ var ModelList = []string{
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
|
||||
@@ -306,10 +306,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
request.Temperature = nil
|
||||
}
|
||||
|
||||
// gpt-5系列模型适配 归零不再支持的参数
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
if info.UpstreamModelName != "gpt-5-chat-latest" {
|
||||
request.Temperature = nil
|
||||
}
|
||||
request.Temperature = nil
|
||||
request.TopP = 0 // oai 的 top_p 默认值是 1.0,但是为了 omitempty 属性直接不传,这里显式设置为 0
|
||||
request.LogProbs = false
|
||||
}
|
||||
|
||||
// 转换模型推理力度后缀
|
||||
|
||||
@@ -172,7 +172,7 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
|
||||
shouldSendLastResp *bool) error {
|
||||
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -393,7 +393,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask 查询任务状态
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -408,7 +408,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -146,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -163,7 +163,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -200,7 +200,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -223,7 +223,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("x-goog-api-key", key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
|
||||
@@ -110,7 +110,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return hResp.TaskID, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -126,7 +126,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -210,7 +210,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -251,7 +251,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
}
|
||||
}
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -199,7 +199,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -228,7 +228,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
@@ -67,11 +69,30 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.apiKey = info.ApiKey
|
||||
}
|
||||
|
||||
func validateRemixRequest(c *gin.Context) *dto.TaskError {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return validateRemixRequest(c)
|
||||
}
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
|
||||
}
|
||||
|
||||
@@ -125,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -140,7 +161,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -132,7 +132,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
|
||||
byteBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
@@ -153,11 +153,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
|
||||
|
||||
@@ -120,7 +120,11 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
proxy := ""
|
||||
if info != nil {
|
||||
proxy = info.ChannelSetting.Proxy
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
@@ -216,7 +220,7 @@ func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generat
|
||||
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -249,7 +253,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
@@ -261,7 +265,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
|
||||
@@ -188,7 +188,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return vResp.TaskId, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
@@ -204,7 +204,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -181,6 +182,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
} else if baseModel, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,8 +36,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -63,6 +62,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/embeddings", specialPlan.OpenAIBaseURL), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/api/paas/v4/embeddings", baseURL), nil
|
||||
case relayconstant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/paas/v4/images/generations", baseURL), nil
|
||||
default:
|
||||
if hasSpecialPlan && specialPlan.OpenAIBaseURL != "" {
|
||||
return fmt.Sprintf("%s/chat/completions", specialPlan.OpenAIBaseURL), nil
|
||||
@@ -114,6 +115,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
if info.RelayMode == relayconstant.RelayModeImagesGenerations {
|
||||
return zhipu4vImageHandler(c, resp, info)
|
||||
}
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
|
||||
127
relay/channel/zhipu_4v/image.go
Normal file
127
relay/channel/zhipu_4v/image.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package zhipu_4v
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type zhipuImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
WatermarkEnabled *bool `json:"watermark_enabled,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type zhipuImageResponse struct {
|
||||
Created *int64 `json:"created,omitempty"`
|
||||
Data []zhipuImageData `json:"data,omitempty"`
|
||||
ContentFilter any `json:"content_filter,omitempty"`
|
||||
Usage *dto.Usage `json:"usage,omitempty"`
|
||||
Error *zhipuImageError `json:"error,omitempty"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
ExtendParam map[string]string `json:"extendParam,omitempty"`
|
||||
}
|
||||
|
||||
type zhipuImageError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type zhipuImageData struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
ImageUrl string `json:"image_url,omitempty"`
|
||||
B64Json string `json:"b64_json,omitempty"`
|
||||
B64Image string `json:"b64_image,omitempty"`
|
||||
}
|
||||
|
||||
type openAIImagePayload struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []openAIImageData `json:"data"`
|
||||
}
|
||||
|
||||
type openAIImageData struct {
|
||||
B64Json string `json:"b64_json"`
|
||||
}
|
||||
|
||||
func zhipu4vImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var zhipuResp zhipuImageResponse
|
||||
if err := common.Unmarshal(responseBody, &zhipuResp); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if zhipuResp.Error != nil && zhipuResp.Error.Message != "" {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
Message: zhipuResp.Error.Message,
|
||||
Type: "zhipu_image_error",
|
||||
Code: zhipuResp.Error.Code,
|
||||
}, resp.StatusCode)
|
||||
}
|
||||
|
||||
payload := openAIImagePayload{}
|
||||
if zhipuResp.Created != nil && *zhipuResp.Created != 0 {
|
||||
payload.Created = *zhipuResp.Created
|
||||
} else {
|
||||
payload.Created = info.StartTime.Unix()
|
||||
}
|
||||
for _, data := range zhipuResp.Data {
|
||||
url := data.Url
|
||||
if url == "" {
|
||||
url = data.ImageUrl
|
||||
}
|
||||
if url == "" {
|
||||
logger.LogWarn(c, "zhipu_image_missing_url")
|
||||
continue
|
||||
}
|
||||
|
||||
var b64 string
|
||||
switch {
|
||||
case data.B64Json != "":
|
||||
b64 = data.B64Json
|
||||
case data.B64Image != "":
|
||||
b64 = data.B64Image
|
||||
default:
|
||||
_, downloaded, err := service.GetImageFromUrl(url)
|
||||
if err != nil {
|
||||
logger.LogError(c, "zhipu_image_get_b64_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64 = downloaded
|
||||
}
|
||||
|
||||
if b64 == "" {
|
||||
logger.LogWarn(c, "zhipu_image_empty_b64")
|
||||
continue
|
||||
}
|
||||
|
||||
imageData := openAIImageData{
|
||||
B64Json: b64,
|
||||
}
|
||||
payload.Data = append(payload.Data, imageData)
|
||||
}
|
||||
|
||||
jsonResp, err := common.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResp)
|
||||
|
||||
return &dto.Usage{}, nil
|
||||
}
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
|
||||
|
||||
type ConditionOperation struct {
|
||||
Path string `json:"path"` // JSON路径
|
||||
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
|
||||
@@ -186,8 +188,7 @@ func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperat
|
||||
}
|
||||
|
||||
func processNegativeIndex(jsonStr string, path string) string {
|
||||
re := regexp.MustCompile(`\.(-\d+)`)
|
||||
matches := re.FindAllStringSubmatch(path, -1)
|
||||
matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return path
|
||||
|
||||
@@ -99,7 +99,10 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
// check if free model pre-consume is disabled
|
||||
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
|
||||
// if model price or ratio is 0, do not pre-consume quota
|
||||
if usePrice {
|
||||
if groupRatioInfo.GroupRatio == 0 {
|
||||
preConsumedQuota = 0
|
||||
freeModel = true
|
||||
} else if usePrice {
|
||||
if modelPrice == 0 {
|
||||
preConsumedQuota = 0
|
||||
freeModel = true
|
||||
|
||||
@@ -32,7 +32,94 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
if info.TaskRelayInfo == nil {
|
||||
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
|
||||
info.Action = constant.TaskActionRemix
|
||||
}
|
||||
|
||||
// 提取 remix 任务的 video_id
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
videoID := c.Param("video_id")
|
||||
if strings.TrimSpace(videoID) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
info.OriginTaskID = videoID
|
||||
}
|
||||
|
||||
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||
|
||||
// 获取原始任务信息
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if info.OriginModelName == "" {
|
||||
if originTask.Properties.OriginModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.OriginModelName
|
||||
} else if originTask.Properties.UpstreamModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.UpstreamModelName
|
||||
} else {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
if m, ok := taskData["model"].(string); ok && m != "" {
|
||||
info.OriginModelName = m
|
||||
platform = originTask.Platform
|
||||
}
|
||||
}
|
||||
}
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
key, _, newAPIError := channel.GetNextEnabledKey()
|
||||
if newAPIError != nil {
|
||||
taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
|
||||
return
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
|
||||
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
info.ChannelType = channel.Type
|
||||
info.ApiKey = key
|
||||
platform = originTask.Platform
|
||||
}
|
||||
|
||||
// 使用原始任务的参数
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
secondsStr, _ := taskData["seconds"].(string)
|
||||
seconds, _ := strconv.Atoi(secondsStr)
|
||||
if seconds <= 0 {
|
||||
seconds = 4
|
||||
}
|
||||
sizeStr, _ := taskData["size"].(string)
|
||||
if info.PriceData.OtherRatios == nil {
|
||||
info.PriceData.OtherRatios = map[string]float64{}
|
||||
}
|
||||
info.PriceData.OtherRatios["seconds"] = float64(seconds)
|
||||
info.PriceData.OtherRatios["size"] = 1
|
||||
if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
|
||||
info.PriceData.OtherRatios["size"] = 1.666667
|
||||
}
|
||||
}
|
||||
}
|
||||
if platform == "" {
|
||||
platform = GetTaskPlatform(c)
|
||||
}
|
||||
@@ -94,34 +181,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
return
|
||||
}
|
||||
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
|
||||
}
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
}
|
||||
}
|
||||
|
||||
// build body
|
||||
requestBody, err := adaptor.BuildRequestBody(c, info)
|
||||
if err != nil {
|
||||
@@ -326,6 +385,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
if channelModel.GetBaseURL() != "" {
|
||||
baseURL = channelModel.GetBaseURL()
|
||||
}
|
||||
proxy := channelModel.GetSetting().Proxy
|
||||
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||
if adaptor == nil {
|
||||
return
|
||||
@@ -333,7 +393,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||
"task_id": originTask.TaskID,
|
||||
"action": originTask.Action,
|
||||
})
|
||||
}, proxy)
|
||||
if err2 != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func SetVideoRouter(router *gin.Engine) {
|
||||
videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
|
||||
videoV1Router.POST("/video/generations", controller.RelayTask)
|
||||
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
|
||||
videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)
|
||||
}
|
||||
// openai compatible API video routes
|
||||
// docs: https://platform.openai.com/docs/api-reference/videos/create
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
|
||||
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName string, retry int) (*model.Channel, string, error) {
|
||||
var channel *model.Channel
|
||||
var err error
|
||||
@@ -20,15 +21,30 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName stri
|
||||
if len(setting.GetAutoGroups()) == 0 {
|
||||
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||||
}
|
||||
for _, autoGroup := range GetUserAutoGroup(userGroup) {
|
||||
logger.LogDebug(c, "Auto selecting group:", autoGroup)
|
||||
channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, retry)
|
||||
autoGroups := GetUserAutoGroup(userGroup)
|
||||
// 如果 token 启用了跨分组重试,获取上次失败的 auto group 索引,从下一个开始尝试
|
||||
startIndex := 0
|
||||
crossGroupRetry := common.GetContextKeyBool(c, constant.ContextKeyTokenCrossGroupRetry)
|
||||
if crossGroupRetry && retry > 0 {
|
||||
logger.LogDebug(c, "Auto group retry cross group, retry: %d", retry)
|
||||
if lastIndex, exists := common.GetContextKey(c, constant.ContextKeyAutoGroupIndex); exists {
|
||||
if idx, ok := lastIndex.(int); ok {
|
||||
startIndex = idx + 1
|
||||
}
|
||||
}
|
||||
logger.LogDebug(c, "Auto group retry cross group, start index: %d", startIndex)
|
||||
}
|
||||
for i := startIndex; i < len(autoGroups); i++ {
|
||||
autoGroup := autoGroups[i]
|
||||
logger.LogDebug(c, "Auto selecting group: %s", autoGroup)
|
||||
channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, 0)
|
||||
if channel == nil {
|
||||
continue
|
||||
} else {
|
||||
c.Set("auto_group", autoGroup)
|
||||
common.SetContextKey(c, constant.ContextKeyAutoGroupIndex, i)
|
||||
selectGroup = autoGroup
|
||||
logger.LogDebug(c, "Auto selected group:", autoGroup)
|
||||
logger.LogDebug(c, "Auto selected group: %s", autoGroup)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,6 +201,10 @@ func generateStopBlock(index int) *dto.ClaudeResponse {
|
||||
}
|
||||
|
||||
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
|
||||
if info.ClaudeConvertInfo.Done {
|
||||
return nil
|
||||
}
|
||||
|
||||
var claudeResponses []*dto.ClaudeResponse
|
||||
if info.SendResponseCount == 1 {
|
||||
msg := &dto.ClaudeMediaMessage{
|
||||
@@ -218,45 +222,117 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
Type: "message_start",
|
||||
Message: msg,
|
||||
})
|
||||
claudeResponses = append(claudeResponses)
|
||||
//claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
// Type: "ping",
|
||||
//})
|
||||
if openAIResponse.IsToolCall() {
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
|
||||
var toolCall dto.ToolCallResponse
|
||||
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolCall = openAIResponse.Choices[0].Delta.ToolCalls[0]
|
||||
} else {
|
||||
first := openAIResponse.GetFirstToolCall()
|
||||
if first != nil {
|
||||
toolCall = *first
|
||||
} else {
|
||||
toolCall = dto.ToolCallResponse{}
|
||||
}
|
||||
}
|
||||
resp := &dto.ClaudeResponse{
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Id: openAIResponse.GetFirstToolCall().ID,
|
||||
Id: toolCall.ID,
|
||||
Type: "tool_use",
|
||||
Name: openAIResponse.GetFirstToolCall().Function.Name,
|
||||
Name: toolCall.Function.Name,
|
||||
Input: map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
resp.SetIndex(0)
|
||||
claudeResponses = append(claudeResponses, resp)
|
||||
// 首块包含工具 delta,则追加 input_json_delta
|
||||
if toolCall.Function.Arguments != "" {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "input_json_delta",
|
||||
PartialJson: &toolCall.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
|
||||
}
|
||||
// 判断首个响应是否存在内容(非标准的 OpenAI 响应)
|
||||
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 {
|
||||
if len(openAIResponse.Choices) > 0 {
|
||||
reasoning := openAIResponse.Choices[0].Delta.GetReasoningContent()
|
||||
content := openAIResponse.Choices[0].Delta.GetContentString()
|
||||
|
||||
if reasoning != "" {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking",
|
||||
Thinking: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking_delta",
|
||||
Thinking: &reasoning,
|
||||
},
|
||||
})
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
|
||||
} else if content != "" {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "text_delta",
|
||||
Text: common.GetPointer[string](content),
|
||||
},
|
||||
})
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||
}
|
||||
}
|
||||
|
||||
// 如果首块就带 finish_reason,需要立即发送停止块
|
||||
if len(openAIResponse.Choices) > 0 && openAIResponse.Choices[0].FinishReason != nil && *openAIResponse.Choices[0].FinishReason != "" {
|
||||
info.FinishReason = *openAIResponse.Choices[0].FinishReason
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
oaiUsage := openAIResponse.Usage
|
||||
if oaiUsage == nil {
|
||||
oaiUsage = info.ClaudeConvertInfo.Usage
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
})
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
Type: "message_stop",
|
||||
})
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "text_delta",
|
||||
Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()),
|
||||
},
|
||||
})
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
}
|
||||
return claudeResponses
|
||||
}
|
||||
@@ -264,7 +340,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
if len(openAIResponse.Choices) == 0 {
|
||||
// no choices
|
||||
// 可能为非标准的 OpenAI 响应,判断是否已经完成
|
||||
if info.Done {
|
||||
if info.ClaudeConvertInfo.Done {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
if oaiUsage != nil {
|
||||
@@ -288,16 +364,110 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
return claudeResponses
|
||||
} else {
|
||||
chosenChoice := openAIResponse.Choices[0]
|
||||
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
|
||||
// should be done
|
||||
doneChunk := chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != ""
|
||||
if doneChunk {
|
||||
info.FinishReason = *chosenChoice.FinishReason
|
||||
if !info.Done {
|
||||
return claudeResponses
|
||||
}
|
||||
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
var isEmpty bool
|
||||
claudeResponse.Type = "content_block_delta"
|
||||
if len(chosenChoice.Delta.ToolCalls) > 0 {
|
||||
toolCalls := chosenChoice.Delta.ToolCalls
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
info.ClaudeConvertInfo.Index++
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
|
||||
|
||||
for i, toolCall := range toolCalls {
|
||||
blockIndex := info.ClaudeConvertInfo.Index
|
||||
if toolCall.Index != nil {
|
||||
blockIndex = *toolCall.Index
|
||||
} else if len(toolCalls) > 1 {
|
||||
blockIndex = info.ClaudeConvertInfo.Index + i
|
||||
}
|
||||
|
||||
idx := blockIndex
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &idx,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Id: toolCall.ID,
|
||||
Type: "tool_use",
|
||||
Name: toolCall.Function.Name,
|
||||
Input: map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &idx,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "input_json_delta",
|
||||
PartialJson: &toolCall.Function.Arguments,
|
||||
},
|
||||
})
|
||||
|
||||
info.ClaudeConvertInfo.Index = blockIndex
|
||||
}
|
||||
} else {
|
||||
reasoning := chosenChoice.Delta.GetReasoningContent()
|
||||
textContent := chosenChoice.Delta.GetContentString()
|
||||
if reasoning != "" || textContent != "" {
|
||||
if reasoning != "" {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking",
|
||||
Thinking: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "thinking_delta",
|
||||
Thinking: &reasoning,
|
||||
}
|
||||
} else {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
|
||||
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeThinking || info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeTools {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
info.ClaudeConvertInfo.Index++
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "text_delta",
|
||||
Text: common.GetPointer[string](textContent),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
isEmpty = true
|
||||
}
|
||||
}
|
||||
if info.Done {
|
||||
|
||||
claudeResponse.Index = &info.ClaudeConvertInfo.Index
|
||||
if !isEmpty && claudeResponse.Delta != nil {
|
||||
claudeResponses = append(claudeResponses, &claudeResponse)
|
||||
}
|
||||
|
||||
if doneChunk || info.ClaudeConvertInfo.Done {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
oaiUsage := openAIResponse.Usage
|
||||
if oaiUsage == nil {
|
||||
oaiUsage = info.ClaudeConvertInfo.Usage
|
||||
}
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
@@ -315,83 +485,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_stop",
|
||||
})
|
||||
} else {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
var isEmpty bool
|
||||
claudeResponse.Type = "content_block_delta"
|
||||
if len(chosenChoice.Delta.ToolCalls) > 0 {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
info.ClaudeConvertInfo.Index++
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Id: openAIResponse.GetFirstToolCall().ID,
|
||||
Type: "tool_use",
|
||||
Name: openAIResponse.GetFirstToolCall().Function.Name,
|
||||
Input: map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
|
||||
// tools delta
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "input_json_delta",
|
||||
PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
|
||||
}
|
||||
} else {
|
||||
reasoning := chosenChoice.Delta.GetReasoningContent()
|
||||
textContent := chosenChoice.Delta.GetContentString()
|
||||
if reasoning != "" || textContent != "" {
|
||||
if reasoning != "" {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
|
||||
//info.ClaudeConvertInfo.Index++
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking",
|
||||
Thinking: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
|
||||
// text delta
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "thinking_delta",
|
||||
Thinking: &reasoning,
|
||||
}
|
||||
} else {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
|
||||
if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
info.ClaudeConvertInfo.Index++
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||
// text delta
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "text_delta",
|
||||
Text: common.GetPointer[string](textContent),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
isEmpty = true
|
||||
}
|
||||
}
|
||||
claudeResponse.Index = &info.ClaudeConvertInfo.Index
|
||||
if !isEmpty {
|
||||
claudeResponses = append(claudeResponses, &claudeResponse)
|
||||
}
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
return claudeResponses
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -35,9 +35,9 @@ func checkRedirect(req *http.Request, via []*http.Request) error {
|
||||
|
||||
func InitHttpClient() {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
|
||||
if common.RelayTimeout == 0 {
|
||||
@@ -58,6 +58,14 @@ func GetHttpClient() *http.Client {
|
||||
return httpClient
|
||||
}
|
||||
|
||||
// GetHttpClientWithProxy returns the default client or a proxy-enabled one when proxyURL is provided.
|
||||
func GetHttpClientWithProxy(proxyURL string) (*http.Client, error) {
|
||||
if proxyURL == "" {
|
||||
return GetHttpClient(), nil
|
||||
}
|
||||
return NewProxyHttpClient(proxyURL)
|
||||
}
|
||||
|
||||
// ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化
|
||||
func ResetProxyClientCache() {
|
||||
proxyClientLock.Lock()
|
||||
@@ -92,10 +100,10 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
case "http", "https":
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
},
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
@@ -127,9 +135,9 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
|
||||
@@ -317,7 +317,7 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
|
||||
for i, file := range meta.Files {
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if common.IsOpenAITextModel(info.UpstreamModelName) {
|
||||
if common.IsOpenAITextModel(model) {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
|
||||
|
||||
@@ -43,6 +43,7 @@ var defaultCacheRatio = map[string]float64{
|
||||
"claude-3-opus-20240229": 0.1,
|
||||
"claude-3-haiku-20240307": 0.1,
|
||||
"claude-3-5-haiku-20241022": 0.1,
|
||||
"claude-haiku-4-5-20251001": 0.1,
|
||||
"claude-3-5-sonnet-20240620": 0.1,
|
||||
"claude-3-5-sonnet-20241022": 0.1,
|
||||
"claude-3-7-sonnet-20250219": 0.1,
|
||||
@@ -64,6 +65,7 @@ var defaultCreateCacheRatio = map[string]float64{
|
||||
"claude-3-opus-20240229": 1.25,
|
||||
"claude-3-haiku-20240307": 1.25,
|
||||
"claude-3-5-haiku-20241022": 1.25,
|
||||
"claude-haiku-4-5-20251001": 1.25,
|
||||
"claude-3-5-sonnet-20240620": 1.25,
|
||||
"claude-3-5-sonnet-20241022": 1.25,
|
||||
"claude-3-7-sonnet-20250219": 1.25,
|
||||
|
||||
@@ -137,6 +137,7 @@ var defaultModelRatio = map[string]float64{
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
|
||||
"claude-haiku-4-5-20251001": 0.5, // $1 / 1M tokens
|
||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||
"claude-3-5-sonnet-20240620": 1.5,
|
||||
"claude-3-5-sonnet-20241022": 1.5,
|
||||
@@ -560,7 +561,7 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
|
||||
if strings.Contains(name, "claude-3") {
|
||||
return 5, true
|
||||
} else if strings.Contains(name, "claude-sonnet-4") || strings.Contains(name, "claude-opus-4") {
|
||||
} else if strings.Contains(name, "claude-sonnet-4") || strings.Contains(name, "claude-opus-4") || strings.Contains(name, "claude-haiku-4") {
|
||||
return 5, true
|
||||
} else if strings.Contains(name, "claude-instant-1") || strings.Contains(name, "claude-2") {
|
||||
return 3, true
|
||||
|
||||
@@ -294,7 +294,7 @@ const LoginForm = () => {
|
||||
setGithubButtonDisabled(true);
|
||||
}, 20000);
|
||||
try {
|
||||
onGitHubOAuthClicked(status.github_client_id);
|
||||
onGitHubOAuthClicked(status.github_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setGithubLoading(false), 3000);
|
||||
@@ -309,7 +309,7 @@ const LoginForm = () => {
|
||||
}
|
||||
setDiscordLoading(true);
|
||||
try {
|
||||
onDiscordOAuthClicked(status.discord_client_id);
|
||||
onDiscordOAuthClicked(status.discord_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setDiscordLoading(false), 3000);
|
||||
@@ -324,7 +324,12 @@ const LoginForm = () => {
|
||||
}
|
||||
setOidcLoading(true);
|
||||
try {
|
||||
onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id);
|
||||
onOIDCClicked(
|
||||
status.oidc_authorization_endpoint,
|
||||
status.oidc_client_id,
|
||||
false,
|
||||
{ shouldLogout: true },
|
||||
);
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setOidcLoading(false), 3000);
|
||||
@@ -339,7 +344,7 @@ const LoginForm = () => {
|
||||
}
|
||||
setLinuxdoLoading(true);
|
||||
try {
|
||||
onLinuxDOOAuthClicked(status.linuxdo_client_id);
|
||||
onLinuxDOOAuthClicked(status.linuxdo_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setLinuxdoLoading(false), 3000);
|
||||
|
||||
@@ -261,7 +261,7 @@ const RegisterForm = () => {
|
||||
setGithubButtonDisabled(true);
|
||||
}, 20000);
|
||||
try {
|
||||
onGitHubOAuthClicked(status.github_client_id);
|
||||
onGitHubOAuthClicked(status.github_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
setTimeout(() => setGithubLoading(false), 3000);
|
||||
}
|
||||
@@ -270,7 +270,7 @@ const RegisterForm = () => {
|
||||
const handleDiscordClick = () => {
|
||||
setDiscordLoading(true);
|
||||
try {
|
||||
onDiscordOAuthClicked(status.discord_client_id);
|
||||
onDiscordOAuthClicked(status.discord_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
setTimeout(() => setDiscordLoading(false), 3000);
|
||||
}
|
||||
@@ -279,7 +279,12 @@ const RegisterForm = () => {
|
||||
const handleOIDCClick = () => {
|
||||
setOidcLoading(true);
|
||||
try {
|
||||
onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id);
|
||||
onOIDCClicked(
|
||||
status.oidc_authorization_endpoint,
|
||||
status.oidc_client_id,
|
||||
false,
|
||||
{ shouldLogout: true },
|
||||
);
|
||||
} finally {
|
||||
setTimeout(() => setOidcLoading(false), 3000);
|
||||
}
|
||||
@@ -288,7 +293,7 @@ const RegisterForm = () => {
|
||||
const handleLinuxDOClick = () => {
|
||||
setLinuxdoLoading(true);
|
||||
try {
|
||||
onLinuxDOOAuthClicked(status.linuxdo_client_id);
|
||||
onLinuxDOOAuthClicked(status.linuxdo_client_id, { shouldLogout: true });
|
||||
} finally {
|
||||
setTimeout(() => setLinuxdoLoading(false), 3000);
|
||||
}
|
||||
|
||||
@@ -377,7 +377,6 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
className='sidebar-container'
|
||||
style={{
|
||||
width: 'var(--sidebar-current-width)',
|
||||
background: 'var(--semi-color-bg-0)',
|
||||
}}
|
||||
>
|
||||
<SkeletonWrapper
|
||||
|
||||
@@ -39,6 +39,7 @@ import {
|
||||
TASK_ACTION_GENERATE,
|
||||
TASK_ACTION_REFERENCE_GENERATE,
|
||||
TASK_ACTION_TEXT_GENERATE,
|
||||
TASK_ACTION_REMIX_GENERATE,
|
||||
} from '../../../constants/common.constant';
|
||||
import { CHANNEL_OPTIONS } from '../../../constants/channel.constants';
|
||||
|
||||
@@ -125,6 +126,12 @@ const renderType = (type, t) => {
|
||||
{t('参照生视频')}
|
||||
</Tag>
|
||||
);
|
||||
case TASK_ACTION_REMIX_GENERATE:
|
||||
return (
|
||||
<Tag color='blue' shape='circle' prefixIcon={<Sparkles size={14} />}>
|
||||
{t('视频Remix')}
|
||||
</Tag>
|
||||
);
|
||||
default:
|
||||
return (
|
||||
<Tag color='white' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
||||
@@ -359,7 +366,8 @@ export const getTaskLogsColumns = ({
|
||||
record.action === TASK_ACTION_GENERATE ||
|
||||
record.action === TASK_ACTION_TEXT_GENERATE ||
|
||||
record.action === TASK_ACTION_FIRST_TAIL_GENERATE ||
|
||||
record.action === TASK_ACTION_REFERENCE_GENERATE;
|
||||
record.action === TASK_ACTION_REFERENCE_GENERATE ||
|
||||
record.action === TASK_ACTION_REMIX_GENERATE;
|
||||
const isSuccess = record.status === 'SUCCESS';
|
||||
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
|
||||
if (isSuccess && isVideoTask && isUrl) {
|
||||
|
||||
@@ -88,7 +88,7 @@ const renderStatus = (text, record, t) => {
|
||||
};
|
||||
|
||||
// Render group column
|
||||
const renderGroupColumn = (text, t) => {
|
||||
const renderGroupColumn = (text, record, t) => {
|
||||
if (text === 'auto') {
|
||||
return (
|
||||
<Tooltip
|
||||
@@ -98,8 +98,8 @@ const renderGroupColumn = (text, t) => {
|
||||
position='top'
|
||||
>
|
||||
<Tag color='white' shape='circle'>
|
||||
{' '}
|
||||
{t('智能熔断')}{' '}
|
||||
{t('智能熔断')}
|
||||
{record && record.cross_group_retry ? `(${t('跨分组')})` : ''}
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
);
|
||||
@@ -455,7 +455,7 @@ export const getTokensColumns = ({
|
||||
title: t('分组'),
|
||||
dataIndex: 'group',
|
||||
key: 'group',
|
||||
render: (text) => renderGroupColumn(text, t),
|
||||
render: (text, record) => renderGroupColumn(text, record, t),
|
||||
},
|
||||
{
|
||||
title: t('密钥'),
|
||||
|
||||
@@ -73,6 +73,7 @@ const EditTokenModal = (props) => {
|
||||
model_limits: [],
|
||||
allow_ips: '',
|
||||
group: '',
|
||||
cross_group_retry: false,
|
||||
tokenCount: 1,
|
||||
});
|
||||
|
||||
@@ -377,6 +378,16 @@ const EditTokenModal = (props) => {
|
||||
/>
|
||||
)}
|
||||
</Col>
|
||||
<Col span={24} style={{ display: values.group === 'auto' ? 'block' : 'none' }}>
|
||||
<Form.Switch
|
||||
field='cross_group_retry'
|
||||
label={t('跨分组重试')}
|
||||
size='default'
|
||||
extraText={t(
|
||||
'开启后,当前分组渠道失败时会按顺序尝试下一个分组的渠道',
|
||||
)}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} sm={24} md={24} lg={10} xl={10}>
|
||||
<Form.DatePicker
|
||||
field='expired_time'
|
||||
@@ -499,7 +510,7 @@ const EditTokenModal = (props) => {
|
||||
<Form.Switch
|
||||
field='unlimited_quota'
|
||||
label={t('无限额度')}
|
||||
size='large'
|
||||
size='default'
|
||||
extraText={t(
|
||||
'令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制',
|
||||
)}
|
||||
|
||||
@@ -42,3 +42,4 @@ export const TASK_ACTION_GENERATE = 'generate';
|
||||
export const TASK_ACTION_TEXT_GENERATE = 'textGenerate';
|
||||
export const TASK_ACTION_FIRST_TAIL_GENERATE = 'firstTailGenerate';
|
||||
export const TASK_ACTION_REFERENCE_GENERATE = 'referenceGenerate';
|
||||
export const TASK_ACTION_REMIX_GENERATE = 'remixGenerate';
|
||||
|
||||
@@ -231,8 +231,22 @@ export async function getOAuthState() {
|
||||
}
|
||||
}
|
||||
|
||||
export async function onDiscordOAuthClicked(client_id) {
|
||||
const state = await getOAuthState();
|
||||
async function prepareOAuthState(options = {}) {
|
||||
const { shouldLogout = false } = options;
|
||||
if (shouldLogout) {
|
||||
try {
|
||||
await API.get('/api/user/logout', { skipErrorHandler: true });
|
||||
} catch (err) {
|
||||
|
||||
}
|
||||
localStorage.removeItem('user');
|
||||
updateAPI();
|
||||
}
|
||||
return await getOAuthState();
|
||||
}
|
||||
|
||||
export async function onDiscordOAuthClicked(client_id, options = {}) {
|
||||
const state = await prepareOAuthState(options);
|
||||
if (!state) return;
|
||||
const redirect_uri = `${window.location.origin}/oauth/discord`;
|
||||
const response_type = 'code';
|
||||
@@ -242,8 +256,13 @@ export async function onDiscordOAuthClicked(client_id) {
|
||||
);
|
||||
}
|
||||
|
||||
export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) {
|
||||
const state = await getOAuthState();
|
||||
export async function onOIDCClicked(
|
||||
auth_url,
|
||||
client_id,
|
||||
openInNewTab = false,
|
||||
options = {},
|
||||
) {
|
||||
const state = await prepareOAuthState(options);
|
||||
if (!state) return;
|
||||
const url = new URL(auth_url);
|
||||
url.searchParams.set('client_id', client_id);
|
||||
@@ -258,16 +277,19 @@ export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) {
|
||||
}
|
||||
}
|
||||
|
||||
export async function onGitHubOAuthClicked(github_client_id) {
|
||||
const state = await getOAuthState();
|
||||
export async function onGitHubOAuthClicked(github_client_id, options = {}) {
|
||||
const state = await prepareOAuthState(options);
|
||||
if (!state) return;
|
||||
window.open(
|
||||
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`,
|
||||
);
|
||||
}
|
||||
|
||||
export async function onLinuxDOOAuthClicked(linuxdo_client_id) {
|
||||
const state = await getOAuthState();
|
||||
export async function onLinuxDOOAuthClicked(
|
||||
linuxdo_client_id,
|
||||
options = { shouldLogout: false },
|
||||
) {
|
||||
const state = await prepareOAuthState(options);
|
||||
if (!state) return;
|
||||
window.open(
|
||||
`https://connect.linux.do/oauth2/authorize?response_type=code&client_id=${linuxdo_client_id}&state=${state}`,
|
||||
|
||||
@@ -1086,9 +1086,12 @@ function renderPriceSimpleCore({
|
||||
);
|
||||
const finalGroupRatio = effectiveGroupRatio;
|
||||
|
||||
const { symbol, rate } = getCurrencyConfig();
|
||||
if (modelPrice !== -1) {
|
||||
return i18next.t('价格:${{price}} * {{ratioType}}:{{ratio}}', {
|
||||
price: modelPrice,
|
||||
const displayPrice = (modelPrice * rate).toFixed(6);
|
||||
return i18next.t('价格:{{symbol}}{{price}} * {{ratioType}}:{{ratio}}', {
|
||||
symbol: symbol,
|
||||
price: displayPrice,
|
||||
ratioType: ratioLabel,
|
||||
ratio: finalGroupRatio,
|
||||
});
|
||||
|
||||
@@ -548,6 +548,7 @@
|
||||
"参数值": "Parameter value",
|
||||
"参数覆盖": "Parameters override",
|
||||
"参照生视频": "Reference video generation",
|
||||
"视频Remix": "Video remix",
|
||||
"友情链接": "Friendly links",
|
||||
"发布日期": "Publish Date",
|
||||
"发布时间": "Publish Time",
|
||||
@@ -2176,6 +2177,9 @@
|
||||
"默认区域,如: us-central1": "Default region, e.g.: us-central1",
|
||||
"默认折叠侧边栏": "Default collapse sidebar",
|
||||
"默认测试模型": "Default Test Model",
|
||||
"默认补全倍率": "Default completion ratio"
|
||||
"默认补全倍率": "Default completion ratio",
|
||||
"跨分组重试": "Cross-group retry",
|
||||
"跨分组": "Cross-group",
|
||||
"开启后,当前分组渠道失败时会按顺序尝试下一个分组的渠道": "After enabling, when the current group channel fails, it will try the next group's channel in order"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -551,6 +551,7 @@
|
||||
"参数值": "Valeur du paramètre",
|
||||
"参数覆盖": "Remplacement des paramètres",
|
||||
"参照生视频": "Générer une vidéo par référence",
|
||||
"视频Remix": "Remix vidéo",
|
||||
"友情链接": "Liens amicaux",
|
||||
"发布日期": "Date de publication",
|
||||
"发布时间": "Heure de publication",
|
||||
@@ -2225,6 +2226,9 @@
|
||||
"默认助手消息": "Bonjour ! Comment puis-je vous aider aujourd'hui ?",
|
||||
"可选,用于复现结果": "Optionnel, pour des résultats reproductibles",
|
||||
"随机种子 (留空为随机)": "Graine aléatoire (laisser vide pour aléatoire)",
|
||||
"默认补全倍率": "Taux de complétion par défaut"
|
||||
"默认补全倍率": "Taux de complétion par défaut",
|
||||
"跨分组重试": "Nouvelle tentative inter-groupes",
|
||||
"跨分组": "Inter-groupes",
|
||||
"开启后,当前分组渠道失败时会按顺序尝试下一个分组的渠道": "Après activation, lorsque le canal du groupe actuel échoue, il essaiera le canal du groupe suivant dans l'ordre"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,6 +510,7 @@
|
||||
"参数值": "パラメータ値",
|
||||
"参数覆盖": "パラメータの上書き",
|
||||
"参照生视频": "参照動画生成",
|
||||
"视频Remix": "動画リミックス",
|
||||
"友情链接": "関連リンク",
|
||||
"发布日期": "公開日",
|
||||
"发布时间": "公開日時",
|
||||
@@ -2124,6 +2125,9 @@
|
||||
"默认用户消息": "こんにちは",
|
||||
"默认助手消息": "こんにちは!何かお手伝いできることはありますか?",
|
||||
"可选,用于复现结果": "オプション、結果の再現用",
|
||||
"随机种子 (留空为随机)": "ランダムシード(空欄でランダム)"
|
||||
"随机种子 (留空为随机)": "ランダムシード(空欄でランダム)",
|
||||
"跨分组重试": "グループ間リトライ",
|
||||
"跨分组": "グループ間",
|
||||
"开启后,当前分组渠道失败时会按顺序尝试下一个分组的渠道": "有効にすると、現在のグループチャネルが失敗した場合、次のグループのチャネルを順番に試行します"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -555,6 +555,7 @@
|
||||
"参数值": "Значение параметра",
|
||||
"参数覆盖": "Переопределение параметров",
|
||||
"参照生视频": "Ссылка на генерацию видео",
|
||||
"视频Remix": "Видео ремикс",
|
||||
"友情链接": "Дружественные ссылки",
|
||||
"发布日期": "Дата публикации",
|
||||
"发布时间": "Время публикации",
|
||||
@@ -2235,6 +2236,9 @@
|
||||
"默认用户消息": "Здравствуйте",
|
||||
"默认助手消息": "Здравствуйте! Чем я могу вам помочь?",
|
||||
"可选,用于复现结果": "Необязательно, для воспроизводимых результатов",
|
||||
"随机种子 (留空为随机)": "Случайное зерно (оставьте пустым для случайного)"
|
||||
"随机种子 (留空为随机)": "Случайное зерно (оставьте пустым для случайного)",
|
||||
"跨分组重试": "Повторная попытка между группами",
|
||||
"跨分组": "Межгрупповой",
|
||||
"开启后,当前分组渠道失败时会按顺序尝试下一个分组的渠道": "После включения, когда канал текущей группы не работает, он будет пытаться использовать канал следующей группы по порядку"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,6 +510,7 @@
|
||||
"参数值": "Giá trị tham số",
|
||||
"参数覆盖": "Ghi đè tham số",
|
||||
"参照生视频": "Tạo video tham chiếu",
|
||||
"视频Remix": "Remix video",
|
||||
"友情链接": "Liên kết thân thiện",
|
||||
"发布日期": "Ngày xuất bản",
|
||||
"发布时间": "Thời gian xuất bản",
|
||||
@@ -2735,6 +2736,9 @@
|
||||
"默认用户消息": "Xin chào",
|
||||
"默认助手消息": "Xin chào! Tôi có thể giúp gì cho bạn?",
|
||||
"可选,用于复现结果": "Tùy chọn, để tái tạo kết quả",
|
||||
"随机种子 (留空为随机)": "Hạt giống ngẫu nhiên (để trống cho ngẫu nhiên)"
|
||||
"随机种子 (留空为随机)": "Hạt giống ngẫu nhiên (để trống cho ngẫu nhiên)",
|
||||
"跨分组重试": "Thử lại giữa các nhóm",
|
||||
"跨分组": "Giữa các nhóm",
|
||||
"开启后,当前分组渠道失败时会按顺序尝试下一个分组的渠道": "Sau khi bật, khi kênh nhóm hiện tại thất bại, nó sẽ thử kênh của nhóm tiếp theo theo thứ tự"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -543,6 +543,7 @@
|
||||
"参数值": "参数值",
|
||||
"参数覆盖": "参数覆盖",
|
||||
"参照生视频": "参照生视频",
|
||||
"视频Remix": "视频 Remix",
|
||||
"友情链接": "友情链接",
|
||||
"发布日期": "发布日期",
|
||||
"发布时间": "发布时间",
|
||||
@@ -2202,6 +2203,9 @@
|
||||
"默认用户消息": "你好",
|
||||
"默认助手消息": "你好!有什么我可以帮助你的吗?",
|
||||
"可选,用于复现结果": "可选,用于复现结果",
|
||||
"随机种子 (留空为随机)": "随机种子 (留空为随机)"
|
||||
"随机种子 (留空为随机)": "随机种子 (留空为随机)",
|
||||
"跨分组重试": "跨分组重试",
|
||||
"跨分组": "跨分组",
|
||||
"开启后,当前分组渠道失败时会按顺序尝试下一个分组的渠道": "开启后,当前分组渠道失败时会按顺序尝试下一个分组的渠道"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,6 +108,7 @@ code {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
transition: width 0.3s ease;
|
||||
background: var(--semi-color-bg-0);
|
||||
}
|
||||
|
||||
.sidebar-nav {
|
||||
@@ -221,6 +222,22 @@ code {
|
||||
padding-top: 12px;
|
||||
}
|
||||
|
||||
@media (max-width: 767px) {
|
||||
.sidebar-container {
|
||||
background: var(--semi-color-bg-1);
|
||||
border-right: 1px solid var(--semi-color-border);
|
||||
}
|
||||
|
||||
.sidebar-nav {
|
||||
background: var(--semi-color-bg-1);
|
||||
}
|
||||
|
||||
.sidebar-collapse-button {
|
||||
background-color: var(--semi-color-bg-1);
|
||||
box-shadow: 0 -10px 10px -5px var(--semi-color-bg-1);
|
||||
}
|
||||
}
|
||||
|
||||
/* ==================== 聊天界面样式 ==================== */
|
||||
.semi-chat {
|
||||
padding-top: 0 !important;
|
||||
|
||||
Reference in New Issue
Block a user