mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-17 23:47:26 +00:00
Compare commits
226 Commits
v0.8.8.0-a
...
v0.8.9.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cc81deb69 | ||
|
|
1d578b73ce | ||
|
|
ca1f3c6e4c | ||
|
|
f942361f7b | ||
|
|
02fd80b703 | ||
|
|
d6b03d4760 | ||
|
|
cb75e25a1a | ||
|
|
9572e16dcb | ||
|
|
459fce196f | ||
|
|
ada434fb20 | ||
|
|
71ba3fa310 | ||
|
|
0727353afa | ||
|
|
fd2ff2a973 | ||
|
|
50f9195f2d | ||
|
|
a47fc5a76b | ||
|
|
72ffe61ad1 | ||
|
|
ea8cac7c10 | ||
|
|
8639699d49 | ||
|
|
f242220132 | ||
|
|
55dbdba636 | ||
|
|
03b670971b | ||
|
|
24860fdc05 | ||
|
|
229dd3a123 | ||
|
|
919eacd907 | ||
|
|
4cec55c9a4 | ||
|
|
aa8ec92976 | ||
|
|
44da9c9a28 | ||
|
|
c776a1edff | ||
|
|
a5cbef1a61 | ||
|
|
ae22ba593a | ||
|
|
4ad4ad7088 | ||
|
|
8bccda5649 | ||
|
|
2a804b6c02 | ||
|
|
3b61617cb1 | ||
|
|
ec28671aed | ||
|
|
c7c7229b8b | ||
|
|
2efc133997 | ||
|
|
df72ac1215 | ||
|
|
2fc0d7b2a7 | ||
|
|
3a9e394814 | ||
|
|
3d9d3da1ae | ||
|
|
8abd764eca | ||
|
|
7a31e481a6 | ||
|
|
b70d2655ed | ||
|
|
15cb2f1a9e | ||
|
|
2471367c92 | ||
|
|
962c40c1a7 | ||
|
|
f6c7828160 | ||
|
|
8b57da9a2b | ||
|
|
daa7a13505 | ||
|
|
cda4790219 | ||
|
|
c6bb1dcc0e | ||
|
|
f8e1b084cd | ||
|
|
7d869c9af1 | ||
|
|
1690b05629 | ||
|
|
563825492e | ||
|
|
eee37017e1 | ||
|
|
29ec328f46 | ||
|
|
b843bb8286 | ||
|
|
77975529fe | ||
|
|
9de65184ab | ||
|
|
4912b1e632 | ||
|
|
cf91cf1b14 | ||
|
|
d0fb54fbfe | ||
|
|
346b869d60 | ||
|
|
ac158e227e | ||
|
|
d96f846648 | ||
|
|
473f3b6f3e | ||
|
|
7f1a471751 | ||
|
|
bbac342f3a | ||
|
|
4b3702987f | ||
|
|
6341847203 | ||
|
|
4e75a9b3b3 | ||
|
|
26f44b8d4b | ||
|
|
8fba0017c7 | ||
|
|
7f4056abc9 | ||
|
|
0257918571 | ||
|
|
1d4e746c4f | ||
|
|
677a02c632 | ||
|
|
177b891905 | ||
|
|
c4dcc6df9c | ||
|
|
7ddd314015 | ||
|
|
ba7325c884 | ||
|
|
3c4b1ef127 | ||
|
|
18c630e5e4 | ||
|
|
0ea0a432bf | ||
|
|
8a964efbed | ||
|
|
865bb7aad8 | ||
|
|
d9c1fb5244 | ||
|
|
71c39c9893 | ||
|
|
38067f1ddc | ||
|
|
7cfeb6e87c | ||
|
|
0a231a8acc | ||
|
|
1cea7a0314 | ||
|
|
ed95a9f2b2 | ||
|
|
76d71a032a | ||
|
|
38bff1a0e0 | ||
|
|
0c0caad827 | ||
|
|
4445e5891f | ||
|
|
f46cefbd39 | ||
|
|
feef022303 | ||
|
|
6a80c18189 | ||
|
|
6616bb4048 | ||
|
|
ac5f51c3d5 | ||
|
|
587888a688 | ||
|
|
7370b4fbcd | ||
|
|
94506bee99 | ||
|
|
7c814a5fd9 | ||
|
|
24aa29598a | ||
|
|
d61a862fa2 | ||
|
|
e29c6b44c7 | ||
|
|
327a0ca323 | ||
|
|
a746309a8e | ||
|
|
d247f90571 | ||
|
|
edbe18b157 | ||
|
|
d951485431 | ||
|
|
306a1a3f57 | ||
|
|
2431de78fa | ||
|
|
49abd6aaf3 | ||
|
|
f3a1f98add | ||
|
|
1ccc728e5d | ||
|
|
11ee80d377 | ||
|
|
512850e83d | ||
|
|
0e9c3cde7c | ||
|
|
43263a3bc8 | ||
|
|
8cce3cc84a | ||
|
|
faaa5a2949 | ||
|
|
c00f5a17c8 | ||
|
|
9c079d04a8 | ||
|
|
c9d4cdc57e | ||
|
|
12b4e80d4b | ||
|
|
6e2a04f374 | ||
|
|
3feeca627c | ||
|
|
8357b15fec | ||
|
|
ecdd9d1ccb | ||
|
|
fc69f4f757 | ||
|
|
5e70274003 | ||
|
|
57b194c63f | ||
|
|
10b04416c1 | ||
|
|
9f6027325c | ||
|
|
b64c8ea56b | ||
|
|
e74d3f4a8f | ||
|
|
8a2aebf845 | ||
|
|
984c8ee477 | ||
|
|
398ae7156b | ||
|
|
d85eeabf11 | ||
|
|
6a62654759 | ||
|
|
c056a7ad7c | ||
|
|
c784a70277 | ||
|
|
e6c87907d5 | ||
|
|
71e9290142 | ||
|
|
74ec34da67 | ||
|
|
7188749cb3 | ||
|
|
c28add55db | ||
|
|
78f34a8245 | ||
|
|
97d6f10f15 | ||
|
|
afefc4caca | ||
|
|
6abbd036f8 | ||
|
|
ef0db0f914 | ||
|
|
e01986fdd4 | ||
|
|
a0c6ebe2d8 | ||
|
|
d2183af23f | ||
|
|
953f1bdc3c | ||
|
|
e2429f20f8 | ||
|
|
f0945da4fb | ||
|
|
8df3de9ae5 | ||
|
|
277cc1cac8 | ||
|
|
07a92293e4 | ||
|
|
9730b9ba2d | ||
|
|
508799c452 | ||
|
|
5e81ef4a44 | ||
|
|
eb42eb6f27 | ||
|
|
232612898b | ||
|
|
6a37efb871 | ||
|
|
af59b61f8a | ||
|
|
f995e31d04 | ||
|
|
9758a9e60d | ||
|
|
6f56696af2 | ||
|
|
345fbdf3d2 | ||
|
|
ce031f7d15 | ||
|
|
bd6b811183 | ||
|
|
196bafff03 | ||
|
|
82bf149ade | ||
|
|
f20b558e22 | ||
|
|
54447bf227 | ||
|
|
fc09051d8b | ||
|
|
1f5ef24ecd | ||
|
|
b1faf42529 | ||
|
|
6a85206e32 | ||
|
|
e3d3e697d3 | ||
|
|
db9b333930 | ||
|
|
f7b284ad73 | ||
|
|
1c1e3386f8 | ||
|
|
a385c8a6f8 | ||
|
|
4cc76f2deb | ||
|
|
b41c24d653 | ||
|
|
fe9acb6c59 | ||
|
|
75548c449b | ||
|
|
bca78beb1b | ||
|
|
9110611489 | ||
|
|
0b1a1ca064 | ||
|
|
52a9cee0e1 | ||
|
|
34d45bb3b8 | ||
|
|
9b73696a98 | ||
|
|
aecdbfacf3 | ||
|
|
1c25e29999 | ||
|
|
5ceb898676 | ||
|
|
2fe3706ef0 | ||
|
|
1880164e29 | ||
|
|
e417c269eb | ||
|
|
59a76b3970 | ||
|
|
53be79a00e | ||
|
|
c4b69b341a | ||
|
|
a99dbc78c9 | ||
|
|
8a54512037 | ||
|
|
3f96bd9509 | ||
|
|
6d06cb8fb3 | ||
|
|
4247883173 | ||
|
|
bf491d6fe7 | ||
|
|
c15e753a0a | ||
|
|
902aee4e6b | ||
|
|
b964f755ec | ||
|
|
a044070e1d | ||
|
|
6103888610 | ||
|
|
7af3fb5ae4 | ||
|
|
3ac54b2178 |
@@ -65,6 +65,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = constant.APITypeCoze
|
||||
case constant.ChannelTypeJimeng:
|
||||
apiType = constant.APITypeJimeng
|
||||
case constant.ChannelTypeMoonshot:
|
||||
apiType = constant.APITypeMoonshot
|
||||
}
|
||||
if apiType == -1 {
|
||||
return constant.APITypeOpenAI, false
|
||||
|
||||
@@ -83,6 +83,7 @@ var GitHubClientId = ""
|
||||
var GitHubClientSecret = ""
|
||||
var LinuxDOClientId = ""
|
||||
var LinuxDOClientSecret = ""
|
||||
var LinuxDOMinimumTrustLevel = 0
|
||||
|
||||
var WeChatServerAddress = ""
|
||||
var WeChatServerToken = ""
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type stringWriter interface {
|
||||
@@ -52,6 +53,8 @@ type CustomEvent struct {
|
||||
Id string
|
||||
Retry uint
|
||||
Data interface{}
|
||||
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
func encode(writer io.Writer, event CustomEvent) error {
|
||||
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
|
||||
}
|
||||
|
||||
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||
r.Mutex.Lock()
|
||||
defer r.Mutex.Unlock()
|
||||
header := w.Header()
|
||||
header["Content-Type"] = contentType
|
||||
|
||||
|
||||
32
common/endpoint_defaults.go
Normal file
32
common/endpoint_defaults.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
// EndpointInfo 描述单个端点的默认请求信息
|
||||
// path: 上游路径
|
||||
// method: HTTP 请求方式,例如 POST/GET
|
||||
// 目前均为 POST,后续可扩展
|
||||
//
|
||||
// json 标签用于直接序列化到 API 输出
|
||||
// 例如:{"path":"/v1/chat/completions","method":"POST"}
|
||||
|
||||
type EndpointInfo struct {
|
||||
Path string `json:"path"`
|
||||
Method string `json:"method"`
|
||||
}
|
||||
|
||||
// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method
|
||||
var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
|
||||
constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"},
|
||||
constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"},
|
||||
constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"},
|
||||
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
|
||||
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
|
||||
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
|
||||
}
|
||||
|
||||
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
|
||||
func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) {
|
||||
info, ok := defaultEndpointInfoMap[et]
|
||||
return info, ok
|
||||
}
|
||||
@@ -31,6 +31,9 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//if DebugEnabled {
|
||||
// println("UnmarshalBodyReusable request body:", string(requestBody))
|
||||
//}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = Unmarshal(requestBody, &v)
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -95,3 +98,95 @@ func GetJsonString(data any) string {
|
||||
b, _ := json.Marshal(data)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string
|
||||
// Example:
|
||||
// http://example.com -> http://***.com
|
||||
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
|
||||
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
|
||||
// 192.168.1.1 -> ***.***.***.***
|
||||
func MaskSensitiveInfo(str string) string {
|
||||
// Mask URLs
|
||||
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
host := u.Host
|
||||
if host == "" {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
// Split host by dots
|
||||
parts := strings.Split(host, ".")
|
||||
if len(parts) < 2 {
|
||||
// If less than 2 parts, just mask the whole host
|
||||
return u.Scheme + "://***" + u.Path
|
||||
}
|
||||
|
||||
// Keep the TLD (Top Level Domain) and mask the rest
|
||||
var maskedHost string
|
||||
if len(parts) == 2 {
|
||||
// example.com -> ***.com
|
||||
maskedHost = "***." + parts[len(parts)-1]
|
||||
} else {
|
||||
// Handle cases like sub.domain.co.uk or api.example.com
|
||||
// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
|
||||
lastPart := parts[len(parts)-1]
|
||||
secondLastPart := parts[len(parts)-2]
|
||||
|
||||
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
|
||||
// Likely country code TLD like co.uk, com.cn
|
||||
maskedHost = "***." + secondLastPart + "." + lastPart
|
||||
} else {
|
||||
// Regular TLD like .com, .org
|
||||
maskedHost = "***." + lastPart
|
||||
}
|
||||
}
|
||||
|
||||
result := u.Scheme + "://" + maskedHost
|
||||
|
||||
// Mask path
|
||||
if u.Path != "" && u.Path != "/" {
|
||||
pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
|
||||
maskedPathParts := make([]string, len(pathParts))
|
||||
for i := range pathParts {
|
||||
if pathParts[i] != "" {
|
||||
maskedPathParts[i] = "***"
|
||||
}
|
||||
}
|
||||
if len(maskedPathParts) > 0 {
|
||||
result += "/" + strings.Join(maskedPathParts, "/")
|
||||
}
|
||||
} else if u.Path == "/" {
|
||||
result += "/"
|
||||
}
|
||||
|
||||
// Mask query parameters
|
||||
if u.RawQuery != "" {
|
||||
values, err := url.ParseQuery(u.RawQuery)
|
||||
if err != nil {
|
||||
// If can't parse query, just mask the whole query string
|
||||
result += "?***"
|
||||
} else {
|
||||
maskedParams := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
maskedParams = append(maskedParams, key+"=***")
|
||||
}
|
||||
if len(maskedParams) > 0 {
|
||||
result += "?" + strings.Join(maskedParams, "&")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// Mask IP addresses
|
||||
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
150
common/totp.go
Normal file
150
common/totp.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
)
|
||||
|
||||
const (
|
||||
// 备用码配置
|
||||
BackupCodeLength = 8 // 备用码长度
|
||||
BackupCodeCount = 4 // 生成备用码数量
|
||||
|
||||
// 限制配置
|
||||
MaxFailAttempts = 5 // 最大失败尝试次数
|
||||
LockoutDuration = 300 // 锁定时间(秒)
|
||||
)
|
||||
|
||||
// GenerateTOTPSecret 生成TOTP密钥和配置
|
||||
func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
|
||||
issuer := Get2FAIssuer()
|
||||
return totp.Generate(totp.GenerateOpts{
|
||||
Issuer: issuer,
|
||||
AccountName: accountName,
|
||||
Period: 30,
|
||||
Digits: otp.DigitsSix,
|
||||
Algorithm: otp.AlgorithmSHA1,
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateTOTPCode 验证TOTP验证码
|
||||
func ValidateTOTPCode(secret, code string) bool {
|
||||
// 清理验证码格式
|
||||
cleanCode := strings.ReplaceAll(code, " ", "")
|
||||
if len(cleanCode) != 6 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
return totp.Validate(cleanCode, secret)
|
||||
}
|
||||
|
||||
// GenerateBackupCodes 生成备用恢复码
|
||||
func GenerateBackupCodes() ([]string, error) {
|
||||
codes := make([]string, BackupCodeCount)
|
||||
|
||||
for i := 0; i < BackupCodeCount; i++ {
|
||||
code, err := generateRandomBackupCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
codes[i] = code
|
||||
}
|
||||
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// generateRandomBackupCode 生成单个备用码
|
||||
func generateRandomBackupCode() (string, error) {
|
||||
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
code := make([]byte, BackupCodeLength)
|
||||
|
||||
for i := range code {
|
||||
randomBytes := make([]byte, 1)
|
||||
_, err := rand.Read(randomBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code[i] = charset[int(randomBytes[0])%len(charset)]
|
||||
}
|
||||
|
||||
// 格式化为 XXXX-XXXX 格式
|
||||
return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
|
||||
}
|
||||
|
||||
// ValidateBackupCode 验证备用码格式
|
||||
func ValidateBackupCode(code string) bool {
|
||||
// 移除所有分隔符并转为大写
|
||||
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||
if len(cleanCode) != BackupCodeLength {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查字符是否合法
|
||||
for _, char := range cleanCode {
|
||||
if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// NormalizeBackupCode 标准化备用码格式
|
||||
func NormalizeBackupCode(code string) string {
|
||||
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||
if len(cleanCode) == BackupCodeLength {
|
||||
return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
|
||||
}
|
||||
return code
|
||||
}
|
||||
|
||||
// HashBackupCode 对备用码进行哈希
|
||||
func HashBackupCode(code string) (string, error) {
|
||||
normalizedCode := NormalizeBackupCode(code)
|
||||
return Password2Hash(normalizedCode)
|
||||
}
|
||||
|
||||
// Get2FAIssuer 获取2FA发行者名称
|
||||
func Get2FAIssuer() string {
|
||||
return SystemName
|
||||
}
|
||||
|
||||
// getEnvOrDefault 获取环境变量或默认值
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if value, exists := os.LookupEnv(key); exists {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// ValidateNumericCode 验证数字验证码格式
|
||||
func ValidateNumericCode(code string) (string, error) {
|
||||
// 移除空格
|
||||
code = strings.ReplaceAll(code, " ", "")
|
||||
|
||||
if len(code) != 6 {
|
||||
return "", fmt.Errorf("验证码必须是6位数字")
|
||||
}
|
||||
|
||||
// 检查是否为纯数字
|
||||
if _, err := strconv.Atoi(code); err != nil {
|
||||
return "", fmt.Errorf("验证码只能包含数字")
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// GenerateQRCodeData 生成二维码数据
|
||||
func GenerateQRCodeData(secret, username string) string {
|
||||
issuer := Get2FAIssuer()
|
||||
accountName := fmt.Sprintf("%s (%s)", username, issuer)
|
||||
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
|
||||
issuer, accountName, secret, issuer)
|
||||
}
|
||||
@@ -31,5 +31,6 @@ const (
|
||||
APITypeXai
|
||||
APITypeCoze
|
||||
APITypeJimeng
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
APITypeMoonshot // this one is only for count, do not add any channel after this
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@ const (
|
||||
ContextKeyTokenKey ContextKey = "token_key"
|
||||
ContextKeyTokenId ContextKey = "token_id"
|
||||
ContextKeyTokenGroup ContextKey = "token_group"
|
||||
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
||||
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||
@@ -41,4 +40,6 @@ const (
|
||||
ContextKeyUserGroup ContextKey = "user_group"
|
||||
ContextKeyUsingGroup ContextKey = "group"
|
||||
ContextKeyUserName ContextKey = "username"
|
||||
|
||||
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
|
||||
)
|
||||
|
||||
@@ -161,7 +161,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
logInfo.ApiKey = ""
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens()))
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
@@ -275,7 +275,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
Quota: quota,
|
||||
Content: "模型测试",
|
||||
UseTimeSeconds: int(consumedTime),
|
||||
IsStream: false,
|
||||
IsStream: info.IsStream,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
@@ -332,8 +332,11 @@ func TestChannel(c *gin.Context) {
|
||||
}
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
channel, err = model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
//defer func() {
|
||||
// if channel.ChannelInfo.IsMultiKey {
|
||||
|
||||
@@ -52,6 +52,13 @@ func parseStatusFilter(statusParam string) int {
|
||||
}
|
||||
}
|
||||
|
||||
func clearChannelInfo(channel *model.Channel) {
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = nil
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
channelData := make([]*model.Channel, 0)
|
||||
@@ -126,6 +133,10 @@ func GetAllChannels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, datum := range channelData {
|
||||
clearChannelInfo(datum)
|
||||
}
|
||||
|
||||
countQuery := model.DB.Model(&model.Channel{})
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
@@ -168,14 +179,26 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
|
||||
var url string
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeGemini:
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
default:
|
||||
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||
}
|
||||
|
||||
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
||||
var body []byte
|
||||
key := strings.Split(channel.Key, "\n")[0]
|
||||
if channel.Type == constant.ChannelTypeGemini {
|
||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it
|
||||
} else {
|
||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
@@ -319,6 +342,10 @@ func SearchChannels(c *gin.Context) {
|
||||
|
||||
pagedData := channelData[startIdx:endIdx]
|
||||
|
||||
for _, datum := range pagedData {
|
||||
clearChannelInfo(datum)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -342,6 +369,9 @@ func GetChannel(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if channel != nil {
|
||||
clearChannelInfo(channel)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -669,6 +699,7 @@ func DeleteChannelBatch(c *gin.Context) {
|
||||
type PatchChannel struct {
|
||||
model.Channel
|
||||
MultiKeyMode *string `json:"multi_key_mode"`
|
||||
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
|
||||
}
|
||||
|
||||
func UpdateChannel(c *gin.Context) {
|
||||
@@ -688,7 +719,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
||||
originChannel, err := model.GetChannelById(channel.Id, false)
|
||||
originChannel, err := model.GetChannelById(channel.Id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -704,6 +735,69 @@ func UpdateChannel(c *gin.Context) {
|
||||
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
||||
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
||||
}
|
||||
|
||||
// 处理多key模式下的密钥追加/覆盖逻辑
|
||||
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
|
||||
switch *channel.KeyMode {
|
||||
case "append":
|
||||
// 追加模式:将新密钥添加到现有密钥列表
|
||||
if originChannel.Key != "" {
|
||||
var newKeys []string
|
||||
var existingKeys []string
|
||||
|
||||
// 解析现有密钥
|
||||
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
|
||||
// JSON数组格式
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
|
||||
existingKeys = make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
existingKeys[i] = string(v)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 换行分隔格式
|
||||
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
|
||||
}
|
||||
|
||||
// 处理 Vertex AI 的特殊情况
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
// 尝试解析新密钥为JSON数组
|
||||
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||||
array, err := getVertexArrayKeys(channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "追加密钥解析失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
newKeys = array
|
||||
} else {
|
||||
// 单个JSON密钥
|
||||
newKeys = []string{channel.Key}
|
||||
}
|
||||
// 合并密钥
|
||||
allKeys := append(existingKeys, newKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
} else {
|
||||
// 普通渠道的处理
|
||||
inputKeys := strings.Split(channel.Key, "\n")
|
||||
for _, key := range inputKeys {
|
||||
key = strings.TrimSpace(key)
|
||||
if key != "" {
|
||||
newKeys = append(newKeys, key)
|
||||
}
|
||||
}
|
||||
// 合并密钥
|
||||
allKeys := append(existingKeys, newKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
}
|
||||
}
|
||||
case "replace":
|
||||
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
|
||||
}
|
||||
}
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -711,6 +805,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
}
|
||||
model.InitChannelCache()
|
||||
channel.Key = ""
|
||||
clearChannelInfo(&channel.Channel)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -914,3 +1009,413 @@ func CopyChannel(c *gin.Context) {
|
||||
// success
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
|
||||
}
|
||||
|
||||
// MultiKeyManageRequest represents the request for multi-key management operations
|
||||
type MultiKeyManageRequest struct {
|
||||
ChannelId int `json:"channel_id"`
|
||||
Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
|
||||
KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
|
||||
Page int `json:"page,omitempty"` // for get_key_status pagination
|
||||
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
|
||||
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
|
||||
}
|
||||
|
||||
// MultiKeyStatusResponse represents the response for key status query
|
||||
type MultiKeyStatusResponse struct {
|
||||
Keys []KeyStatus `json:"keys"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
// Statistics
|
||||
EnabledCount int `json:"enabled_count"`
|
||||
ManualDisabledCount int `json:"manual_disabled_count"`
|
||||
AutoDisabledCount int `json:"auto_disabled_count"`
|
||||
}
|
||||
|
||||
type KeyStatus struct {
|
||||
Index int `json:"index"`
|
||||
Status int `json:"status"` // 1: enabled, 2: disabled
|
||||
DisabledTime int64 `json:"disabled_time,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
|
||||
}
|
||||
|
||||
// ManageMultiKeys handles multi-key management operations
|
||||
func ManageMultiKeys(c *gin.Context) {
|
||||
request := MultiKeyManageRequest{}
|
||||
err := c.ShouldBindJSON(&request)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(request.ChannelId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "渠道不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !channel.ChannelInfo.IsMultiKey {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该渠道不是多密钥模式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
lock := model.GetChannelPollingLock(channel.Id)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
switch request.Action {
|
||||
case "get_key_status":
|
||||
keys := channel.GetKeys()
|
||||
|
||||
// Default pagination parameters
|
||||
page := request.Page
|
||||
pageSize := request.PageSize
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 50 // Default page size
|
||||
}
|
||||
|
||||
// Statistics for all keys (unchanged by filtering)
|
||||
var enabledCount, manualDisabledCount, autoDisabledCount int
|
||||
|
||||
// Build all key status data first
|
||||
var allKeyStatusList []KeyStatus
|
||||
for i, key := range keys {
|
||||
status := 1 // default enabled
|
||||
var disabledTime int64
|
||||
var reason string
|
||||
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||
status = s
|
||||
}
|
||||
}
|
||||
|
||||
// Count for statistics (all keys)
|
||||
switch status {
|
||||
case 1:
|
||||
enabledCount++
|
||||
case 2:
|
||||
manualDisabledCount++
|
||||
case 3:
|
||||
autoDisabledCount++
|
||||
}
|
||||
|
||||
if status != 1 {
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Create key preview (first 10 chars)
|
||||
keyPreview := key
|
||||
if len(key) > 10 {
|
||||
keyPreview = key[:10] + "..."
|
||||
}
|
||||
|
||||
allKeyStatusList = append(allKeyStatusList, KeyStatus{
|
||||
Index: i,
|
||||
Status: status,
|
||||
DisabledTime: disabledTime,
|
||||
Reason: reason,
|
||||
KeyPreview: keyPreview,
|
||||
})
|
||||
}
|
||||
|
||||
// Apply status filter if specified
|
||||
var filteredKeyStatusList []KeyStatus
|
||||
if request.Status != nil {
|
||||
for _, keyStatus := range allKeyStatusList {
|
||||
if keyStatus.Status == *request.Status {
|
||||
filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
filteredKeyStatusList = allKeyStatusList
|
||||
}
|
||||
|
||||
// Calculate pagination based on filtered results
|
||||
filteredTotal := len(filteredKeyStatusList)
|
||||
totalPages := (filteredTotal + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
if page > totalPages {
|
||||
page = totalPages
|
||||
}
|
||||
|
||||
// Calculate range for current page
|
||||
start := (page - 1) * pageSize
|
||||
end := start + pageSize
|
||||
if end > filteredTotal {
|
||||
end = filteredTotal
|
||||
}
|
||||
|
||||
// Get the page data
|
||||
var pageKeyStatusList []KeyStatus
|
||||
if start < filteredTotal {
|
||||
pageKeyStatusList = filteredKeyStatusList[start:end]
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": MultiKeyStatusResponse{
|
||||
Keys: pageKeyStatusList,
|
||||
Total: filteredTotal, // Total of filtered results
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
EnabledCount: enabledCount, // Overall statistics
|
||||
ManualDisabledCount: manualDisabledCount, // Overall statistics
|
||||
AutoDisabledCount: autoDisabledCount, // Overall statistics
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
case "disable_key":
|
||||
if request.KeyIndex == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "未指定要禁用的密钥索引",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keyIndex := *request.KeyIndex
|
||||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "密钥索引超出范围",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
}
|
||||
|
||||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "密钥已禁用",
|
||||
})
|
||||
return
|
||||
|
||||
case "enable_key":
|
||||
if request.KeyIndex == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "未指定要启用的密钥索引",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keyIndex := *request.KeyIndex
|
||||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "密钥索引超出范围",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从状态列表中删除该密钥的记录,使其回到默认启用状态
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
|
||||
}
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "密钥已启用",
|
||||
})
|
||||
return
|
||||
|
||||
case "enable_all_keys":
|
||||
// 清空所有禁用状态,使所有密钥回到默认启用状态
|
||||
var enabledCount int
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
|
||||
}
|
||||
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
|
||||
})
|
||||
return
|
||||
|
||||
case "disable_all_keys":
|
||||
// 禁用所有启用的密钥
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
}
|
||||
|
||||
var disabledCount int
|
||||
for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
|
||||
status := 1 // default enabled
|
||||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||
status = s
|
||||
}
|
||||
|
||||
// 只禁用当前启用的密钥
|
||||
if status == 1 {
|
||||
channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
|
||||
disabledCount++
|
||||
}
|
||||
}
|
||||
|
||||
if disabledCount == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "没有可禁用的密钥",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
|
||||
})
|
||||
return
|
||||
|
||||
case "delete_disabled_keys":
|
||||
keys := channel.GetKeys()
|
||||
var remainingKeys []string
|
||||
var deletedCount int
|
||||
var newStatusList = make(map[int]int)
|
||||
var newDisabledTime = make(map[int]int64)
|
||||
var newDisabledReason = make(map[int]string)
|
||||
|
||||
newIndex := 0
|
||||
for i, key := range keys {
|
||||
status := 1 // default enabled
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||
status = s
|
||||
}
|
||||
}
|
||||
|
||||
// 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
|
||||
if status == 3 {
|
||||
deletedCount++
|
||||
} else {
|
||||
remainingKeys = append(remainingKeys, key)
|
||||
// 保留非自动禁用密钥的状态信息,重新索引
|
||||
if status != 1 {
|
||||
newStatusList[newIndex] = status
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
|
||||
newDisabledTime[newIndex] = t
|
||||
}
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
|
||||
newDisabledReason[newIndex] = r
|
||||
}
|
||||
}
|
||||
}
|
||||
newIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if deletedCount == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "没有需要删除的自动禁用密钥",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update channel with remaining keys
|
||||
channel.Key = strings.Join(remainingKeys, "\n")
|
||||
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
|
||||
channel.ChannelInfo.MultiKeyStatusList = newStatusList
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
|
||||
"data": deletedCount,
|
||||
})
|
||||
return
|
||||
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不支持的操作",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,21 +220,29 @@ func LinuxdoOAuth(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -145,6 +145,22 @@ func UpdateMidjourneyTaskBulk() {
|
||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||
task.Buttons = string(buttonStr)
|
||||
}
|
||||
// 映射 VideoUrl
|
||||
task.VideoUrl = responseItem.VideoUrl
|
||||
|
||||
// 映射 VideoUrls - 将数组序列化为 JSON 字符串
|
||||
if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
|
||||
videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
|
||||
task.VideoUrls = "[]" // 失败时设置为空数组
|
||||
} else {
|
||||
task.VideoUrls = string(videoUrlsStr)
|
||||
}
|
||||
} else {
|
||||
task.VideoUrls = "" // 空值时清空字段
|
||||
}
|
||||
|
||||
shouldReturnQuota := false
|
||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
@@ -208,6 +224,20 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
|
||||
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
||||
return true
|
||||
}
|
||||
// 检查 VideoUrl 是否需要更新
|
||||
if oldTask.VideoUrl != newTask.VideoUrl {
|
||||
return true
|
||||
}
|
||||
// 检查 VideoUrls 是否需要更新
|
||||
if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 {
|
||||
newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls)
|
||||
if oldTask.VideoUrls != string(newVideoUrlsStr) {
|
||||
return true
|
||||
}
|
||||
} else if oldTask.VideoUrls != "" {
|
||||
// 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -41,46 +41,47 @@ func GetStatus(c *gin.Context) {
|
||||
cs := console_setting.GetConsoleSetting()
|
||||
|
||||
data := gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
"usd_exchange_rate": setting.USDExchangeRate,
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
"usd_exchange_rate": setting.USDExchangeRate,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
|
||||
27
controller/missing_models.go
Normal file
27
controller/missing_models.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetMissingModels returns the list of model names that are referenced by channels
|
||||
// but do not have corresponding records in the models meta table.
|
||||
// This helps administrators quickly discover models that need configuration.
|
||||
func GetMissingModels(c *gin.Context) {
|
||||
missing, err := model.GetMissingModels()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": missing,
|
||||
})
|
||||
}
|
||||
178
controller/model_meta.go
Normal file
178
controller/model_meta.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetAllModelsMeta 获取模型列表(分页)
|
||||
func GetAllModelsMeta(c *gin.Context) {
|
||||
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// 填充附加字段
|
||||
for _, m := range modelsMeta {
|
||||
fillModelExtra(m)
|
||||
}
|
||||
var total int64
|
||||
model.DB.Model(&model.Model{}).Count(&total)
|
||||
|
||||
// 统计供应商计数(全部数据,不受分页影响)
|
||||
vendorCounts, _ := model.GetVendorModelCounts()
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(modelsMeta)
|
||||
common.ApiSuccess(c, gin.H{
|
||||
"items": modelsMeta,
|
||||
"total": total,
|
||||
"page": pageInfo.GetPage(),
|
||||
"page_size": pageInfo.GetPageSize(),
|
||||
"vendor_counts": vendorCounts,
|
||||
})
|
||||
}
|
||||
|
||||
// SearchModelsMeta 搜索模型列表
|
||||
func SearchModelsMeta(c *gin.Context) {
|
||||
|
||||
keyword := c.Query("keyword")
|
||||
vendor := c.Query("vendor")
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
for _, m := range modelsMeta {
|
||||
fillModelExtra(m)
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(modelsMeta)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
// GetModelMeta 根据 ID 获取单条模型信息
|
||||
func GetModelMeta(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var m model.Model
|
||||
if err := model.DB.First(&m, id).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
fillModelExtra(&m)
|
||||
common.ApiSuccess(c, &m)
|
||||
}
|
||||
|
||||
// CreateModelMeta 新建模型
|
||||
func CreateModelMeta(c *gin.Context) {
|
||||
var m model.Model
|
||||
if err := c.ShouldBindJSON(&m); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if m.ModelName == "" {
|
||||
common.ApiErrorMsg(c, "模型名称不能为空")
|
||||
return
|
||||
}
|
||||
// 名称冲突检查
|
||||
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "模型名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.Insert(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.RefreshPricing()
|
||||
common.ApiSuccess(c, &m)
|
||||
}
|
||||
|
||||
// UpdateModelMeta 更新模型
|
||||
func UpdateModelMeta(c *gin.Context) {
|
||||
statusOnly := c.Query("status_only") == "true"
|
||||
|
||||
var m model.Model
|
||||
if err := c.ShouldBindJSON(&m); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if m.Id == 0 {
|
||||
common.ApiErrorMsg(c, "缺少模型 ID")
|
||||
return
|
||||
}
|
||||
|
||||
if statusOnly {
|
||||
// 只更新状态,防止误清空其他字段
|
||||
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 名称冲突检查
|
||||
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "模型名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.Update(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
model.RefreshPricing()
|
||||
common.ApiSuccess(c, &m)
|
||||
}
|
||||
|
||||
// DeleteModelMeta 删除模型
|
||||
func DeleteModelMeta(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.RefreshPricing()
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
|
||||
// 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups
|
||||
func fillModelExtra(m *model.Model) {
|
||||
if m.Endpoints == "" {
|
||||
eps := model.GetModelSupportEndpointTypes(m.ModelName)
|
||||
if b, err := json.Marshal(eps); err == nil {
|
||||
m.Endpoints = string(b)
|
||||
}
|
||||
}
|
||||
if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
|
||||
m.BoundChannels = channels
|
||||
}
|
||||
// 填充启用分组
|
||||
m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
|
||||
// 填充计费类型
|
||||
m.QuotaType = model.GetModelQuotaType(m.ModelName)
|
||||
}
|
||||
@@ -5,10 +5,8 @@ import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
@@ -28,41 +26,19 @@ func Playground(c *gin.Context) {
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
group := playgroundRequest.Group
|
||||
userGroup := c.GetString("group")
|
||||
|
||||
if group == "" {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
}
|
||||
group := c.GetString("group")
|
||||
modelName := c.GetString("original_model")
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
@@ -73,7 +49,7 @@ func Playground(c *gin.Context) {
|
||||
Group: group,
|
||||
}
|
||||
_ = middleware.SetupContextForToken(c, tempToken)
|
||||
_, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
|
||||
_, newAPIError = getChannel(c, group, modelName, 0)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
90
controller/prefill_group.go
Normal file
90
controller/prefill_group.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
|
||||
func GetPrefillGroups(c *gin.Context) {
|
||||
groupType := c.Query("type")
|
||||
groups, err := model.GetAllPrefillGroups(groupType)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, groups)
|
||||
}
|
||||
|
||||
// CreatePrefillGroup 创建新的预填组
|
||||
func CreatePrefillGroup(c *gin.Context) {
|
||||
var g model.PrefillGroup
|
||||
if err := c.ShouldBindJSON(&g); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if g.Name == "" || g.Type == "" {
|
||||
common.ApiErrorMsg(c, "组名称和类型不能为空")
|
||||
return
|
||||
}
|
||||
// 创建前检查名称
|
||||
if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "组名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := g.Insert(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &g)
|
||||
}
|
||||
|
||||
// UpdatePrefillGroup 更新预填组
|
||||
func UpdatePrefillGroup(c *gin.Context) {
|
||||
var g model.PrefillGroup
|
||||
if err := c.ShouldBindJSON(&g); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if g.Id == 0 {
|
||||
common.ApiErrorMsg(c, "缺少组 ID")
|
||||
return
|
||||
}
|
||||
// 名称冲突检查
|
||||
if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "组名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := g.Update(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &g)
|
||||
}
|
||||
|
||||
// DeletePrefillGroup 删除预填组
|
||||
func DeletePrefillGroup(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := model.DeletePrefillGroupByID(id); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
@@ -39,10 +39,13 @@ func GetPricing(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"group_ratio": groupRatio,
|
||||
"usable_group": usableGroup,
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"vendors": model.GetVendors(),
|
||||
"group_ratio": groupRatio,
|
||||
"usable_group": usableGroup,
|
||||
"supported_endpoint": model.GetSupportedEndpointMap(),
|
||||
"auto_groups": setting.AutoGroups,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -42,12 +42,16 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
case relayconstant.RelayModeResponses:
|
||||
err = relay.ResponsesHelper(c)
|
||||
case relayconstant.RelayModeGemini:
|
||||
err = relay.GeminiHelper(c)
|
||||
if strings.Contains(c.Request.URL.Path, "embed") {
|
||||
err = relay.GeminiEmbeddingHandler(c)
|
||||
} else {
|
||||
err = relay.GeminiHelper(c)
|
||||
}
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
|
||||
if constant2.ErrorLogEnabled && err != nil {
|
||||
if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
|
||||
// 保存错误日志到mysql中
|
||||
userId := c.GetInt("id")
|
||||
tokenName := c.GetString("token_name")
|
||||
@@ -62,8 +66,15 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
|
||||
adminInfo := make(map[string]interface{})
|
||||
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||
if isMultiKey {
|
||||
adminInfo["is_multi_key"] = true
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
other["admin_info"] = adminInfo
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -128,7 +139,7 @@ func WssRelay(c *gin.Context) {
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -259,10 +270,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
}
|
||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed)
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
if newAPIError != nil {
|
||||
@@ -278,7 +289,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
if types.IsChannelError(openaiErr) {
|
||||
return true
|
||||
}
|
||||
if types.IsLocalError(openaiErr) {
|
||||
if types.IsSkipRetryError(openaiErr) {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
|
||||
553
controller/twofa.go
Normal file
553
controller/twofa.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Setup2FARequest 设置2FA请求结构
|
||||
type Setup2FARequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// Verify2FARequest 验证2FA请求结构
|
||||
type Verify2FARequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// Setup2FAResponse 设置2FA响应结构
|
||||
type Setup2FAResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeData string `json:"qr_code_data"`
|
||||
BackupCodes []string `json:"backup_codes"`
|
||||
}
|
||||
|
||||
// Setup2FA 初始化2FA设置
|
||||
func Setup2FA(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 检查用户是否已经启用2FA
|
||||
existing, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if existing != nil && existing.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已启用2FA,请先禁用后重新设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果存在已禁用的2FA记录,先删除它
|
||||
if existing != nil && !existing.IsEnabled {
|
||||
if err := existing.Delete(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
existing = nil // 重置为nil,后续将创建新记录
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 生成TOTP密钥
|
||||
key, err := common.GenerateTOTPSecret(user.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成2FA密钥失败",
|
||||
})
|
||||
common.SysError("生成TOTP密钥失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 生成备用码
|
||||
backupCodes, err := common.GenerateBackupCodes()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成备用码失败",
|
||||
})
|
||||
common.SysError("生成备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 生成二维码数据
|
||||
qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
|
||||
|
||||
// 创建或更新2FA记录(暂未启用)
|
||||
twoFA := &model.TwoFA{
|
||||
UserId: userId,
|
||||
Secret: key.Secret(),
|
||||
IsEnabled: false,
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
// 更新现有记录
|
||||
twoFA.Id = existing.Id
|
||||
err = twoFA.Update()
|
||||
} else {
|
||||
// 创建新记录
|
||||
err = twoFA.Create()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建备用码记录
|
||||
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "保存备用码失败",
|
||||
})
|
||||
common.SysError("保存备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置",
|
||||
"data": Setup2FAResponse{
|
||||
Secret: key.Secret(),
|
||||
QRCodeData: qrCodeData,
|
||||
BackupCodes: backupCodes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Enable2FA 启用2FA
|
||||
func Enable2FA(c *gin.Context) {
|
||||
var req Setup2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "请先完成2FA初始化设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
if twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "2FA已经启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 启用2FA
|
||||
if err := twoFA.Enable(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "两步验证启用成功",
|
||||
})
|
||||
}
|
||||
|
||||
// Disable2FA 禁用2FA
|
||||
func Disable2FA(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码或备用码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
isValidTOTP := false
|
||||
isValidBackup := false
|
||||
|
||||
if err == nil {
|
||||
// 尝试验证TOTP
|
||||
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
}
|
||||
|
||||
if !isValidTOTP {
|
||||
// 尝试验证备用码
|
||||
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidTOTP && !isValidBackup {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 禁用2FA
|
||||
if err := model.DisableTwoFA(userId); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "两步验证已禁用",
|
||||
})
|
||||
}
|
||||
|
||||
// Get2FAStatus 获取用户2FA状态
|
||||
func Get2FAStatus(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
status := map[string]interface{}{
|
||||
"enabled": false,
|
||||
"locked": false,
|
||||
}
|
||||
|
||||
if twoFA != nil {
|
||||
status["enabled"] = twoFA.IsEnabled
|
||||
status["locked"] = twoFA.IsLocked()
|
||||
if twoFA.IsEnabled {
|
||||
// 获取剩余备用码数量
|
||||
backupCount, err := model.GetUnusedBackupCodeCount(userId)
|
||||
if err != nil {
|
||||
common.SysError("获取备用码数量失败: " + err.Error())
|
||||
} else {
|
||||
status["backup_codes_remaining"] = backupCount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": status,
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateBackupCodes 重新生成备用码
|
||||
func RegenerateBackupCodes(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新的备用码
|
||||
backupCodes, err := common.GenerateBackupCodes()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成备用码失败",
|
||||
})
|
||||
common.SysError("生成备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 保存新的备用码
|
||||
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "保存备用码失败",
|
||||
})
|
||||
common.SysError("保存备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "备用码重新生成成功",
|
||||
"data": map[string]interface{}{
|
||||
"backup_codes": backupCodes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Verify2FALogin 登录时验证2FA
|
||||
func Verify2FALogin(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从会话中获取pending用户信息
|
||||
session := sessions.Default(c)
|
||||
pendingUserId := session.Get("pending_user_id")
|
||||
if pendingUserId == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "会话已过期,请重新登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
userId, ok := pendingUserId.(int)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "会话数据无效,请重新登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 获取用户信息
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(user.Id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码或备用码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
isValidTOTP := false
|
||||
isValidBackup := false
|
||||
|
||||
if err == nil {
|
||||
// 尝试验证TOTP
|
||||
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
}
|
||||
|
||||
if !isValidTOTP {
|
||||
// 尝试验证备用码
|
||||
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidTOTP && !isValidBackup {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2FA验证成功,清理pending会话信息并完成登录
|
||||
session.Delete("pending_username")
|
||||
session.Delete("pending_user_id")
|
||||
session.Save()
|
||||
|
||||
setupLogin(user, c)
|
||||
}
|
||||
|
||||
// Admin2FAStats 管理员获取2FA统计信息
|
||||
func Admin2FAStats(c *gin.Context) {
|
||||
stats, err := model.GetTwoFAStats()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": stats,
|
||||
})
|
||||
}
|
||||
|
||||
// AdminDisable2FA 管理员强制禁用用户2FA
|
||||
func AdminDisable2FA(c *gin.Context) {
|
||||
userIdStr := c.Param("id")
|
||||
userId, err := strconv.Atoi(userIdStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户ID格式错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查目标用户权限
|
||||
targetUser, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权操作同级或更高级用户的2FA设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 禁用2FA
|
||||
if err := model.DisableTwoFA(userId); err != nil {
|
||||
if errors.Is(err, model.ErrTwoFANotEnabled) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
adminId := c.GetInt("id")
|
||||
model.RecordLog(userId, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "用户2FA已被强制禁用",
|
||||
})
|
||||
}
|
||||
@@ -62,6 +62,32 @@ func Login(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否启用2FA
|
||||
if model.IsTwoFAEnabled(user.Id) {
|
||||
// 设置pending session,等待2FA验证
|
||||
session := sessions.Default(c)
|
||||
session.Set("pending_username", user.Username)
|
||||
session.Set("pending_user_id", user.Id)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "无法保存会话信息,请重试",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "请输入两步验证码",
|
||||
"success": true,
|
||||
"data": map[string]interface{}{
|
||||
"require_2fa": true,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
|
||||
124
controller/vendor_meta.go
Normal file
124
controller/vendor_meta.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetAllVendors 获取供应商列表(分页)
|
||||
func GetAllVendors(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var total int64
|
||||
model.DB.Model(&model.Vendor{}).Count(&total)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(vendors)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
// SearchVendors 搜索供应商
|
||||
func SearchVendors(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(vendors)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
// GetVendorMeta 根据 ID 获取供应商
|
||||
func GetVendorMeta(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
v, err := model.GetVendorByID(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, v)
|
||||
}
|
||||
|
||||
// CreateVendorMeta 新建供应商
|
||||
func CreateVendorMeta(c *gin.Context) {
|
||||
var v model.Vendor
|
||||
if err := c.ShouldBindJSON(&v); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if v.Name == "" {
|
||||
common.ApiErrorMsg(c, "供应商名称不能为空")
|
||||
return
|
||||
}
|
||||
// 创建前先检查名称
|
||||
if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := v.Insert(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &v)
|
||||
}
|
||||
|
||||
// UpdateVendorMeta 更新供应商
|
||||
func UpdateVendorMeta(c *gin.Context) {
|
||||
var v model.Vendor
|
||||
if err := c.ShouldBindJSON(&v); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if v.Id == 0 {
|
||||
common.ApiErrorMsg(c, "缺少供应商 ID")
|
||||
return
|
||||
}
|
||||
// 名称冲突检查
|
||||
if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := v.Update(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &v)
|
||||
}
|
||||
|
||||
// DeleteVendorMeta 删除供应商
|
||||
func DeleteVendorMeta(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
@@ -6,4 +6,5 @@ type ChannelSettings struct {
|
||||
Proxy string `json:"proxy"`
|
||||
PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
@@ -198,6 +199,18 @@ type ClaudeRequest struct {
|
||||
Thinking *Thinking `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
|
||||
for _, message := range c.Messages {
|
||||
content, _ := message.ParseContent()
|
||||
for _, mediaMessage := range content {
|
||||
if mediaMessage.Id == toolCallId {
|
||||
return mediaMessage.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// AddTool 添加工具到请求中
|
||||
func (c *ClaudeRequest) AddTool(tool any) {
|
||||
if c.Tools == nil {
|
||||
@@ -284,14 +297,9 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
||||
return mediaContent
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeErrorWithStatusCode struct {
|
||||
Error ClaudeError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
Error types.ClaudeError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
LocalError bool
|
||||
}
|
||||
|
||||
@@ -303,7 +311,7 @@ type ClaudeResponse struct {
|
||||
Completion string `json:"completion,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Error *types.ClaudeError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||
@@ -324,12 +332,48 @@ func (c *ClaudeResponse) GetIndex() int {
|
||||
return *c.Index
|
||||
}
|
||||
|
||||
// GetClaudeError 从动态错误类型中提取ClaudeError结构
|
||||
func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
|
||||
if c.Error == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err := c.Error.(type) {
|
||||
case types.ClaudeError:
|
||||
return &err
|
||||
case *types.ClaudeError:
|
||||
return err
|
||||
case map[string]interface{}:
|
||||
// 处理从JSON解析来的map结构
|
||||
claudeErr := &types.ClaudeError{}
|
||||
if errType, ok := err["type"].(string); ok {
|
||||
claudeErr.Type = errType
|
||||
}
|
||||
if errMsg, ok := err["message"].(string); ok {
|
||||
claudeErr.Message = errMsg
|
||||
}
|
||||
return claudeErr
|
||||
case string:
|
||||
// 处理简单字符串错误
|
||||
return &types.ClaudeError{
|
||||
Type: "error",
|
||||
Message: err,
|
||||
}
|
||||
default:
|
||||
// 未知类型,尝试转换为字符串
|
||||
return &types.ClaudeError{
|
||||
Type: "unknown_error",
|
||||
Message: fmt.Sprintf("%v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"`
|
||||
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeServerToolUse struct {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package gemini
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
@@ -32,7 +35,7 @@ func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
|
||||
MimeTypeSnake string `json:"mime_type"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -53,7 +56,7 @@ type FunctionCall struct {
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
}
|
||||
@@ -78,7 +81,7 @@ type GeminiPart struct {
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||
@@ -93,7 +96,7 @@ func (p *GeminiPart) UnmarshalJSON(data []byte) error {
|
||||
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -207,16 +210,25 @@ type GeminiImagePrediction struct {
|
||||
|
||||
// Embedding related structs
|
||||
type GeminiEmbeddingRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Content GeminiChatContent `json:"content"`
|
||||
TaskType string `json:"taskType,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiBatchEmbeddingRequest struct {
|
||||
Requests []*GeminiEmbeddingRequest `json:"requests"`
|
||||
}
|
||||
|
||||
type GeminiEmbeddingResponse struct {
|
||||
Embedding ContentEmbedding `json:"embedding"`
|
||||
}
|
||||
|
||||
type GeminiBatchEmbeddingResponse struct {
|
||||
Embeddings []*ContentEmbedding `json:"embeddings"`
|
||||
}
|
||||
|
||||
type ContentEmbedding struct {
|
||||
Values []float64 `json:"values"`
|
||||
}
|
||||
@@ -29,6 +29,7 @@ type GeneralOpenAIRequest struct {
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Verbosity json.RawMessage `json:"verbosity,omitempty"` // gpt-5
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
@@ -78,6 +79,8 @@ func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
|
||||
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
|
||||
return "developer"
|
||||
}
|
||||
} else if strings.HasPrefix(r.Model, "gpt-5") {
|
||||
return "developer"
|
||||
}
|
||||
return "system"
|
||||
}
|
||||
@@ -99,8 +102,11 @@ type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetMaxTokens() int {
|
||||
return int(r.MaxTokens)
|
||||
func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
|
||||
if r.MaxCompletionTokens != 0 {
|
||||
return r.MaxCompletionTokens
|
||||
}
|
||||
return r.MaxTokens
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
||||
|
||||
@@ -2,12 +2,18 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
Error *OpenAIError `json:"error"`
|
||||
Error any `json:"error"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(s.Error)
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
@@ -31,10 +37,15 @@ type OpenAITextResponse struct {
|
||||
Object string `json:"object"`
|
||||
Created any `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(o.Error)
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
@@ -132,6 +143,13 @@ type ChatCompletionsStreamResponse struct {
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) IsFinished() bool {
|
||||
if len(c.Choices) == 0 {
|
||||
return false
|
||||
}
|
||||
return c.Choices[0].FinishReason != nil && *c.Choices[0].FinishReason != ""
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
|
||||
if len(c.Choices) == 0 {
|
||||
return false
|
||||
@@ -146,6 +164,19 @@ func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) ClearToolCalls() {
|
||||
if !c.IsToolCall() {
|
||||
return
|
||||
}
|
||||
for choiceIdx := range c.Choices {
|
||||
for callIdx := range c.Choices[choiceIdx].Delta.ToolCalls {
|
||||
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].ID = ""
|
||||
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Type = nil
|
||||
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Function.Name = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
|
||||
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
|
||||
copy(choices, c.Choices)
|
||||
@@ -217,7 +248,7 @@ type OpenAIResponsesResponse struct {
|
||||
Object string `json:"object"`
|
||||
CreatedAt int `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
Instructions string `json:"instructions"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
@@ -237,6 +268,11 @@ type OpenAIResponsesResponse struct {
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(o.Error)
|
||||
}
|
||||
|
||||
type IncompleteDetails struct {
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
@@ -276,3 +312,45 @@ type ResponsesStreamResponse struct {
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Item *ResponsesOutput `json:"item,omitempty"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func GetOpenAIError(errorField any) *types.OpenAIError {
|
||||
if errorField == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err := errorField.(type) {
|
||||
case types.OpenAIError:
|
||||
return &err
|
||||
case *types.OpenAIError:
|
||||
return err
|
||||
case map[string]interface{}:
|
||||
// 处理从JSON解析来的map结构
|
||||
openaiErr := &types.OpenAIError{}
|
||||
if errType, ok := err["type"].(string); ok {
|
||||
openaiErr.Type = errType
|
||||
}
|
||||
if errMsg, ok := err["message"].(string); ok {
|
||||
openaiErr.Message = errMsg
|
||||
}
|
||||
if errParam, ok := err["param"].(string); ok {
|
||||
openaiErr.Param = errParam
|
||||
}
|
||||
if errCode, ok := err["code"]; ok {
|
||||
openaiErr.Code = errCode
|
||||
}
|
||||
return openaiErr
|
||||
case string:
|
||||
// 处理简单字符串错误
|
||||
return &types.OpenAIError{
|
||||
Type: "error",
|
||||
Message: err,
|
||||
}
|
||||
default:
|
||||
// 未知类型,尝试转换为字符串
|
||||
return &types.OpenAIError{
|
||||
Type: "unknown_error",
|
||||
Message: fmt.Sprintf("%v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
14
go.mod
14
go.mod
@@ -7,9 +7,10 @@ require (
|
||||
github.com/Calcium-Ion/go-epay v0.0.4
|
||||
github.com/andybalholm/brotli v1.1.1
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
||||
github.com/aws/aws-sdk-go-v2 v1.26.1
|
||||
github.com/aws/aws-sdk-go-v2 v1.37.2
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
|
||||
github.com/aws/smithy-go v1.22.5
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
@@ -24,6 +25,7 @@ require (
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pquerna/otp v1.5.0
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
@@ -41,10 +43,10 @@ require (
|
||||
|
||||
require (
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||
github.com/aws/smithy-go v1.20.2 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect
|
||||
github.com/boombuler/barcode v1.1.0 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
|
||||
29
go.sum
29
go.sum
@@ -6,20 +6,23 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
||||
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
|
||||
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
|
||||
github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
|
||||
github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
||||
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
||||
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA=
|
||||
github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
|
||||
github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
|
||||
github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
@@ -169,6 +172,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
||||
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -234,6 +237,16 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
allowIpsMap := token.GetIpLimitsMap()
|
||||
if len(allowIpsMap) != 0 {
|
||||
clientIp := c.ClientIP()
|
||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userCache, err := model.GetUserCache(token.UserId)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||
@@ -247,6 +260,25 @@ func TokenAuth() func(c *gin.Context) {
|
||||
|
||||
userCache.WriteContext(c)
|
||||
|
||||
userGroup := userCache.Group
|
||||
tokenGroup := token.Group
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||
if tokenGroup != "auto" {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
}
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||
|
||||
err = SetupContextForToken(c, token, parts...)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -273,7 +305,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||
c.Set("token_group", token.Group)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
|
||||
@@ -27,14 +27,6 @@ type ModelRequest struct {
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
||||
if len(allowIpsMap) != 0 {
|
||||
clientIp := c.ClientIP()
|
||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
||||
return
|
||||
}
|
||||
}
|
||||
var channel *model.Channel
|
||||
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||
@@ -42,24 +34,6 @@ func Distribute() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||
return
|
||||
}
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||
if tokenGroup != "auto" {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
}
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -81,22 +55,21 @@ func Distribute() func(c *gin.Context) {
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
} else {
|
||||
tokenModelLimit = map[string]bool{}
|
||||
}
|
||||
if tokenModelLimit != nil {
|
||||
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if !ok {
|
||||
// token model limit is empty, all models are not allowed
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||
return
|
||||
}
|
||||
var tokenModelLimit map[string]bool
|
||||
tokenModelLimit, ok = s.(map[string]bool)
|
||||
if !ok {
|
||||
tokenModelLimit = map[string]bool{}
|
||||
}
|
||||
matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
|
||||
if _, ok := tokenModelLimit[matchName]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if shouldSelectChannel {
|
||||
@@ -105,6 +78,23 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
var selectGroup string
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||
// check path is /pg/chat/completions
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||
return
|
||||
}
|
||||
if playgroundRequest.Group != "" {
|
||||
if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组")
|
||||
return
|
||||
}
|
||||
userGroup = playgroundRequest.Group
|
||||
}
|
||||
}
|
||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
showGroup := userGroup
|
||||
@@ -247,7 +237,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||
c.Set("original_model", modelName) // for retry
|
||||
if channel == nil {
|
||||
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
|
||||
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||
@@ -269,11 +259,16 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
||||
} else {
|
||||
// 必须设置为 false,否则在重试到单个 key 的时候会导致日志显示错误
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
|
||||
}
|
||||
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, false)
|
||||
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeAzure:
|
||||
|
||||
@@ -142,7 +142,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
func (channel *Channel) AddAbilities() error {
|
||||
func (channel *Channel) AddAbilities(tx *gorm.DB) error {
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
abilitySet := make(map[string]struct{})
|
||||
@@ -169,8 +169,13 @@ func (channel *Channel) AddAbilities() error {
|
||||
if len(abilities) == 0 {
|
||||
return nil
|
||||
}
|
||||
// choose DB or provided tx
|
||||
useDB := DB
|
||||
if tx != nil {
|
||||
useDB = tx
|
||||
}
|
||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||
err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
||||
err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -284,6 +289,21 @@ func FixAbility() (int, int, error) {
|
||||
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
||||
}
|
||||
defer fixLock.Unlock()
|
||||
|
||||
// truncate abilities table
|
||||
if common.UsingSQLite {
|
||||
err := DB.Exec("DELETE FROM abilities").Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
return 0, 0, err
|
||||
}
|
||||
} else {
|
||||
err := DB.Exec("TRUNCATE TABLE abilities").Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
var channels []*Channel
|
||||
// Find all channels
|
||||
err := DB.Model(&Channel{}).Find(&channels).Error
|
||||
@@ -306,7 +326,7 @@ func FixAbility() (int, int, error) {
|
||||
}
|
||||
// Then add new abilities
|
||||
for _, channel := range chunk {
|
||||
err = channel.AddAbilities()
|
||||
err = channel.AddAbilities(nil)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
||||
failCount++
|
||||
|
||||
128
model/channel.go
128
model/channel.go
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -41,19 +42,25 @@ type Channel struct {
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
Settings string `json:"settings"`
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
// add after v0.8.5
|
||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||
|
||||
// cache info
|
||||
Keys []string `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
type ChannelInfo struct {
|
||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||
MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason
|
||||
MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time
|
||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer interface
|
||||
@@ -67,15 +74,18 @@ func (c *ChannelInfo) Scan(value interface{}) error {
|
||||
return common.Unmarshal(bytesValue, c)
|
||||
}
|
||||
|
||||
func (channel *Channel) getKeys() []string {
|
||||
func (channel *Channel) GetKeys() []string {
|
||||
if channel.Key == "" {
|
||||
return []string{}
|
||||
}
|
||||
if len(channel.Keys) > 0 {
|
||||
return channel.Keys
|
||||
}
|
||||
trimmed := strings.TrimSpace(channel.Key)
|
||||
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
res := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
res[i] = string(v)
|
||||
@@ -95,7 +105,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||||
}
|
||||
|
||||
// Obtain all keys (split by \n)
|
||||
keys := channel.getKeys()
|
||||
keys := channel.GetKeys()
|
||||
if len(keys) == 0 {
|
||||
// No keys available, return error, should disable the channel
|
||||
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
||||
@@ -132,13 +142,13 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||||
return keys[selectedIdx], selectedIdx, nil
|
||||
case constant.MultiKeyModePolling:
|
||||
// Use channel-specific lock to ensure thread-safe polling
|
||||
lock := getChannelPollingLock(channel.Id)
|
||||
lock := GetChannelPollingLock(channel.Id)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||||
if err != nil {
|
||||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
||||
defer func() {
|
||||
@@ -197,7 +207,7 @@ func (channel *Channel) GetGroups() []string {
|
||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
otherInfo := make(map[string]interface{})
|
||||
if channel.OtherInfo != "" {
|
||||
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal other info: " + err.Error())
|
||||
}
|
||||
@@ -328,38 +338,54 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
}
|
||||
|
||||
func BatchInsertChannels(channels []Channel) error {
|
||||
var err error
|
||||
err = DB.Create(&channels).Error
|
||||
if err != nil {
|
||||
return err
|
||||
if len(channels) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, channel_ := range channels {
|
||||
err = channel_.AddAbilities()
|
||||
if err != nil {
|
||||
tx := DB.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, chunk := range lo.Chunk(channels, 50) {
|
||||
if err := tx.Create(&chunk).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
for _, channel_ := range chunk {
|
||||
if err := channel_.AddAbilities(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func BatchDeleteChannels(ids []int) error {
|
||||
//使用事务 删除channel表和channel_ability表
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
// 使用事务 分批删除channel表和abilities表
|
||||
tx := DB.Begin()
|
||||
err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
|
||||
if err != nil {
|
||||
// 回滚事务
|
||||
tx.Rollback()
|
||||
return err
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
// 回滚事务
|
||||
tx.Rollback()
|
||||
return err
|
||||
for _, chunk := range lo.Chunk(ids, 200) {
|
||||
if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
// 提交事务
|
||||
tx.Commit()
|
||||
return err
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func (channel *Channel) GetPriority() int64 {
|
||||
@@ -403,7 +429,7 @@ func (channel *Channel) Insert() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = channel.AddAbilities()
|
||||
err = channel.AddAbilities(nil)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -425,7 +451,7 @@ func (channel *Channel) Update() error {
|
||||
trimmed := strings.TrimSpace(keyStr)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
keys = make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
keys[i] = string(v)
|
||||
@@ -491,8 +517,8 @@ var channelStatusLock sync.Mutex
|
||||
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
|
||||
var channelPollingLocks sync.Map
|
||||
|
||||
// getChannelPollingLock returns or creates a mutex for the given channel ID
|
||||
func getChannelPollingLock(channelId int) *sync.Mutex {
|
||||
// GetChannelPollingLock returns or creates a mutex for the given channel ID
|
||||
func GetChannelPollingLock(channelId int) *sync.Mutex {
|
||||
if lock, exists := channelPollingLocks.Load(channelId); exists {
|
||||
return lock.(*sync.Mutex)
|
||||
}
|
||||
@@ -522,8 +548,8 @@ func CleanupChannelPollingLocks() {
|
||||
})
|
||||
}
|
||||
|
||||
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
|
||||
keys := channel.getKeys()
|
||||
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
|
||||
keys := channel.GetKeys()
|
||||
if len(keys) == 0 {
|
||||
channel.Status = status
|
||||
} else {
|
||||
@@ -541,6 +567,14 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||
} else {
|
||||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
}
|
||||
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
|
||||
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
|
||||
}
|
||||
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||||
channel.Status = common.ChannelStatusAutoDisabled
|
||||
@@ -563,7 +597,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
}
|
||||
if channelCache.ChannelInfo.IsMultiKey {
|
||||
// 如果是多Key模式,更新缓存中的状态
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status)
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||||
//CacheUpdateChannel(channelCache)
|
||||
//return true
|
||||
} else {
|
||||
@@ -571,10 +605,6 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
if channelCache.Status == status {
|
||||
return false
|
||||
}
|
||||
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
||||
if status != common.ChannelStatusEnabled {
|
||||
return false
|
||||
}
|
||||
CacheUpdateChannelStatus(channelId, status)
|
||||
}
|
||||
}
|
||||
@@ -598,7 +628,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
beforeStatus := channel.Status
|
||||
handlerMultiKeyUpdate(channel, usingKey, status)
|
||||
handlerMultiKeyUpdate(channel, usingKey, status, reason)
|
||||
if beforeStatus != channel.Status {
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
@@ -778,7 +808,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
func (channel *Channel) ValidateSettings() error {
|
||||
channelParams := &dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
err := common.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -789,7 +819,7 @@ func (channel *Channel) ValidateSettings() error {
|
||||
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
setting := dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
channel.Setting = nil // 清空设置以避免后续错误
|
||||
@@ -800,7 +830,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
}
|
||||
|
||||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
settingBytes, err := common.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
@@ -811,7 +841,7 @@ func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||||
paramOverride := make(map[string]interface{})
|
||||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||||
err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal param override: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -66,6 +68,20 @@ func InitChannelCache() {
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
//channelsIDM = newChannelId2channel
|
||||
for i, channel := range newChannelId2channel {
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
channel.Keys = channel.GetKeys()
|
||||
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||
if oldChannel, ok := channelsIDM[i]; ok {
|
||||
// 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
|
||||
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
channelsIDM = newChannelId2channel
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
@@ -113,13 +129,6 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string,
|
||||
}
|
||||
|
||||
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(model, "gpt-4o-gizmo") {
|
||||
model = "gpt-4o-gizmo-*"
|
||||
}
|
||||
|
||||
// if memory cache is disabled, get channel directly from database
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model, retry)
|
||||
@@ -127,8 +136,16 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
|
||||
// First, try to find channels with the exact model name.
|
||||
channels := group2model2channels[group][model]
|
||||
|
||||
// If no channels found, try to find channels with the normalized model name.
|
||||
if len(channels) == 0 {
|
||||
normalizedModel := ratio_setting.FormatMatchingModelName(model)
|
||||
channels = group2model2channels[group][normalizedModel]
|
||||
}
|
||||
|
||||
if len(channels) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -203,9 +220,6 @@ func CacheGetChannel(id int) (*Channel, error) {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
if c.Status != common.ChannelStatusEnabled {
|
||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -224,9 +238,6 @@ func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
if c.Status != common.ChannelStatusEnabled {
|
||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||
}
|
||||
return &c.ChannelInfo, nil
|
||||
}
|
||||
|
||||
@@ -239,6 +250,20 @@ func CacheUpdateChannelStatus(id int, status int) {
|
||||
if channel, ok := channelsIDM[id]; ok {
|
||||
channel.Status = status
|
||||
}
|
||||
if status != common.ChannelStatusEnabled {
|
||||
// delete the channel from group2model2channels
|
||||
for group, model2channels := range group2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
for i, channelId := range channels {
|
||||
if channelId == id {
|
||||
// remove the channel from the slice
|
||||
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CacheUpdateChannel(channel *Channel) {
|
||||
|
||||
@@ -64,6 +64,22 @@ var DB *gorm.DB
|
||||
|
||||
var LOG_DB *gorm.DB
|
||||
|
||||
// dropIndexIfExists drops a MySQL index only if it exists to avoid noisy 1091 errors
|
||||
func dropIndexIfExists(tableName string, indexName string) {
|
||||
if !common.UsingMySQL {
|
||||
return
|
||||
}
|
||||
var count int64
|
||||
// Check index existence via information_schema
|
||||
err := DB.Raw(
|
||||
"SELECT COUNT(1) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?",
|
||||
tableName, indexName,
|
||||
).Scan(&count).Error
|
||||
if err == nil && count > 0 {
|
||||
_ = DB.Exec("ALTER TABLE " + tableName + " DROP INDEX " + indexName + ";").Error
|
||||
}
|
||||
}
|
||||
|
||||
func createRootAccountIfNeed() error {
|
||||
var user User
|
||||
//if user.Status != common.UserStatusEnabled {
|
||||
@@ -235,6 +251,9 @@ func InitLogDB() (err error) {
|
||||
}
|
||||
|
||||
func migrateDB() error {
|
||||
// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
|
||||
dropIndexIfExists("models", "uk_model_name")
|
||||
dropIndexIfExists("vendors", "uk_vendor_name")
|
||||
if !common.UsingPostgreSQL {
|
||||
return migrateDBFast()
|
||||
}
|
||||
@@ -250,7 +269,12 @@ func migrateDB() error {
|
||||
&TopUp{},
|
||||
&QuotaData{},
|
||||
&Task{},
|
||||
&Model{},
|
||||
&Vendor{},
|
||||
&PrefillGroup{},
|
||||
&Setup{},
|
||||
&TwoFA{},
|
||||
&TwoFABackupCode{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -259,6 +283,10 @@ func migrateDB() error {
|
||||
}
|
||||
|
||||
func migrateDBFast() error {
|
||||
// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
|
||||
dropIndexIfExists("models", "uk_model_name")
|
||||
dropIndexIfExists("vendors", "uk_vendor_name")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
migrations := []struct {
|
||||
@@ -276,7 +304,12 @@ func migrateDBFast() error {
|
||||
{&TopUp{}, "TopUp"},
|
||||
{&QuotaData{}, "QuotaData"},
|
||||
{&Task{}, "Task"},
|
||||
{&Model{}, "Model"},
|
||||
{&Vendor{}, "Vendor"},
|
||||
{&PrefillGroup{}, "PrefillGroup"},
|
||||
{&Setup{}, "Setup"},
|
||||
{&TwoFA{}, "TwoFA"},
|
||||
{&TwoFABackupCode{}, "TwoFABackupCode"},
|
||||
}
|
||||
// 动态计算migration数量,确保errChan缓冲区足够大
|
||||
errChan := make(chan error, len(migrations))
|
||||
|
||||
30
model/missing_models.go
Normal file
30
model/missing_models.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package model
|
||||
|
||||
// GetMissingModels returns model names that are referenced in the system
|
||||
func GetMissingModels() ([]string, error) {
|
||||
// 1. 获取所有已启用模型(去重)
|
||||
models := GetEnabledModels()
|
||||
if len(models) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// 2. 查询已有的元数据模型名
|
||||
var existing []string
|
||||
if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingSet := make(map[string]struct{}, len(existing))
|
||||
for _, e := range existing {
|
||||
existingSet[e] = struct{}{}
|
||||
}
|
||||
|
||||
// 3. 收集缺失模型
|
||||
var missing []string
|
||||
for _, name := range models {
|
||||
if _, ok := existingSet[name]; !ok {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
return missing, nil
|
||||
}
|
||||
34
model/model_extra.go
Normal file
34
model/model_extra.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package model
|
||||
|
||||
// GetModelEnableGroups 返回指定模型名称可用的用户分组列表。
|
||||
// 使用在 updatePricing() 中维护的缓存映射,O(1) 读取,适合高并发场景。
|
||||
func GetModelEnableGroups(modelName string) []string {
|
||||
// 确保缓存最新
|
||||
GetPricing()
|
||||
|
||||
if modelName == "" {
|
||||
return make([]string, 0)
|
||||
}
|
||||
|
||||
modelEnableGroupsLock.RLock()
|
||||
groups, ok := modelEnableGroups[modelName]
|
||||
modelEnableGroupsLock.RUnlock()
|
||||
if !ok {
|
||||
return make([]string, 0)
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
// GetModelQuotaType 返回指定模型的计费类型(quota_type)。
|
||||
// 同样使用缓存映射,避免每次遍历定价切片。
|
||||
func GetModelQuotaType(modelName string) int {
|
||||
GetPricing()
|
||||
|
||||
modelEnableGroupsLock.RLock()
|
||||
quota, ok := modelQuotaTypeMap[modelName]
|
||||
modelEnableGroupsLock.RUnlock()
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return quota
|
||||
}
|
||||
208
model/model_meta.go
Normal file
208
model/model_meta.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Model 用于存储模型的元数据,例如描述、标签等
|
||||
// ModelName 字段具有唯一性约束,确保每个模型只会出现一次
|
||||
// Tags 字段使用逗号分隔的字符串保存标签集合,后期可根据需要扩展为 JSON 类型
|
||||
// Status: 1 表示启用,0 表示禁用,保留以便后续功能扩展
|
||||
// CreatedTime 和 UpdatedTime 使用 Unix 时间戳(秒)保存方便跨数据库移植
|
||||
// DeletedAt 采用 GORM 的软删除特性,便于后续数据恢复
|
||||
//
|
||||
// 该表设计遵循第三范式(3NF):
|
||||
// 1. 每一列都与主键(Id 或 ModelName)直接相关
|
||||
// 2. 不存在部分依赖(ModelName 是唯一键)
|
||||
// 3. 不存在传递依赖(描述、标签等都依赖于 ModelName,而非依赖于其他非主键列)
|
||||
// 这样既保证了数据一致性,也方便后期扩展
|
||||
|
||||
// 模型名称匹配规则
|
||||
const (
|
||||
NameRuleExact = iota // 0 精确匹配
|
||||
NameRulePrefix // 1 前缀匹配
|
||||
NameRuleContains // 2 包含匹配
|
||||
NameRuleSuffix // 3 后缀匹配
|
||||
)
|
||||
|
||||
type BoundChannel struct {
|
||||
Name string `json:"name"`
|
||||
Type int `json:"type"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Id int `json:"id"`
|
||||
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"`
|
||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
||||
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
||||
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"`
|
||||
|
||||
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
|
||||
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
|
||||
QuotaType int `json:"quota_type" gorm:"-"`
|
||||
NameRule int `json:"name_rule" gorm:"default:0"`
|
||||
}
|
||||
|
||||
// Insert 创建新的模型元数据记录
|
||||
func (mi *Model) Insert() error {
|
||||
now := common.GetTimestamp()
|
||||
mi.CreatedTime = now
|
||||
mi.UpdatedTime = now
|
||||
return DB.Create(mi).Error
|
||||
}
|
||||
|
||||
// IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID)
|
||||
func IsModelNameDuplicated(id int, name string) (bool, error) {
|
||||
if name == "" {
|
||||
return false, nil
|
||||
}
|
||||
var cnt int64
|
||||
err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
|
||||
return cnt > 0, err
|
||||
}
|
||||
|
||||
// Update 更新现有模型记录
|
||||
func (mi *Model) Update() error {
|
||||
mi.UpdatedTime = common.GetTimestamp()
|
||||
// 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新
|
||||
return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
|
||||
Model(&Model{}).
|
||||
Where("id = ?", mi.Id).
|
||||
Omit("created_time").
|
||||
Select("*").
|
||||
Updates(mi).Error
|
||||
}
|
||||
|
||||
// Delete 软删除模型记录
|
||||
func (mi *Model) Delete() error {
|
||||
return DB.Delete(mi).Error
|
||||
}
|
||||
|
||||
// GetModelByName 根据模型名称查询元数据
|
||||
func GetModelByName(name string) (*Model, error) {
|
||||
var mi Model
|
||||
err := DB.Where("model_name = ?", name).First(&mi).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mi, nil
|
||||
}
|
||||
|
||||
// GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响)
|
||||
func GetVendorModelCounts() (map[int64]int64, error) {
|
||||
var stats []struct {
|
||||
VendorID int64
|
||||
Count int64
|
||||
}
|
||||
if err := DB.Model(&Model{}).
|
||||
Select("vendor_id as vendor_id, count(*) as count").
|
||||
Group("vendor_id").
|
||||
Scan(&stats).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := make(map[int64]int64, len(stats))
|
||||
for _, s := range stats {
|
||||
m[s.VendorID] = s.Count
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GetAllModels 分页获取所有模型元数据
|
||||
func GetAllModels(offset int, limit int) ([]*Model, error) {
|
||||
var models []*Model
|
||||
err := DB.Offset(offset).Limit(limit).Find(&models).Error
|
||||
return models, err
|
||||
}
|
||||
|
||||
// GetBoundChannels 查询支持该模型的渠道(名称+类型)
|
||||
func GetBoundChannels(modelName string) ([]BoundChannel, error) {
|
||||
var channels []BoundChannel
|
||||
err := DB.Table("channels").
|
||||
Select("channels.name, channels.type").
|
||||
Joins("join abilities on abilities.channel_id = channels.id").
|
||||
Where("abilities.model = ? AND abilities.enabled = ?", modelName, true).
|
||||
Group("channels.id").
|
||||
Scan(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
// FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含
|
||||
func FindModelByNameWithRule(name string) (*Model, error) {
|
||||
// 1. 精确匹配
|
||||
if m, err := GetModelByName(name); err == nil {
|
||||
return m, nil
|
||||
}
|
||||
// 2. 规则匹配
|
||||
var models []*Model
|
||||
if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var prefixMatch, suffixMatch, containsMatch *Model
|
||||
for _, m := range models {
|
||||
switch m.NameRule {
|
||||
case NameRulePrefix:
|
||||
if strings.HasPrefix(name, m.ModelName) {
|
||||
if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) {
|
||||
prefixMatch = m
|
||||
}
|
||||
}
|
||||
case NameRuleSuffix:
|
||||
if strings.HasSuffix(name, m.ModelName) {
|
||||
if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) {
|
||||
suffixMatch = m
|
||||
}
|
||||
}
|
||||
case NameRuleContains:
|
||||
if strings.Contains(name, m.ModelName) {
|
||||
if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) {
|
||||
containsMatch = m
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if prefixMatch != nil {
|
||||
return prefixMatch, nil
|
||||
}
|
||||
if suffixMatch != nil {
|
||||
return suffixMatch, nil
|
||||
}
|
||||
if containsMatch != nil {
|
||||
return containsMatch, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
// SearchModels 根据关键词和供应商搜索模型,支持分页
|
||||
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
||||
var models []*Model
|
||||
db := DB.Model(&Model{})
|
||||
if keyword != "" {
|
||||
like := "%" + keyword + "%"
|
||||
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
|
||||
}
|
||||
if vendor != "" {
|
||||
// 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配
|
||||
if vid, err := strconv.Atoi(vendor); err == nil {
|
||||
db = db.Where("models.vendor_id = ?", vid)
|
||||
} else {
|
||||
db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
|
||||
}
|
||||
}
|
||||
var total int64
|
||||
err := db.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error
|
||||
return models, total, err
|
||||
}
|
||||
@@ -336,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.LinuxDOClientId = value
|
||||
case "LinuxDOClientSecret":
|
||||
common.LinuxDOClientSecret = value
|
||||
case "LinuxDOMinimumTrustLevel":
|
||||
common.LinuxDOMinimumTrustLevel, _ = strconv.Atoi(value)
|
||||
case "Footer":
|
||||
common.Footer = value
|
||||
case "SystemName":
|
||||
|
||||
126
model/prefill_group.go
Normal file
126
model/prefill_group.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。
|
||||
// Name 字段保持唯一,用于在前端下拉框中展示。
|
||||
// Type 字段用于区分组的类别,可选值如:model、tag、endpoint。
|
||||
// Items 字段使用 JSON 数组保存对应类型的字符串集合,示例:
|
||||
// ["gpt-4o", "gpt-3.5-turbo"]
|
||||
// 设计遵循 3NF,避免冗余,提供灵活扩展能力。
|
||||
|
||||
// JSONValue 基于 json.RawMessage 实现,支持从数据库的 []byte 和 string 两种类型读取
|
||||
type JSONValue json.RawMessage
|
||||
|
||||
// Value 实现 driver.Valuer 接口,用于数据库写入
|
||||
func (j JSONValue) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return []byte(j), nil
|
||||
}
|
||||
|
||||
// Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型
|
||||
func (j *JSONValue) Scan(value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
*j = nil
|
||||
return nil
|
||||
case []byte:
|
||||
// 拷贝底层字节,避免保留底层缓冲区
|
||||
b := make([]byte, len(v))
|
||||
copy(b, v)
|
||||
*j = JSONValue(b)
|
||||
return nil
|
||||
case string:
|
||||
*j = JSONValue([]byte(v))
|
||||
return nil
|
||||
default:
|
||||
// 其他类型尝试序列化为 JSON
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*j = JSONValue(b)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致
|
||||
func (j JSONValue) MarshalJSON() ([]byte, error) {
|
||||
if j == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致
|
||||
func (j *JSONValue) UnmarshalJSON(data []byte) error {
|
||||
if data == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
b := make([]byte, len(data))
|
||||
copy(b, data)
|
||||
*j = JSONValue(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
type PrefillGroup struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"`
|
||||
Type string `json:"type" gorm:"size:32;index;not null"`
|
||||
Items JSONValue `json:"items" gorm:"type:json"`
|
||||
Description string `json:"description,omitempty" gorm:"type:varchar(255)"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
// Insert 新建组
|
||||
func (g *PrefillGroup) Insert() error {
|
||||
now := common.GetTimestamp()
|
||||
g.CreatedTime = now
|
||||
g.UpdatedTime = now
|
||||
return DB.Create(g).Error
|
||||
}
|
||||
|
||||
// IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID)
|
||||
func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) {
|
||||
if name == "" {
|
||||
return false, nil
|
||||
}
|
||||
var cnt int64
|
||||
err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
|
||||
return cnt > 0, err
|
||||
}
|
||||
|
||||
// Update 更新组
|
||||
func (g *PrefillGroup) Update() error {
|
||||
g.UpdatedTime = common.GetTimestamp()
|
||||
return DB.Save(g).Error
|
||||
}
|
||||
|
||||
// DeleteByID 根据 ID 删除组
|
||||
func DeletePrefillGroupByID(id int) error {
|
||||
return DB.Delete(&PrefillGroup{}, id).Error
|
||||
}
|
||||
|
||||
// GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部)
|
||||
func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) {
|
||||
var groups []*PrefillGroup
|
||||
query := DB.Model(&PrefillGroup{})
|
||||
if groupType != "" {
|
||||
query = query.Where("type = ?", groupType)
|
||||
}
|
||||
if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
198
model/pricing.go
198
model/pricing.go
@@ -1,7 +1,10 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting/ratio_setting"
|
||||
@@ -12,6 +15,10 @@ import (
|
||||
|
||||
type Pricing struct {
|
||||
ModelName string `json:"model_name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Tags string `json:"tags,omitempty"`
|
||||
VendorID int `json:"vendor_id,omitempty"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
@@ -21,10 +28,24 @@ type Pricing struct {
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
}
|
||||
|
||||
type PricingVendor struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
pricingMap []Pricing
|
||||
lastGetPricingTime time.Time
|
||||
updatePricingLock sync.Mutex
|
||||
pricingMap []Pricing
|
||||
vendorsList []PricingVendor
|
||||
supportedEndpointMap map[string]common.EndpointInfo
|
||||
lastGetPricingTime time.Time
|
||||
updatePricingLock sync.Mutex
|
||||
|
||||
// 缓存映射:模型名 -> 启用分组 / 计费类型
|
||||
modelEnableGroups = make(map[string][]string)
|
||||
modelQuotaTypeMap = make(map[string]int)
|
||||
modelEnableGroupsLock = sync.RWMutex{}
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -46,6 +67,15 @@ func GetPricing() []Pricing {
|
||||
return pricingMap
|
||||
}
|
||||
|
||||
// GetVendors 返回当前定价接口使用到的供应商信息
|
||||
func GetVendors() []PricingVendor {
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
// 保证先刷新一次
|
||||
GetPricing()
|
||||
}
|
||||
return vendorsList
|
||||
}
|
||||
|
||||
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||||
if model == "" {
|
||||
return make([]constant.EndpointType, 0)
|
||||
@@ -65,6 +95,77 @@ func updatePricing() {
|
||||
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||
return
|
||||
}
|
||||
// 预加载模型元数据与供应商一次,避免循环查询
|
||||
var allMeta []Model
|
||||
_ = DB.Find(&allMeta).Error
|
||||
metaMap := make(map[string]*Model)
|
||||
prefixList := make([]*Model, 0)
|
||||
suffixList := make([]*Model, 0)
|
||||
containsList := make([]*Model, 0)
|
||||
for i := range allMeta {
|
||||
m := &allMeta[i]
|
||||
if m.NameRule == NameRuleExact {
|
||||
metaMap[m.ModelName] = m
|
||||
} else {
|
||||
switch m.NameRule {
|
||||
case NameRulePrefix:
|
||||
prefixList = append(prefixList, m)
|
||||
case NameRuleSuffix:
|
||||
suffixList = append(suffixList, m)
|
||||
case NameRuleContains:
|
||||
containsList = append(containsList, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 将非精确规则模型匹配到 metaMap
|
||||
for _, m := range prefixList {
|
||||
for _, pricingModel := range enableAbilities {
|
||||
if strings.HasPrefix(pricingModel.Model, m.ModelName) {
|
||||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||||
metaMap[pricingModel.Model] = m
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, m := range suffixList {
|
||||
for _, pricingModel := range enableAbilities {
|
||||
if strings.HasSuffix(pricingModel.Model, m.ModelName) {
|
||||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||||
metaMap[pricingModel.Model] = m
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, m := range containsList {
|
||||
for _, pricingModel := range enableAbilities {
|
||||
if strings.Contains(pricingModel.Model, m.ModelName) {
|
||||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||||
metaMap[pricingModel.Model] = m
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 预加载供应商
|
||||
var vendors []Vendor
|
||||
_ = DB.Find(&vendors).Error
|
||||
vendorMap := make(map[int]*Vendor)
|
||||
for i := range vendors {
|
||||
vendorMap[vendors[i].Id] = &vendors[i]
|
||||
}
|
||||
|
||||
// 构建对前端友好的供应商列表
|
||||
vendorsList = make([]PricingVendor, 0, len(vendors))
|
||||
for _, v := range vendors {
|
||||
vendorsList = append(vendorsList, PricingVendor{
|
||||
ID: v.Id,
|
||||
Name: v.Name,
|
||||
Description: v.Description,
|
||||
Icon: v.Icon,
|
||||
})
|
||||
}
|
||||
|
||||
modelGroupsMap := make(map[string]*types.Set[string])
|
||||
|
||||
for _, ability := range enableAbilities {
|
||||
@@ -79,12 +180,9 @@ func updatePricing() {
|
||||
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||||
modelSupportEndpointsStr := make(map[string][]string)
|
||||
|
||||
// 先根据已有能力填充原生端点
|
||||
for _, ability := range enableAbilities {
|
||||
endpoints, ok := modelSupportEndpointsStr[ability.Model]
|
||||
if !ok {
|
||||
endpoints = make([]string, 0)
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
endpoints := modelSupportEndpointsStr[ability.Model]
|
||||
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||||
for _, channelType := range channelTypes {
|
||||
if !common.StringsContains(endpoints, string(channelType)) {
|
||||
@@ -94,6 +192,23 @@ func updatePricing() {
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
|
||||
// 再补充模型自定义端点
|
||||
for modelName, meta := range metaMap {
|
||||
if strings.TrimSpace(meta.Endpoints) == "" {
|
||||
continue
|
||||
}
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||||
endpoints := modelSupportEndpointsStr[modelName]
|
||||
for k := range raw {
|
||||
if !common.StringsContains(endpoints, k) {
|
||||
endpoints = append(endpoints, k)
|
||||
}
|
||||
}
|
||||
modelSupportEndpointsStr[modelName] = endpoints
|
||||
}
|
||||
}
|
||||
|
||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||
for model, endpoints := range modelSupportEndpointsStr {
|
||||
supportedEndpoints := make([]constant.EndpointType, 0)
|
||||
@@ -104,6 +219,45 @@ func updatePricing() {
|
||||
modelSupportEndpointTypes[model] = supportedEndpoints
|
||||
}
|
||||
|
||||
// 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
|
||||
supportedEndpointMap = make(map[string]common.EndpointInfo)
|
||||
// 1. 默认端点
|
||||
for _, endpoints := range modelSupportEndpointTypes {
|
||||
for _, et := range endpoints {
|
||||
if info, ok := common.GetDefaultEndpointInfo(et); ok {
|
||||
if _, exists := supportedEndpointMap[string(et)]; !exists {
|
||||
supportedEndpointMap[string(et)] = info
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 2. 自定义端点(models 表)覆盖默认
|
||||
for _, meta := range metaMap {
|
||||
if strings.TrimSpace(meta.Endpoints) == "" {
|
||||
continue
|
||||
}
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||||
for k, v := range raw {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
|
||||
case map[string]interface{}:
|
||||
ep := common.EndpointInfo{Method: "POST"}
|
||||
if p, ok := val["path"].(string); ok {
|
||||
ep.Path = p
|
||||
}
|
||||
if m, ok := val["method"].(string); ok {
|
||||
ep.Method = strings.ToUpper(m)
|
||||
}
|
||||
supportedEndpointMap[k] = ep
|
||||
default:
|
||||
// ignore unsupported types
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pricingMap = make([]Pricing, 0)
|
||||
for model, groups := range modelGroupsMap {
|
||||
pricing := Pricing{
|
||||
@@ -111,6 +265,18 @@ func updatePricing() {
|
||||
EnableGroup: groups.Items(),
|
||||
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||||
}
|
||||
|
||||
// 补充模型元数据(描述、标签、供应商、状态)
|
||||
if meta, ok := metaMap[model]; ok {
|
||||
// 若模型被禁用(status!=1),则直接跳过,不返回给前端
|
||||
if meta.Status != 1 {
|
||||
continue
|
||||
}
|
||||
pricing.Description = meta.Description
|
||||
pricing.Icon = meta.Icon
|
||||
pricing.Tags = meta.Tags
|
||||
pricing.VendorID = meta.VendorID
|
||||
}
|
||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
pricing.ModelPrice = modelPrice
|
||||
@@ -123,5 +289,21 @@ func updatePricing() {
|
||||
}
|
||||
pricingMap = append(pricingMap, pricing)
|
||||
}
|
||||
|
||||
// 刷新缓存映射,供高并发快速查询
|
||||
modelEnableGroupsLock.Lock()
|
||||
modelEnableGroups = make(map[string][]string)
|
||||
modelQuotaTypeMap = make(map[string]int)
|
||||
for _, p := range pricingMap {
|
||||
modelEnableGroups[p.ModelName] = p.EnableGroup
|
||||
modelQuotaTypeMap[p.ModelName] = p.QuotaType
|
||||
}
|
||||
modelEnableGroupsLock.Unlock()
|
||||
|
||||
lastGetPricingTime = time.Now()
|
||||
}
|
||||
|
||||
// GetSupportedEndpointMap 返回全局端点到路径的映射
|
||||
func GetSupportedEndpointMap() map[string]common.EndpointInfo {
|
||||
return supportedEndpointMap
|
||||
}
|
||||
|
||||
14
model/pricing_refresh.go
Normal file
14
model/pricing_refresh.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package model
|
||||
|
||||
// RefreshPricing 强制立即重新计算与定价相关的缓存。
|
||||
// 该方法用于需要最新数据的内部管理 API,
|
||||
// 因此会绕过默认的 1 分钟延迟刷新。
|
||||
func RefreshPricing() {
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
|
||||
modelSupportEndpointsLock.Lock()
|
||||
defer modelSupportEndpointsLock.Unlock()
|
||||
|
||||
updatePricing()
|
||||
}
|
||||
322
model/twofa.go
Normal file
322
model/twofa.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var ErrTwoFANotEnabled = errors.New("用户未启用2FA")
|
||||
|
||||
// TwoFA 用户2FA设置表
|
||||
type TwoFA struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"unique;not null;index"`
|
||||
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
|
||||
IsEnabled bool `json:"is_enabled" gorm:"default:false"`
|
||||
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
|
||||
LockedUntil *time.Time `json:"locked_until,omitempty"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
// TwoFABackupCode 备用码使用记录表
|
||||
type TwoFABackupCode struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"not null;index"`
|
||||
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
|
||||
IsUsed bool `json:"is_used" gorm:"default:false"`
|
||||
UsedAt *time.Time `json:"used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
// GetTwoFAByUserId 根据用户ID获取2FA设置
|
||||
func GetTwoFAByUserId(userId int) (*TwoFA, error) {
|
||||
if userId == 0 {
|
||||
return nil, errors.New("用户ID不能为空")
|
||||
}
|
||||
|
||||
var twoFA TwoFA
|
||||
err := DB.Where("user_id = ?", userId).First(&twoFA).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil // 返回nil表示未设置2FA
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &twoFA, nil
|
||||
}
|
||||
|
||||
// IsTwoFAEnabled 检查用户是否启用了2FA
|
||||
func IsTwoFAEnabled(userId int) bool {
|
||||
twoFA, err := GetTwoFAByUserId(userId)
|
||||
if err != nil || twoFA == nil {
|
||||
return false
|
||||
}
|
||||
return twoFA.IsEnabled
|
||||
}
|
||||
|
||||
// CreateTwoFA 创建2FA设置
|
||||
func (t *TwoFA) Create() error {
|
||||
// 检查用户是否已存在2FA设置
|
||||
existing, err := GetTwoFAByUserId(t.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
return errors.New("用户已存在2FA设置")
|
||||
}
|
||||
|
||||
// 验证用户存在
|
||||
var user User
|
||||
if err := DB.First(&user, t.UserId).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return DB.Create(t).Error
|
||||
}
|
||||
|
||||
// Update 更新2FA设置
|
||||
func (t *TwoFA) Update() error {
|
||||
if t.Id == 0 {
|
||||
return errors.New("2FA记录ID不能为空")
|
||||
}
|
||||
return DB.Save(t).Error
|
||||
}
|
||||
|
||||
// Delete 删除2FA设置
|
||||
func (t *TwoFA) Delete() error {
|
||||
if t.Id == 0 {
|
||||
return errors.New("2FA记录ID不能为空")
|
||||
}
|
||||
|
||||
// 使用事务确保原子性
|
||||
return DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 同时删除相关的备用码记录(硬删除)
|
||||
if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 硬删除2FA记录
|
||||
return tx.Unscoped().Delete(t).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ResetFailedAttempts 重置失败尝试次数
|
||||
func (t *TwoFA) ResetFailedAttempts() error {
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
return t.Update()
|
||||
}
|
||||
|
||||
// IncrementFailedAttempts 增加失败尝试次数
|
||||
func (t *TwoFA) IncrementFailedAttempts() error {
|
||||
t.FailedAttempts++
|
||||
|
||||
// 检查是否需要锁定
|
||||
if t.FailedAttempts >= common.MaxFailAttempts {
|
||||
lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second)
|
||||
t.LockedUntil = &lockUntil
|
||||
}
|
||||
|
||||
return t.Update()
|
||||
}
|
||||
|
||||
// IsLocked 检查账户是否被锁定
|
||||
func (t *TwoFA) IsLocked() bool {
|
||||
if t.LockedUntil == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*t.LockedUntil)
|
||||
}
|
||||
|
||||
// CreateBackupCodes 创建备用码
|
||||
func CreateBackupCodes(userId int, codes []string) error {
|
||||
return DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 先删除现有的备用码
|
||||
if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建新的备用码记录
|
||||
for _, code := range codes {
|
||||
hashedCode, err := common.HashBackupCode(code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backupCode := TwoFABackupCode{
|
||||
UserId: userId,
|
||||
CodeHash: hashedCode,
|
||||
IsUsed: false,
|
||||
}
|
||||
|
||||
if err := tx.Create(&backupCode).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateBackupCode 验证并使用备用码
|
||||
func ValidateBackupCode(userId int, code string) (bool, error) {
|
||||
if !common.ValidateBackupCode(code) {
|
||||
return false, errors.New("验证码或备用码不正确")
|
||||
}
|
||||
|
||||
normalizedCode := common.NormalizeBackupCode(code)
|
||||
|
||||
// 查找未使用的备用码
|
||||
var backupCodes []TwoFABackupCode
|
||||
if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// 验证备用码
|
||||
for _, bc := range backupCodes {
|
||||
if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) {
|
||||
// 标记为已使用
|
||||
now := time.Now()
|
||||
bc.IsUsed = true
|
||||
bc.UsedAt = &now
|
||||
|
||||
if err := DB.Save(&bc).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetUnusedBackupCodeCount 获取未使用的备用码数量
|
||||
func GetUnusedBackupCodeCount(userId int) (int, error) {
|
||||
var count int64
|
||||
err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error
|
||||
return int(count), err
|
||||
}
|
||||
|
||||
// DisableTwoFA 禁用用户的2FA
|
||||
func DisableTwoFA(userId int) error {
|
||||
twoFA, err := GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if twoFA == nil {
|
||||
return ErrTwoFANotEnabled
|
||||
}
|
||||
|
||||
// 删除2FA设置和备用码
|
||||
return twoFA.Delete()
|
||||
}
|
||||
|
||||
// EnableTwoFA 启用2FA
|
||||
func (t *TwoFA) Enable() error {
|
||||
t.IsEnabled = true
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
return t.Update()
|
||||
}
|
||||
|
||||
// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录
|
||||
func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
|
||||
// 检查是否被锁定
|
||||
if t.IsLocked() {
|
||||
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
|
||||
// 验证TOTP码
|
||||
if !common.ValidateTOTPCode(t.Secret, code) {
|
||||
// 增加失败次数
|
||||
if err := t.IncrementFailedAttempts(); err != nil {
|
||||
common.SysError("更新2FA失败次数失败: " + err.Error())
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 验证成功,重置失败次数并更新最后使用时间
|
||||
now := time.Now()
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
t.LastUsedAt = &now
|
||||
|
||||
if err := t.Update(); err != nil {
|
||||
common.SysError("更新2FA使用记录失败: " + err.Error())
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录
|
||||
func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
|
||||
// 检查是否被锁定
|
||||
if t.IsLocked() {
|
||||
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
|
||||
// 验证备用码
|
||||
valid, err := ValidateBackupCode(t.UserId, code)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !valid {
|
||||
// 增加失败次数
|
||||
if err := t.IncrementFailedAttempts(); err != nil {
|
||||
common.SysError("更新2FA失败次数失败: " + err.Error())
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 验证成功,重置失败次数并更新最后使用时间
|
||||
now := time.Now()
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
t.LastUsedAt = &now
|
||||
|
||||
if err := t.Update(); err != nil {
|
||||
common.SysError("更新2FA使用记录失败: " + err.Error())
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetTwoFAStats 获取2FA统计信息(管理员使用)
|
||||
func GetTwoFAStats() (map[string]interface{}, error) {
|
||||
var totalUsers, enabledUsers int64
|
||||
|
||||
// 总用户数
|
||||
if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 启用2FA的用户数
|
||||
if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
enabledRate := float64(0)
|
||||
if totalUsers > 0 {
|
||||
enabledRate = float64(enabledUsers) / float64(totalUsers) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_users": totalUsers,
|
||||
"enabled_users": enabledUsers,
|
||||
"enabled_rate": fmt.Sprintf("%.1f%%", enabledRate),
|
||||
}, nil
|
||||
}
|
||||
88
model/vendor_meta.go
Normal file
88
model/vendor_meta.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Vendor 用于存储供应商信息,供模型引用
|
||||
// Name 唯一,用于在模型中关联
|
||||
// Icon 采用 @lobehub/icons 的图标名,前端可直接渲染
|
||||
// Status 预留字段,1 表示启用
|
||||
// 本表同样遵循 3NF 设计范式
|
||||
|
||||
type Vendor struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"`
|
||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"`
|
||||
}
|
||||
|
||||
// Insert 创建新的供应商记录
|
||||
func (v *Vendor) Insert() error {
|
||||
now := common.GetTimestamp()
|
||||
v.CreatedTime = now
|
||||
v.UpdatedTime = now
|
||||
return DB.Create(v).Error
|
||||
}
|
||||
|
||||
// IsVendorNameDuplicated 检查供应商名称是否重复(排除自身 ID)
|
||||
func IsVendorNameDuplicated(id int, name string) (bool, error) {
|
||||
if name == "" {
|
||||
return false, nil
|
||||
}
|
||||
var cnt int64
|
||||
err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
|
||||
return cnt > 0, err
|
||||
}
|
||||
|
||||
// Update 更新供应商记录
|
||||
func (v *Vendor) Update() error {
|
||||
v.UpdatedTime = common.GetTimestamp()
|
||||
return DB.Save(v).Error
|
||||
}
|
||||
|
||||
// Delete 软删除供应商
|
||||
func (v *Vendor) Delete() error {
|
||||
return DB.Delete(v).Error
|
||||
}
|
||||
|
||||
// GetVendorByID 根据 ID 获取供应商
|
||||
func GetVendorByID(id int) (*Vendor, error) {
|
||||
var v Vendor
|
||||
err := DB.First(&v, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// GetAllVendors 获取全部供应商(分页)
|
||||
func GetAllVendors(offset int, limit int) ([]*Vendor, error) {
|
||||
var vendors []*Vendor
|
||||
err := DB.Offset(offset).Limit(limit).Find(&vendors).Error
|
||||
return vendors, err
|
||||
}
|
||||
|
||||
// SearchVendors 按关键字搜索供应商
|
||||
func SearchVendors(keyword string, offset int, limit int) ([]*Vendor, int64, error) {
|
||||
db := DB.Model(&Vendor{})
|
||||
if keyword != "" {
|
||||
like := "%" + keyword + "%"
|
||||
db = db.Where("name LIKE ? OR description LIKE ?", like, like)
|
||||
}
|
||||
var total int64
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var vendors []*Vendor
|
||||
if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return vendors, total, nil
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptTokens := 0
|
||||
@@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -90,18 +90,18 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
||||
|
||||
@@ -26,6 +26,7 @@ type Adaptor interface {
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
|
||||
}
|
||||
|
||||
type TaskAdaptor interface {
|
||||
|
||||
@@ -3,25 +3,29 @@ package ali
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -29,18 +33,24 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var fullRequestURL string
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
|
||||
case constant.RelayModeCompletions:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatClaude:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl)
|
||||
default:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
|
||||
case constant.RelayModeCompletions:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
|
||||
default:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
|
||||
}
|
||||
}
|
||||
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
@@ -60,7 +70,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
// docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216
|
||||
// fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True.
|
||||
if strings.Contains(request.Model, "thinking") {
|
||||
request.EnableThinking = true
|
||||
request.Stream = true
|
||||
info.IsStream = true
|
||||
}
|
||||
// fix: ali parameter.enable_thinking must be set to false for non-streaming calls
|
||||
if !info.IsStream {
|
||||
request.EnableThinking = false
|
||||
@@ -101,19 +117,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = aliEmbeddingHandler(c, resp)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -223,7 +223,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
helper.SetEventStreamHeaders(c)
|
||||
// 处理流式请求的 ping 保活
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
if generalSettings.PingIntervalEnabled {
|
||||
if generalSettings.PingIntervalEnabled && !info.DisablePing {
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
stopPinger = startPingKeepAlive(c, pingInterval)
|
||||
// 使用defer确保在任何情况下都能停止ping goroutine
|
||||
|
||||
@@ -22,6 +22,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", request)
|
||||
|
||||
@@ -13,6 +13,7 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
@@ -54,6 +55,9 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"anthropic.claude-opus-4-20250514-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
}
|
||||
|
||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -19,20 +18,31 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
"github.com/aws/smithy-go/auth/bearer"
|
||||
)
|
||||
|
||||
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
||||
awsSecret := strings.Split(info.ApiKey, "|")
|
||||
if len(awsSecret) != 3 {
|
||||
var client *bedrockruntime.Client
|
||||
switch len(awsSecret) {
|
||||
case 2:
|
||||
apiKey := awsSecret[0]
|
||||
region := awsSecret[1]
|
||||
client = bedrockruntime.New(bedrockruntime.Options{
|
||||
Region: region,
|
||||
BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}},
|
||||
})
|
||||
case 3:
|
||||
ak := awsSecret[0]
|
||||
sk := awsSecret[1]
|
||||
region := awsSecret[2]
|
||||
client = bedrockruntime.New(bedrockruntime.Options{
|
||||
Region: region,
|
||||
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
|
||||
})
|
||||
default:
|
||||
return nil, errors.New("invalid aws secret key")
|
||||
}
|
||||
ak := awsSecret[0]
|
||||
sk := awsSecret[1]
|
||||
region := awsSecret[2]
|
||||
client := bedrockruntime.New(bedrockruntime.Options{
|
||||
Region: region,
|
||||
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
|
||||
})
|
||||
|
||||
return client, nil
|
||||
}
|
||||
@@ -102,14 +112,14 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
@@ -154,14 +164,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||
}
|
||||
stream := awsResp.GetStream()
|
||||
defer stream.Close()
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
EnableCitation: false,
|
||||
UserId: request.User,
|
||||
}
|
||||
if request.MaxTokens != 0 {
|
||||
maxTokens := int(request.MaxTokens)
|
||||
if request.MaxTokens == 1 {
|
||||
if request.GetMaxTokens() != 0 {
|
||||
maxTokens := int(request.GetMaxTokens())
|
||||
if request.GetMaxTokens() == 1 {
|
||||
maxTokens = 2
|
||||
}
|
||||
baiduRequest.MaxOutputTokens = &maxTokens
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
@@ -18,10 +19,14 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -38,20 +43,33 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil
|
||||
case constant.RelayModeImagesEdits:
|
||||
return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 0 || keyParts[0] == "" {
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+keyParts[0])
|
||||
return nil
|
||||
}
|
||||
@@ -94,11 +112,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
adaptor := openai.Adaptor{}
|
||||
usage, err = adaptor.DoResponse(c, resp, info)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
@@ -99,7 +104,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
|
||||
err, usage = ClaudeHandler(c, resp, info, a.RequestMode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ var ModelList = []string{
|
||||
"claude-sonnet-4-20250514-thinking",
|
||||
"claude-opus-4-20250514",
|
||||
"claude-opus-4-20250514-thinking",
|
||||
"claude-opus-4-1-20250805",
|
||||
"claude-opus-4-1-20250805-thinking",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -149,7 +149,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
|
||||
claudeRequest := dto.ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
MaxTokens: textRequest.GetMaxTokens(),
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
@@ -612,8 +612,8 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
@@ -704,8 +704,8 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
@@ -740,7 +740,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -5,7 +5,7 @@ import "one-api/dto"
|
||||
type CfRequest struct {
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Lora string `json:"lora,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -7,7 +7,7 @@ type CohereRequest struct {
|
||||
ChatHistory []ChatHistory `json:"chat_history"`
|
||||
Message string `json:"message"`
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
MaxTokens uint `json:"max_tokens"`
|
||||
SafetyMode string `json:"safety_mode,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ConvertAudioRequest implements channel.Adaptor.
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
|
||||
@@ -19,10 +19,14 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
|
||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
||||
BotType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -20,6 +20,26 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
if len(request.Contents) > 0 {
|
||||
for i, content := range request.Contents {
|
||||
if i == 0 {
|
||||
if request.Contents[0].Role == "" {
|
||||
request.Contents[0].Role = "user"
|
||||
}
|
||||
}
|
||||
for _, part := range content.Parts {
|
||||
if part.FileData != nil {
|
||||
if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") {
|
||||
part.FileData.MimeType = "video/webm"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
|
||||
@@ -51,13 +71,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
// build gemini imagen request
|
||||
geminiRequest := GeminiImageRequest{
|
||||
Instances: []GeminiImageInstance{
|
||||
geminiRequest := dto.GeminiImageRequest{
|
||||
Instances: []dto.GeminiImageInstance{
|
||||
{
|
||||
Prompt: request.Prompt,
|
||||
},
|
||||
},
|
||||
Parameters: GeminiImageParameters{
|
||||
Parameters: dto.GeminiImageParameters{
|
||||
SampleCount: request.N,
|
||||
AspectRatio: aspectRatio,
|
||||
PersonGeneration: "allow_adult", // default allow adult
|
||||
@@ -94,12 +114,19 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
|
||||
action := "embedContent"
|
||||
if info.IsGeminiBatchEmbedding {
|
||||
action = "batchEmbedContents"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent?alt=sse"
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
info.DisablePing = true
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
}
|
||||
@@ -136,29 +163,38 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
if len(inputs) == 0 {
|
||||
return nil, errors.New("input is empty")
|
||||
}
|
||||
|
||||
// only process the first input
|
||||
geminiRequest := GeminiEmbeddingRequest{
|
||||
Content: GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: inputs[0],
|
||||
// We always build a batch-style payload with `requests`, so ensure we call the
|
||||
// batch endpoint upstream to avoid payload/endpoint mismatches.
|
||||
info.IsGeminiBatchEmbedding = true
|
||||
// process all inputs
|
||||
geminiRequests := make([]map[string]interface{}, 0, len(inputs))
|
||||
for _, input := range inputs {
|
||||
geminiRequest := map[string]interface{}{
|
||||
"model": fmt.Sprintf("models/%s", info.UpstreamModelName),
|
||||
"content": dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{
|
||||
Text: input,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// set specific parameters for different models
|
||||
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
|
||||
switch info.UpstreamModelName {
|
||||
case "text-embedding-004":
|
||||
// except embedding-001 supports setting `OutputDimensionality`
|
||||
if request.Dimensions > 0 {
|
||||
geminiRequest.OutputDimensionality = request.Dimensions
|
||||
}
|
||||
|
||||
// set specific parameters for different models
|
||||
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
|
||||
switch info.UpstreamModelName {
|
||||
case "text-embedding-004", "gemini-embedding-exp-03-07", "gemini-embedding-001":
|
||||
// Only newer models introduced after 2024 support OutputDimensionality
|
||||
if request.Dimensions > 0 {
|
||||
geminiRequest["outputDimensionality"] = request.Dimensions
|
||||
}
|
||||
}
|
||||
geminiRequests = append(geminiRequests, geminiRequest)
|
||||
}
|
||||
|
||||
return geminiRequest, nil
|
||||
return map[string]interface{}{
|
||||
"requests": geminiRequests,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
@@ -172,6 +208,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
|
||||
strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
|
||||
return NativeGeminiEmbeddingHandler(c, resp, info)
|
||||
}
|
||||
if info.IsStream {
|
||||
return GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
} else {
|
||||
@@ -196,18 +236,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return GeminiChatHandler(c, info, resp)
|
||||
}
|
||||
|
||||
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
|
||||
// // 没有请求-thinking的情况下,产生思考token,则按照思考模型计费
|
||||
// if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
|
||||
// !strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
// thinkingModelName := info.OriginModelName + "-thinking"
|
||||
// if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
|
||||
// info.OriginModelName = thinkingModelName
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -28,7 +30,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
|
||||
// 解析为 Gemini 原生响应格式
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
@@ -62,6 +64,42 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
|
||||
if info.IsGeminiBatchEmbedding {
|
||||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
} else {
|
||||
var geminiResponse dto.GeminiEmbeddingResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
@@ -71,7 +109,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
@@ -110,10 +148,14 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
|
||||
info.SendResponseCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
|
||||
@@ -49,12 +49,20 @@ const (
|
||||
flash25LiteMaxBudget = 24576
|
||||
)
|
||||
|
||||
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
|
||||
func clampThinkingBudget(modelName string, budget int) int {
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
func isNew25ProModel(modelName string) bool {
|
||||
return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
}
|
||||
|
||||
func is25FlashLiteModel(modelName string) bool {
|
||||
return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
}
|
||||
|
||||
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
|
||||
func clampThinkingBudget(modelName string, budget int) int {
|
||||
isNew25Pro := isNew25ProModel(modelName)
|
||||
is25FlashLite := is25FlashLiteModel(modelName)
|
||||
|
||||
if is25FlashLite {
|
||||
if budget < flash25LiteMinBudget {
|
||||
@@ -81,7 +89,34 @@ func clampThinkingBudget(modelName string, budget int) int {
|
||||
return budget
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||
// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
|
||||
// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
|
||||
// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
|
||||
func clampThinkingBudgetByEffort(modelName string, effort string) int {
|
||||
isNew25Pro := isNew25ProModel(modelName)
|
||||
is25FlashLite := is25FlashLiteModel(modelName)
|
||||
|
||||
maxBudget := 0
|
||||
if is25FlashLite {
|
||||
maxBudget = flash25LiteMaxBudget
|
||||
}
|
||||
if isNew25Pro {
|
||||
maxBudget = pro25MaxBudget
|
||||
} else {
|
||||
maxBudget = flash25MaxBudget
|
||||
}
|
||||
switch effort {
|
||||
case "high":
|
||||
maxBudget = maxBudget * 80 / 100
|
||||
case "medium":
|
||||
maxBudget = maxBudget * 50 / 100
|
||||
case "low":
|
||||
maxBudget = maxBudget * 20 / 100
|
||||
}
|
||||
return clampThinkingBudget(modelName, maxBudget)
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
@@ -93,7 +128,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
@@ -113,22 +148,27 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
if isUnsupported {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
|
||||
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
|
||||
} else {
|
||||
if len(oaiRequest) > 0 {
|
||||
// 如果有reasoningEffort参数,则根据其值设置思考预算
|
||||
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||
if !isNew25Pro {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
@@ -137,14 +177,14 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
geminiRequest := dto.GeminiChatRequest{
|
||||
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
GenerationConfig: dto.GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
MaxOutputTokens: textRequest.GetMaxTokens(),
|
||||
Seed: int64(textRequest.Seed),
|
||||
},
|
||||
}
|
||||
@@ -156,11 +196,41 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
}
|
||||
|
||||
ThinkingAdaptor(&geminiRequest, info)
|
||||
adaptorWithExtraBody := false
|
||||
|
||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
if len(textRequest.ExtraBody) > 0 {
|
||||
if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
var extraBody map[string]interface{}
|
||||
if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
|
||||
return nil, fmt.Errorf("invalid extra body: %w", err)
|
||||
}
|
||||
// eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
|
||||
if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
|
||||
adaptorWithExtraBody = true
|
||||
if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
|
||||
if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
|
||||
budgetInt := int(budget)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(budgetInt),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !adaptorWithExtraBody {
|
||||
ThinkingAdaptor(&geminiRequest, info, textRequest)
|
||||
}
|
||||
|
||||
safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
for _, category := range SafetySettingList {
|
||||
safetySettings = append(safetySettings, GeminiChatSafetySettings{
|
||||
safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
|
||||
Category: category,
|
||||
Threshold: model_setting.GetGeminiSafetySetting(category),
|
||||
})
|
||||
@@ -198,17 +268,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
if codeExecution {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
CodeExecution: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if googleSearch {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
})
|
||||
}
|
||||
@@ -238,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
continue
|
||||
} else if message.Role == "tool" || message.Role == "function" {
|
||||
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
|
||||
Role: "user",
|
||||
})
|
||||
}
|
||||
@@ -265,18 +335,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
}
|
||||
|
||||
functionResp := &FunctionResponse{
|
||||
functionResp := &dto.GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: contentMap,
|
||||
}
|
||||
|
||||
*parts = append(*parts, GeminiPart{
|
||||
*parts = append(*parts, dto.GeminiPart{
|
||||
FunctionResponse: functionResp,
|
||||
})
|
||||
continue
|
||||
}
|
||||
var parts []GeminiPart
|
||||
content := GeminiChatContent{
|
||||
var parts []dto.GeminiPart
|
||||
content := dto.GeminiChatContent{
|
||||
Role: message.Role,
|
||||
}
|
||||
// isToolCall := false
|
||||
@@ -290,8 +360,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
|
||||
}
|
||||
}
|
||||
toolCall := GeminiPart{
|
||||
FunctionCall: &FunctionCall{
|
||||
toolCall := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: call.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
@@ -308,7 +378,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if part.Text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
@@ -331,8 +401,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
|
||||
Data: fileData.Base64Data,
|
||||
},
|
||||
@@ -342,8 +412,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -357,8 +427,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -371,8 +441,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: "audio/" + part.GetInputAudio().Format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -392,8 +462,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
|
||||
if len(system_content) > 0 {
|
||||
geminiRequest.SystemInstructions = &GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
geminiRequest.SystemInstructions = &dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{
|
||||
Text: strings.Join(system_content, "\n"),
|
||||
},
|
||||
@@ -636,7 +706,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
|
||||
return data
|
||||
}
|
||||
|
||||
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
|
||||
var argsBytes []byte
|
||||
var err error
|
||||
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||
@@ -658,7 +728,7 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
@@ -725,10 +795,9 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
||||
isStop := false
|
||||
hasImage := false
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
||||
isStop = true
|
||||
@@ -759,7 +828,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
hasImage = true
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
isTools = true
|
||||
@@ -767,6 +835,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
call.SetIndex(len(choice.Delta.ToolCalls))
|
||||
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
|
||||
}
|
||||
|
||||
} else if part.Thought {
|
||||
isThought = true
|
||||
texts = append(texts, part.Text)
|
||||
@@ -796,7 +865,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Choices = choices
|
||||
return &response, isStop, hasImage
|
||||
return &response, isStop
|
||||
}
|
||||
|
||||
func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
|
||||
@@ -816,7 +885,7 @@ func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.Ch
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal stream response: %w", err)
|
||||
}
|
||||
openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, info.ShouldIncludeUsage)
|
||||
openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -824,23 +893,32 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
// responseText := ""
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
responseText := strings.Builder{}
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
|
||||
respCount := 0
|
||||
finishReason := constant.FinishReasonStop
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
if hasImage {
|
||||
imageCount++
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
if part.Text != "" {
|
||||
responseText.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
@@ -858,11 +936,23 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
|
||||
if respCount == 0 {
|
||||
if info.SendResponseCount == 0 {
|
||||
// send first response
|
||||
err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil))
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
|
||||
if response.IsToolCall() {
|
||||
emptyResponse.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 1)
|
||||
emptyResponse.Choices[0].Delta.ToolCalls[0] = *response.GetFirstToolCall()
|
||||
emptyResponse.Choices[0].Delta.ToolCalls[0].Function.Arguments = ""
|
||||
finishReason = constant.FinishReasonToolCalls
|
||||
err = handleStream(c, info, emptyResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
|
||||
response.ClearToolCalls()
|
||||
if response.IsFinished() {
|
||||
response.Choices[0].FinishReason = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -871,13 +961,12 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
if isStop {
|
||||
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop))
|
||||
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
|
||||
}
|
||||
respCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if respCount == 0 {
|
||||
if info.SendResponseCount == 0 {
|
||||
// 空补全,报错不计费
|
||||
// empty response, throw an error
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
@@ -892,6 +981,16 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
|
||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := handleFinalStream(c, info, response)
|
||||
if err != nil {
|
||||
@@ -913,7 +1012,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
@@ -941,13 +1040,26 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
}
|
||||
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
responseBody, err = common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
case relaycommon.RelayFormatClaude:
|
||||
claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
|
||||
claudeRespStr, err := common.Marshal(claudeResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
case relaycommon.RelayFormatGemini:
|
||||
break
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
c.Writer.Write(jsonResponse)
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
@@ -959,7 +1071,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var geminiResponse GeminiEmbeddingResponse
|
||||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
@@ -967,14 +1079,16 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
// convert to openai format response
|
||||
openAIResponse := dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: []dto.OpenAIEmbeddingResponseItem{
|
||||
{
|
||||
Object: "embedding",
|
||||
Embedding: geminiResponse.Embedding.Values,
|
||||
Index: 0,
|
||||
},
|
||||
},
|
||||
Model: info.UpstreamModelName,
|
||||
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
|
||||
Model: info.UpstreamModelName,
|
||||
}
|
||||
|
||||
for i, embedding := range geminiResponse.Embeddings {
|
||||
openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||
Object: "embedding",
|
||||
Embedding: embedding.Values,
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
|
||||
// calculate usage
|
||||
@@ -1005,7 +1119,7 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse GeminiImageResponse
|
||||
var geminiResponse dto.GeminiImageResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
@@ -13,11 +12,18 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxTokens: request.MaxTokens,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
Tools: request.Tools,
|
||||
ToolChoice: request.ToolChoice,
|
||||
}
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
111
relay/channel/moonshot/adaptor.go
Normal file
111
relay/channel/moonshot/adaptor.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package moonshot
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertImageRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatClaude:
|
||||
return fmt.Sprintf("%s/anthropic/v1/messages", info.BaseUrl), nil
|
||||
default:
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeCompletions {
|
||||
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
case relaycommon.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
openaiAdaptor := openai.Adaptor{}
|
||||
openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
|
||||
|
||||
@@ -60,7 +60,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
|
||||
TopK: request.TopK,
|
||||
Stop: Stop,
|
||||
Tools: request.Tools,
|
||||
MaxTokens: request.MaxTokens,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
ResponseFormat: request.ResponseFormat,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
|
||||
@@ -9,13 +9,13 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ai360"
|
||||
"one-api/relay/channel/lingyiwanwu"
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
"one-api/relay/channel/openrouter"
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -34,15 +34,55 @@ type Adaptor struct {
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
// parseReasoningEffortFromModelSuffix 从模型名称中解析推理级别
|
||||
// support OAI models: o1-mini/o3-mini/o4-mini/o1/o3 etc...
|
||||
// minimal effort only available in gpt-5
|
||||
func parseReasoningEffortFromModelSuffix(model string) (string, string) {
|
||||
effortSuffixes := []string{"-high", "-minimal", "-low", "-medium"}
|
||||
for _, suffix := range effortSuffixes {
|
||||
if strings.HasSuffix(model, suffix) {
|
||||
effort := strings.TrimPrefix(suffix, "-")
|
||||
originModel := strings.TrimSuffix(model, suffix)
|
||||
return effort, originModel
|
||||
}
|
||||
}
|
||||
return "", model
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
// 使用 service.GeminiToOpenAIRequest 转换请求格式
|
||||
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, openaiRequest)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
//if !strings.Contains(request.Model, "claude") {
|
||||
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||
//}
|
||||
//if common.DebugEnabled {
|
||||
// bodyBytes := []byte(common.GetJsonString(request))
|
||||
// err := os.WriteFile(fmt.Sprintf("claude_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644)
|
||||
// if err != nil {
|
||||
// println(fmt.Sprintf("failed to save request body to file: %v", err))
|
||||
// }
|
||||
//}
|
||||
aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if info.SupportStreamOptions {
|
||||
//if common.DebugEnabled {
|
||||
// println(fmt.Sprintf("convert claude to openai request result: %s", common.GetJsonString(aiRequest)))
|
||||
// // Save request body to file for debugging
|
||||
// bodyBytes := []byte(common.GetJsonString(aiRequest))
|
||||
// err = os.WriteFile(fmt.Sprintf("claude_to_openai_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644)
|
||||
// if err != nil {
|
||||
// println(fmt.Sprintf("failed to save request body to file: %v", err))
|
||||
// }
|
||||
//}
|
||||
if info.SupportStreamOptions && info.IsStream {
|
||||
aiRequest.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
@@ -64,9 +104,6 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||
@@ -113,6 +150,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||||
return url, nil
|
||||
default:
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
}
|
||||
@@ -163,23 +203,65 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if len(request.Usage) == 0 {
|
||||
request.Usage = json.RawMessage(`{"include":true}`)
|
||||
}
|
||||
// 适配 OpenRouter 的 thinking 后缀
|
||||
if strings.HasSuffix(info.UpstreamModelName, "-thinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
request.Model = info.UpstreamModelName
|
||||
if len(request.Reasoning) == 0 {
|
||||
reasoning := map[string]any{
|
||||
"enabled": true,
|
||||
}
|
||||
if request.ReasoningEffort != "" && request.ReasoningEffort != "none" {
|
||||
reasoning["effort"] = request.ReasoningEffort
|
||||
}
|
||||
marshal, err := common.Marshal(reasoning)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling reasoning: %w", err)
|
||||
}
|
||||
request.Reasoning = marshal
|
||||
}
|
||||
} else {
|
||||
if len(request.Reasoning) == 0 {
|
||||
// 适配 OpenAI 的 ReasoningEffort 格式
|
||||
if request.ReasoningEffort != "" {
|
||||
reasoning := map[string]any{
|
||||
"enabled": true,
|
||||
}
|
||||
if request.ReasoningEffort != "none" {
|
||||
reasoning["effort"] = request.ReasoningEffort
|
||||
marshal, err := common.Marshal(reasoning)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling reasoning: %w", err)
|
||||
}
|
||||
request.Reasoning = marshal
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "o") {
|
||||
if strings.HasPrefix(request.Model, "o") || strings.HasPrefix(request.Model, "gpt-5") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = 0
|
||||
}
|
||||
request.Temperature = nil
|
||||
if strings.HasSuffix(request.Model, "-high") {
|
||||
request.ReasoningEffort = "high"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-high")
|
||||
} else if strings.HasSuffix(request.Model, "-low") {
|
||||
request.ReasoningEffort = "low"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-low")
|
||||
} else if strings.HasSuffix(request.Model, "-medium") {
|
||||
request.ReasoningEffort = "medium"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-medium")
|
||||
|
||||
if strings.HasPrefix(request.Model, "o") {
|
||||
request.Temperature = nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(request.Model, "gpt-5") {
|
||||
if request.Model != "gpt-5-chat-latest" {
|
||||
request.Temperature = nil
|
||||
}
|
||||
}
|
||||
|
||||
// 转换模型推理力度后缀
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(request.Model)
|
||||
if effort != "" {
|
||||
request.ReasoningEffort = effort
|
||||
request.Model = originModel
|
||||
}
|
||||
|
||||
info.ReasoningEffort = request.ReasoningEffort
|
||||
info.UpstreamModelName = request.Model
|
||||
|
||||
@@ -396,16 +478,11 @@ func detectImageMimeType(filename string) string {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// 模型后缀转换 reasoning effort
|
||||
if strings.HasSuffix(request.Model, "-high") {
|
||||
request.Reasoning.Effort = "high"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-high")
|
||||
} else if strings.HasSuffix(request.Model, "-low") {
|
||||
request.Reasoning.Effort = "low"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-low")
|
||||
} else if strings.HasSuffix(request.Model, "-medium") {
|
||||
request.Reasoning.Effort = "medium"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-medium")
|
||||
// 转换模型推理力度后缀
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(request.Model)
|
||||
if effort != "" {
|
||||
request.Reasoning.Effort = effort
|
||||
request.Model = originModel
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
@@ -456,8 +533,6 @@ func (a *Adaptor) GetModelList() []string {
|
||||
switch a.ChannelType {
|
||||
case constant.ChannelType360:
|
||||
return ai360.ModelList
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return moonshot.ModelList
|
||||
case constant.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ModelList
|
||||
case constant.ChannelTypeMiniMax:
|
||||
@@ -475,8 +550,6 @@ func (a *Adaptor) GetChannelName() string {
|
||||
switch a.ChannelType {
|
||||
case constant.ChannelType360:
|
||||
return ai360.ChannelName
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return moonshot.ChannelName
|
||||
case constant.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ChannelName
|
||||
case constant.ChannelTypeMiniMax:
|
||||
|
||||
@@ -18,6 +18,9 @@ var ModelList = []string{
|
||||
"o3-mini-high", "o3-mini-2025-01-31-high",
|
||||
"o3-mini-low", "o3-mini-2025-01-31-low",
|
||||
"o3-mini-medium", "o3-mini-2025-01-31-medium",
|
||||
"gpt-5", "gpt-5-2025-08-07", "gpt-5-chat-latest",
|
||||
"gpt-5-mini", "gpt-5-mini-2025-08-07",
|
||||
"gpt-5-nano", "gpt-5-nano-2025-08-07",
|
||||
"o1", "o1-2024-12-17",
|
||||
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
|
||||
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17",
|
||||
|
||||
@@ -2,6 +2,9 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/samber/lo"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -16,11 +19,14 @@ import (
|
||||
// 辅助函数
|
||||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
case relaycommon.RelayFormatClaude:
|
||||
return handleClaudeFormat(c, data, info)
|
||||
case relaycommon.RelayFormatGemini:
|
||||
return handleGeminiFormat(c, data, info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -41,6 +47,36 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
common.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// 如果返回 nil,表示没有实际内容,跳过发送
|
||||
if geminiResponse == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "failed to marshal gemini response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// send gemini format response
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
@@ -151,7 +187,9 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
|
||||
*containStreamUsage = true
|
||||
*usage = lastStreamResponse.Usage
|
||||
if !info.ShouldIncludeUsage {
|
||||
*shouldSendLastResp = false
|
||||
*shouldSendLastResp = lo.SomeBy(lastStreamResponse.Choices, func(choice dto.ChatCompletionsStreamResponseChoice) bool {
|
||||
return choice.Delta.GetContentString() != "" || choice.Delta.GetReasoningContent() != ""
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +223,37 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
for _, resp := range claudeResponses {
|
||||
_ = helper.ClaudeData(c, *resp)
|
||||
}
|
||||
|
||||
case relaycommon.RelayFormatGemini:
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
|
||||
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
|
||||
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
|
||||
// 暂不知是否有程序会不兼容。
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// openai 流响应开头的空数据
|
||||
if geminiResponse == nil {
|
||||
return
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling gemini response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 发送最终的 Gemini 响应
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -180,12 +180,15 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("upstream response body:", string(responseBody))
|
||||
}
|
||||
err = common.Unmarshal(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
|
||||
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
|
||||
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
forceFormat := false
|
||||
@@ -223,6 +226,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
case relaycommon.RelayFormatGemini:
|
||||
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
|
||||
geminiRespStr, err := common.Marshal(geminiResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = geminiRespStr
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
@@ -28,8 +28,8 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if responsesResponse.Error != nil {
|
||||
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
|
||||
if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
@@ -37,9 +37,14 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
|
||||
// compute usage
|
||||
usage := dto.Usage{}
|
||||
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = responsesResponse.Usage.TotalTokens
|
||||
if responsesResponse.Usage != nil {
|
||||
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = responsesResponse.Usage.TotalTokens
|
||||
if responsesResponse.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
// 解析 Tools 用量
|
||||
for _, tool := range responsesResponse.Tools {
|
||||
info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
|
||||
@@ -64,9 +69,14 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
if streamResponse.Response.Usage != nil {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
// 处理输出文本
|
||||
responseTextBuilder.WriteString(streamResponse.Delta)
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,30 +18,6 @@ import (
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||
|
||||
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
palmRequest := PaLMChatRequest{
|
||||
Prompt: PaLMPrompt{
|
||||
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
||||
},
|
||||
Temperature: textRequest.Temperature,
|
||||
CandidateCount: textRequest.N,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.MaxTokens,
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
palmMessage := PaLMChatMessage{
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
palmMessage.Author = "0"
|
||||
} else {
|
||||
palmMessage.Author = "1"
|
||||
}
|
||||
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
||||
}
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxTokens: request.MaxTokens,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
|
||||
@@ -25,6 +25,11 @@ type Adaptor struct {
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -35,6 +35,7 @@ var claudeModelMap = map[string]string{
|
||||
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
|
||||
"claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
|
||||
"claude-opus-4-20250514": "claude-opus-4@20250514",
|
||||
"claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
@@ -44,6 +45,11 @@ type Adaptor struct {
|
||||
AccountCredentials Credentials
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
geminiAdaptor := gemini.Adaptor{}
|
||||
return geminiAdaptor.ConvertGeminiRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
c.Set("request_model", v)
|
||||
@@ -69,8 +75,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.RequestMode = RequestModeClaude
|
||||
} else if strings.Contains(info.UpstreamModelName, "llama") {
|
||||
a.RequestMode = RequestModeLlama
|
||||
} else {
|
||||
a.RequestMode = RequestModeGemini
|
||||
}
|
||||
a.RequestMode = RequestModeGemini
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
@@ -231,7 +238,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
} else {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
||||
err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||
|
||||
@@ -36,7 +36,12 @@ var Cache = asynccache.NewAsyncCache(asynccache.Options{
|
||||
})
|
||||
|
||||
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
|
||||
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
|
||||
var cacheKey string
|
||||
if info.ChannelIsMultiKey {
|
||||
cacheKey = fmt.Sprintf("access-token-%d-%d", info.ChannelId, info.ChannelMultiKeyIndex)
|
||||
} else {
|
||||
cacheKey = fmt.Sprintf("access-token-%d", info.ChannelId)
|
||||
}
|
||||
val, err := Cache.Get(cacheKey)
|
||||
if err == nil {
|
||||
return val.(string), nil
|
||||
|
||||
@@ -23,10 +23,14 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -191,6 +195,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
|
||||
case constant.RelayModeImagesEdits:
|
||||
return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
@@ -227,18 +235,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
|
||||
}
|
||||
adaptor := openai.Adaptor{}
|
||||
usage, err = adaptor.DoResponse(c, resp, info)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
//panic("implement me")
|
||||
|
||||
@@ -17,6 +17,11 @@ type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string,
|
||||
xunfeiRequest.Parameter.Chat.Domain = domain
|
||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
|
||||
xunfeiRequest.Payload.Message.Text = messages
|
||||
return &xunfeiRequest
|
||||
}
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -49,8 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
token := getZhipuToken(info.ApiKey)
|
||||
req.Set("Authorization", token)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,69 +1,10 @@
|
||||
package zhipu_4v
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
||||
// chatglm_std, chatglm_lite
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
||||
|
||||
var zhipuTokens sync.Map
|
||||
var expSeconds int64 = 24 * 3600
|
||||
|
||||
func getZhipuToken(apikey string) string {
|
||||
data, ok := zhipuTokens.Load(apikey)
|
||||
if ok {
|
||||
tokenData := data.(tokenData)
|
||||
if time.Now().Before(tokenData.ExpiryTime) {
|
||||
return tokenData.Token
|
||||
}
|
||||
}
|
||||
|
||||
split := strings.Split(apikey, ".")
|
||||
if len(split) != 2 {
|
||||
common.SysError("invalid zhipu key: " + apikey)
|
||||
return ""
|
||||
}
|
||||
|
||||
id := split[0]
|
||||
secret := split[1]
|
||||
|
||||
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
||||
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
||||
|
||||
timestamp := time.Now().UnixNano() / 1e6
|
||||
|
||||
payload := jwt.MapClaims{
|
||||
"api_key": id,
|
||||
"exp": expMillis,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||
|
||||
token.Header["alg"] = "HS256"
|
||||
token.Header["sign_type"] = "SIGN"
|
||||
|
||||
tokenString, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
zhipuTokens.Store(apikey, tokenData{
|
||||
Token: tokenString,
|
||||
ExpiryTime: expiryTime,
|
||||
})
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
@@ -105,7 +46,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxTokens: request.MaxTokens,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
Stop: Stop,
|
||||
Tools: request.Tools,
|
||||
ToolChoice: request.ToolChoice,
|
||||
|
||||
@@ -40,7 +40,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateClaudeRequest(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if textRequest.Stream {
|
||||
@@ -49,18 +49,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
@@ -77,7 +77,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
@@ -111,17 +111,17 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
@@ -133,7 +133,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -60,25 +60,28 @@ type ResponsesUsageInfo struct {
|
||||
}
|
||||
|
||||
type RelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
UsingGroup string // 使用的分组
|
||||
UserGroup string // 用户所在分组
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
isFirstResponse bool
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
ChannelIsMultiKey bool // 是否多密钥
|
||||
ChannelMultiKeyIndex int // 多密钥索引
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
UsingGroup string // 使用的分组
|
||||
UserGroup string // 用户所在分组
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
isFirstResponse bool
|
||||
//SendLastReasoningResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsGeminiBatchEmbedding bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
//RecodeModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
@@ -88,6 +91,7 @@ type RelayInfo struct {
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
||||
IsModelMapped bool
|
||||
ClientWs *websocket.Conn
|
||||
TargetWs *websocket.Conn
|
||||
@@ -222,6 +226,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
|
||||
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
|
||||
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
|
||||
if startTime.IsZero() {
|
||||
startTime = time.Now()
|
||||
}
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
|
||||
apiType, _ := common.ChannelType2APIType(channelType)
|
||||
@@ -259,6 +266,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
},
|
||||
|
||||
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
|
||||
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
||||
info.IsPlayground = true
|
||||
|
||||
@@ -41,17 +41,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||
@@ -59,7 +59,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -74,18 +74,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
|
||||
request := &gemini.GeminiChatRequest{}
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
|
||||
request := &dto.GeminiChatRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -44,7 +44,7 @@ func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
// }
|
||||
}
|
||||
|
||||
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
|
||||
func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
|
||||
var inputTexts []string
|
||||
for _, content := range textRequest.Contents {
|
||||
for _, part := range content.Parts {
|
||||
@@ -61,7 +61,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
// 计算输入 token 数量
|
||||
var inputTexts []string
|
||||
for _, content := range req.Contents {
|
||||
@@ -78,9 +78,13 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
|
||||
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
|
||||
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
|
||||
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0
|
||||
configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
|
||||
if configBudget != nil && *configBudget == 0 {
|
||||
// 如果思考预算为 0,则认为是非思考请求
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -109,7 +113,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
req, err := getAndValidateGeminiRequest(c)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||
@@ -121,14 +125,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
sensitiveWords, err := checkGeminiInputSensitive(req)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
// model mapped 模型映射
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
@@ -159,7 +163,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// pre consume quota
|
||||
@@ -175,7 +179,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
adaptor.Init(relayInfo)
|
||||
@@ -198,13 +202,18 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewReader(body)
|
||||
} else {
|
||||
jsonData, err := common.Marshal(req)
|
||||
// 使用 ConvertGeminiRequest 转换请求格式
|
||||
convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
@@ -216,7 +225,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,3 +264,118 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||
|
||||
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
|
||||
relayInfo.IsGeminiBatchEmbedding = isBatch
|
||||
|
||||
var promptTokens int
|
||||
var req any
|
||||
var err error
|
||||
var inputTexts []string
|
||||
|
||||
if isBatch {
|
||||
batchRequest := &dto.GeminiBatchEmbeddingRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, batchRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
req = batchRequest
|
||||
for _, r := range batchRequest.Requests {
|
||||
for _, part := range r.Content.Parts {
|
||||
if part.Text != "" {
|
||||
inputTexts = append(inputTexts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
singleRequest := &dto.GeminiEmbeddingRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, singleRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
req = singleRequest
|
||||
for _, part := range singleRequest.Content.Parts {
|
||||
if part.Text != "" {
|
||||
inputTexts = append(inputTexts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName)
|
||||
relayInfo.SetPromptTokens(promptTokens)
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
var requestBody io.Reader
|
||||
jsonData, err := common.Marshal(req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
|
||||
if openaiErr != nil {
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user