mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-17 21:37:26 +00:00
Compare commits
90 Commits
v0.8.8.0-a
...
v0.8.8.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c079d04a8 | ||
|
|
c9d4cdc57e | ||
|
|
12b4e80d4b | ||
|
|
6e2a04f374 | ||
|
|
8357b15fec | ||
|
|
ecdd9d1ccb | ||
|
|
10b04416c1 | ||
|
|
398ae7156b | ||
|
|
d85eeabf11 | ||
|
|
c056a7ad7c | ||
|
|
c784a70277 | ||
|
|
e6c87907d5 | ||
|
|
71e9290142 | ||
|
|
74ec34da67 | ||
|
|
7188749cb3 | ||
|
|
c28add55db | ||
|
|
78f34a8245 | ||
|
|
97d6f10f15 | ||
|
|
afefc4caca | ||
|
|
6abbd036f8 | ||
|
|
ef0db0f914 | ||
|
|
e01986fdd4 | ||
|
|
a0c6ebe2d8 | ||
|
|
d2183af23f | ||
|
|
953f1bdc3c | ||
|
|
e2429f20f8 | ||
|
|
f0945da4fb | ||
|
|
8df3de9ae5 | ||
|
|
277cc1cac8 | ||
|
|
07a92293e4 | ||
|
|
f995e31d04 | ||
|
|
9758a9e60d | ||
|
|
6f56696af2 | ||
|
|
345fbdf3d2 | ||
|
|
ce031f7d15 | ||
|
|
bd6b811183 | ||
|
|
196bafff03 | ||
|
|
f20b558e22 | ||
|
|
54447bf227 | ||
|
|
fc09051d8b | ||
|
|
1f5ef24ecd | ||
|
|
b1faf42529 | ||
|
|
6a85206e32 | ||
|
|
e3d3e697d3 | ||
|
|
db9b333930 | ||
|
|
f7b284ad73 | ||
|
|
e1970e8a66 | ||
|
|
0cd93d67ff | ||
|
|
6e806e21bd | ||
|
|
a8462c1b70 | ||
|
|
706ea8b649 | ||
|
|
95d46d1dfc | ||
|
|
010f27678d | ||
|
|
d87117a2cf | ||
|
|
4ed92a94a1 | ||
|
|
821ea34a3c | ||
|
|
ecb3d01376 | ||
|
|
e322ed4f05 | ||
|
|
bcf7e78665 | ||
|
|
0cb2bb2ea7 | ||
|
|
c5d97597c4 | ||
|
|
fe9acb6c59 | ||
|
|
bca78beb1b | ||
|
|
a8a42cbfa8 | ||
|
|
19df2ac234 | ||
|
|
e7524c85c2 | ||
|
|
a4356727e9 | ||
|
|
f15a53fae4 | ||
|
|
8e3cf2eaab | ||
|
|
c51ec3135b | ||
|
|
2469c439b1 | ||
|
|
1297addfb1 | ||
|
|
d6cbf43373 | ||
|
|
df647e7b42 | ||
|
|
fe16d05fbb | ||
|
|
1430c05b6c | ||
|
|
b25841e50d | ||
|
|
b704fc9254 | ||
|
|
352da66bd1 | ||
|
|
8205ad2cd0 | ||
|
|
e162b9c169 | ||
|
|
77e3502028 | ||
|
|
ae0461692c | ||
|
|
13bdb80958 | ||
|
|
6f74e7b738 | ||
|
|
eaee89f77a | ||
|
|
6103888610 | ||
|
|
7af3fb5ae4 | ||
|
|
3ac54b2178 | ||
|
|
5621755655 |
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type stringWriter interface {
|
||||
@@ -52,6 +53,8 @@ type CustomEvent struct {
|
||||
Id string
|
||||
Retry uint
|
||||
Data interface{}
|
||||
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
func encode(writer io.Writer, event CustomEvent) error {
|
||||
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
|
||||
}
|
||||
|
||||
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||
r.Mutex.Lock()
|
||||
defer r.Mutex.Unlock()
|
||||
header := w.Header()
|
||||
header["Content-Type"] = contentType
|
||||
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -95,3 +98,95 @@ func GetJsonString(data any) string {
|
||||
b, _ := json.Marshal(data)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string
|
||||
// Example:
|
||||
// http://example.com -> http://***.com
|
||||
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
|
||||
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
|
||||
// 192.168.1.1 -> ***.***.***.***
|
||||
func MaskSensitiveInfo(str string) string {
|
||||
// Mask URLs
|
||||
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
host := u.Host
|
||||
if host == "" {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
// Split host by dots
|
||||
parts := strings.Split(host, ".")
|
||||
if len(parts) < 2 {
|
||||
// If less than 2 parts, just mask the whole host
|
||||
return u.Scheme + "://***" + u.Path
|
||||
}
|
||||
|
||||
// Keep the TLD (Top Level Domain) and mask the rest
|
||||
var maskedHost string
|
||||
if len(parts) == 2 {
|
||||
// example.com -> ***.com
|
||||
maskedHost = "***." + parts[len(parts)-1]
|
||||
} else {
|
||||
// Handle cases like sub.domain.co.uk or api.example.com
|
||||
// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
|
||||
lastPart := parts[len(parts)-1]
|
||||
secondLastPart := parts[len(parts)-2]
|
||||
|
||||
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
|
||||
// Likely country code TLD like co.uk, com.cn
|
||||
maskedHost = "***." + secondLastPart + "." + lastPart
|
||||
} else {
|
||||
// Regular TLD like .com, .org
|
||||
maskedHost = "***." + lastPart
|
||||
}
|
||||
}
|
||||
|
||||
result := u.Scheme + "://" + maskedHost
|
||||
|
||||
// Mask path
|
||||
if u.Path != "" && u.Path != "/" {
|
||||
pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
|
||||
maskedPathParts := make([]string, len(pathParts))
|
||||
for i := range pathParts {
|
||||
if pathParts[i] != "" {
|
||||
maskedPathParts[i] = "***"
|
||||
}
|
||||
}
|
||||
if len(maskedPathParts) > 0 {
|
||||
result += "/" + strings.Join(maskedPathParts, "/")
|
||||
}
|
||||
} else if u.Path == "/" {
|
||||
result += "/"
|
||||
}
|
||||
|
||||
// Mask query parameters
|
||||
if u.RawQuery != "" {
|
||||
values, err := url.ParseQuery(u.RawQuery)
|
||||
if err != nil {
|
||||
// If can't parse query, just mask the whole query string
|
||||
result += "?***"
|
||||
} else {
|
||||
maskedParams := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
maskedParams = append(maskedParams, key+"=***")
|
||||
}
|
||||
if len(maskedParams) > 0 {
|
||||
result += "?" + strings.Join(maskedParams, "&")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// Mask IP addresses
|
||||
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
150
common/totp.go
Normal file
150
common/totp.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
)
|
||||
|
||||
const (
|
||||
// 备用码配置
|
||||
BackupCodeLength = 8 // 备用码长度
|
||||
BackupCodeCount = 4 // 生成备用码数量
|
||||
|
||||
// 限制配置
|
||||
MaxFailAttempts = 5 // 最大失败尝试次数
|
||||
LockoutDuration = 300 // 锁定时间(秒)
|
||||
)
|
||||
|
||||
// GenerateTOTPSecret 生成TOTP密钥和配置
|
||||
func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
|
||||
issuer := Get2FAIssuer()
|
||||
return totp.Generate(totp.GenerateOpts{
|
||||
Issuer: issuer,
|
||||
AccountName: accountName,
|
||||
Period: 30,
|
||||
Digits: otp.DigitsSix,
|
||||
Algorithm: otp.AlgorithmSHA1,
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateTOTPCode 验证TOTP验证码
|
||||
func ValidateTOTPCode(secret, code string) bool {
|
||||
// 清理验证码格式
|
||||
cleanCode := strings.ReplaceAll(code, " ", "")
|
||||
if len(cleanCode) != 6 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
return totp.Validate(cleanCode, secret)
|
||||
}
|
||||
|
||||
// GenerateBackupCodes 生成备用恢复码
|
||||
func GenerateBackupCodes() ([]string, error) {
|
||||
codes := make([]string, BackupCodeCount)
|
||||
|
||||
for i := 0; i < BackupCodeCount; i++ {
|
||||
code, err := generateRandomBackupCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
codes[i] = code
|
||||
}
|
||||
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// generateRandomBackupCode 生成单个备用码
|
||||
func generateRandomBackupCode() (string, error) {
|
||||
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
code := make([]byte, BackupCodeLength)
|
||||
|
||||
for i := range code {
|
||||
randomBytes := make([]byte, 1)
|
||||
_, err := rand.Read(randomBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code[i] = charset[int(randomBytes[0])%len(charset)]
|
||||
}
|
||||
|
||||
// 格式化为 XXXX-XXXX 格式
|
||||
return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
|
||||
}
|
||||
|
||||
// ValidateBackupCode 验证备用码格式
|
||||
func ValidateBackupCode(code string) bool {
|
||||
// 移除所有分隔符并转为大写
|
||||
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||
if len(cleanCode) != BackupCodeLength {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查字符是否合法
|
||||
for _, char := range cleanCode {
|
||||
if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// NormalizeBackupCode 标准化备用码格式
|
||||
func NormalizeBackupCode(code string) string {
|
||||
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||
if len(cleanCode) == BackupCodeLength {
|
||||
return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
|
||||
}
|
||||
return code
|
||||
}
|
||||
|
||||
// HashBackupCode 对备用码进行哈希
|
||||
func HashBackupCode(code string) (string, error) {
|
||||
normalizedCode := NormalizeBackupCode(code)
|
||||
return Password2Hash(normalizedCode)
|
||||
}
|
||||
|
||||
// Get2FAIssuer 获取2FA发行者名称
|
||||
func Get2FAIssuer() string {
|
||||
return SystemName
|
||||
}
|
||||
|
||||
// getEnvOrDefault 获取环境变量或默认值
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if value, exists := os.LookupEnv(key); exists {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// ValidateNumericCode 验证数字验证码格式
|
||||
func ValidateNumericCode(code string) (string, error) {
|
||||
// 移除空格
|
||||
code = strings.ReplaceAll(code, " ", "")
|
||||
|
||||
if len(code) != 6 {
|
||||
return "", fmt.Errorf("验证码必须是6位数字")
|
||||
}
|
||||
|
||||
// 检查是否为纯数字
|
||||
if _, err := strconv.Atoi(code); err != nil {
|
||||
return "", fmt.Errorf("验证码只能包含数字")
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// GenerateQRCodeData 生成二维码数据
|
||||
func GenerateQRCodeData(secret, username string) string {
|
||||
issuer := Get2FAIssuer()
|
||||
accountName := fmt.Sprintf("%s (%s)", username, issuer)
|
||||
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
|
||||
issuer, accountName, secret, issuer)
|
||||
}
|
||||
@@ -49,6 +49,7 @@ const (
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeVidu = 52
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -106,4 +107,5 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
"https://api.vidu.cn", //52
|
||||
}
|
||||
|
||||
@@ -69,6 +69,12 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeVidu {
|
||||
return testResult{
|
||||
localErr: errors.New("vidu channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
@@ -203,7 +209,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
var httpResp *http.Response
|
||||
@@ -214,7 +220,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -230,7 +236,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: errors.New("usage is nil"),
|
||||
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
|
||||
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
@@ -240,7 +246,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
@@ -326,8 +332,11 @@ func TestChannel(c *gin.Context) {
|
||||
}
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
channel, err = model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
//defer func() {
|
||||
// if channel.ChannelInfo.IsMultiKey {
|
||||
@@ -411,7 +420,7 @@ func testAllChannels(notify bool) error {
|
||||
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||
if milliseconds > disableThreshold {
|
||||
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
|
||||
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
||||
shouldBanChannel = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,11 +36,30 @@ type OpenAIModel struct {
|
||||
Parent string `json:"parent"`
|
||||
}
|
||||
|
||||
type GoogleOpenAICompatibleModels []struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputTokenLimit int `json:"inputTokenLimit"`
|
||||
OutputTokenLimit int `json:"outputTokenLimit"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
MaxTemperature int `json:"maxTemperature,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIModelsResponse struct {
|
||||
Data []OpenAIModel `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
type GoogleOpenAICompatibleResponse struct {
|
||||
Models []GoogleOpenAICompatibleModels `json:"models"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
func parseStatusFilter(statusParam string) int {
|
||||
switch strings.ToLower(statusParam) {
|
||||
case "enabled", "1":
|
||||
@@ -52,6 +71,13 @@ func parseStatusFilter(statusParam string) int {
|
||||
}
|
||||
}
|
||||
|
||||
func clearChannelInfo(channel *model.Channel) {
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = nil
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
channelData := make([]*model.Channel, 0)
|
||||
@@ -126,6 +152,10 @@ func GetAllChannels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, datum := range channelData {
|
||||
clearChannelInfo(datum)
|
||||
}
|
||||
|
||||
countQuery := model.DB.Model(&model.Channel{})
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
@@ -168,26 +198,59 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
|
||||
var url string
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeGemini:
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key)
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
default:
|
||||
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||
}
|
||||
|
||||
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
||||
var body []byte
|
||||
if channel.Type == constant.ChannelTypeGemini {
|
||||
body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader
|
||||
} else {
|
||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var result OpenAIModelsResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
var parseSuccess bool
|
||||
|
||||
// 适配特殊格式
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeGemini:
|
||||
var googleResult GoogleOpenAICompatibleResponse
|
||||
if err = json.Unmarshal(body, &googleResult); err == nil {
|
||||
// 转换Google格式到OpenAI格式
|
||||
for _, model := range googleResult.Models {
|
||||
for _, gModel := range model {
|
||||
result.Data = append(result.Data, OpenAIModel{
|
||||
ID: gModel.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
parseSuccess = true
|
||||
}
|
||||
}
|
||||
|
||||
// 如果解析失败,尝试OpenAI格式
|
||||
if !parseSuccess {
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var ids []string
|
||||
@@ -319,6 +382,10 @@ func SearchChannels(c *gin.Context) {
|
||||
|
||||
pagedData := channelData[startIdx:endIdx]
|
||||
|
||||
for _, datum := range pagedData {
|
||||
clearChannelInfo(datum)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -342,6 +409,9 @@ func GetChannel(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if channel != nil {
|
||||
clearChannelInfo(channel)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -669,6 +739,7 @@ func DeleteChannelBatch(c *gin.Context) {
|
||||
type PatchChannel struct {
|
||||
model.Channel
|
||||
MultiKeyMode *string `json:"multi_key_mode"`
|
||||
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
|
||||
}
|
||||
|
||||
func UpdateChannel(c *gin.Context) {
|
||||
@@ -688,7 +759,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
||||
originChannel, err := model.GetChannelById(channel.Id, false)
|
||||
originChannel, err := model.GetChannelById(channel.Id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -704,6 +775,69 @@ func UpdateChannel(c *gin.Context) {
|
||||
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
||||
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
||||
}
|
||||
|
||||
// 处理多key模式下的密钥追加/覆盖逻辑
|
||||
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
|
||||
switch *channel.KeyMode {
|
||||
case "append":
|
||||
// 追加模式:将新密钥添加到现有密钥列表
|
||||
if originChannel.Key != "" {
|
||||
var newKeys []string
|
||||
var existingKeys []string
|
||||
|
||||
// 解析现有密钥
|
||||
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
|
||||
// JSON数组格式
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
|
||||
existingKeys = make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
existingKeys[i] = string(v)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 换行分隔格式
|
||||
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
|
||||
}
|
||||
|
||||
// 处理 Vertex AI 的特殊情况
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
// 尝试解析新密钥为JSON数组
|
||||
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||||
array, err := getVertexArrayKeys(channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "追加密钥解析失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
newKeys = array
|
||||
} else {
|
||||
// 单个JSON密钥
|
||||
newKeys = []string{channel.Key}
|
||||
}
|
||||
// 合并密钥
|
||||
allKeys := append(existingKeys, newKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
} else {
|
||||
// 普通渠道的处理
|
||||
inputKeys := strings.Split(channel.Key, "\n")
|
||||
for _, key := range inputKeys {
|
||||
key = strings.TrimSpace(key)
|
||||
if key != "" {
|
||||
newKeys = append(newKeys, key)
|
||||
}
|
||||
}
|
||||
// 合并密钥
|
||||
allKeys := append(existingKeys, newKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
}
|
||||
}
|
||||
case "replace":
|
||||
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
|
||||
}
|
||||
}
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -711,6 +845,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
}
|
||||
model.InitChannelCache()
|
||||
channel.Key = ""
|
||||
clearChannelInfo(&channel.Channel)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -914,3 +1049,409 @@ func CopyChannel(c *gin.Context) {
|
||||
// success
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
|
||||
}
|
||||
|
||||
// MultiKeyManageRequest represents the request for multi-key management operations
|
||||
type MultiKeyManageRequest struct {
|
||||
ChannelId int `json:"channel_id"`
|
||||
Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
|
||||
KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
|
||||
Page int `json:"page,omitempty"` // for get_key_status pagination
|
||||
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
|
||||
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
|
||||
}
|
||||
|
||||
// MultiKeyStatusResponse represents the response for key status query
|
||||
type MultiKeyStatusResponse struct {
|
||||
Keys []KeyStatus `json:"keys"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
// Statistics
|
||||
EnabledCount int `json:"enabled_count"`
|
||||
ManualDisabledCount int `json:"manual_disabled_count"`
|
||||
AutoDisabledCount int `json:"auto_disabled_count"`
|
||||
}
|
||||
|
||||
type KeyStatus struct {
|
||||
Index int `json:"index"`
|
||||
Status int `json:"status"` // 1: enabled, 2: disabled
|
||||
DisabledTime int64 `json:"disabled_time,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
|
||||
}
|
||||
|
||||
// ManageMultiKeys handles multi-key management operations
|
||||
func ManageMultiKeys(c *gin.Context) {
|
||||
request := MultiKeyManageRequest{}
|
||||
err := c.ShouldBindJSON(&request)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(request.ChannelId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "渠道不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !channel.ChannelInfo.IsMultiKey {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该渠道不是多密钥模式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
switch request.Action {
|
||||
case "get_key_status":
|
||||
keys := channel.GetKeys()
|
||||
|
||||
// Default pagination parameters
|
||||
page := request.Page
|
||||
pageSize := request.PageSize
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 50 // Default page size
|
||||
}
|
||||
|
||||
// Statistics for all keys (unchanged by filtering)
|
||||
var enabledCount, manualDisabledCount, autoDisabledCount int
|
||||
|
||||
// Build all key status data first
|
||||
var allKeyStatusList []KeyStatus
|
||||
for i, key := range keys {
|
||||
status := 1 // default enabled
|
||||
var disabledTime int64
|
||||
var reason string
|
||||
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||
status = s
|
||||
}
|
||||
}
|
||||
|
||||
// Count for statistics (all keys)
|
||||
switch status {
|
||||
case 1:
|
||||
enabledCount++
|
||||
case 2:
|
||||
manualDisabledCount++
|
||||
case 3:
|
||||
autoDisabledCount++
|
||||
}
|
||||
|
||||
if status != 1 {
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Create key preview (first 10 chars)
|
||||
keyPreview := key
|
||||
if len(key) > 10 {
|
||||
keyPreview = key[:10] + "..."
|
||||
}
|
||||
|
||||
allKeyStatusList = append(allKeyStatusList, KeyStatus{
|
||||
Index: i,
|
||||
Status: status,
|
||||
DisabledTime: disabledTime,
|
||||
Reason: reason,
|
||||
KeyPreview: keyPreview,
|
||||
})
|
||||
}
|
||||
|
||||
// Apply status filter if specified
|
||||
var filteredKeyStatusList []KeyStatus
|
||||
if request.Status != nil {
|
||||
for _, keyStatus := range allKeyStatusList {
|
||||
if keyStatus.Status == *request.Status {
|
||||
filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
filteredKeyStatusList = allKeyStatusList
|
||||
}
|
||||
|
||||
// Calculate pagination based on filtered results
|
||||
filteredTotal := len(filteredKeyStatusList)
|
||||
totalPages := (filteredTotal + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
if page > totalPages {
|
||||
page = totalPages
|
||||
}
|
||||
|
||||
// Calculate range for current page
|
||||
start := (page - 1) * pageSize
|
||||
end := start + pageSize
|
||||
if end > filteredTotal {
|
||||
end = filteredTotal
|
||||
}
|
||||
|
||||
// Get the page data
|
||||
var pageKeyStatusList []KeyStatus
|
||||
if start < filteredTotal {
|
||||
pageKeyStatusList = filteredKeyStatusList[start:end]
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": MultiKeyStatusResponse{
|
||||
Keys: pageKeyStatusList,
|
||||
Total: filteredTotal, // Total of filtered results
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
EnabledCount: enabledCount, // Overall statistics
|
||||
ManualDisabledCount: manualDisabledCount, // Overall statistics
|
||||
AutoDisabledCount: autoDisabledCount, // Overall statistics
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
case "disable_key":
|
||||
if request.KeyIndex == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "未指定要禁用的密钥索引",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keyIndex := *request.KeyIndex
|
||||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "密钥索引超出范围",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
}
|
||||
|
||||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "密钥已禁用",
|
||||
})
|
||||
return
|
||||
|
||||
case "enable_key":
|
||||
if request.KeyIndex == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "未指定要启用的密钥索引",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keyIndex := *request.KeyIndex
|
||||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "密钥索引超出范围",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从状态列表中删除该密钥的记录,使其回到默认启用状态
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
|
||||
}
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "密钥已启用",
|
||||
})
|
||||
return
|
||||
|
||||
case "enable_all_keys":
|
||||
// 清空所有禁用状态,使所有密钥回到默认启用状态
|
||||
var enabledCount int
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
|
||||
}
|
||||
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
|
||||
})
|
||||
return
|
||||
|
||||
case "disable_all_keys":
|
||||
// 禁用所有启用的密钥
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
}
|
||||
|
||||
var disabledCount int
|
||||
for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
|
||||
status := 1 // default enabled
|
||||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||
status = s
|
||||
}
|
||||
|
||||
// 只禁用当前启用的密钥
|
||||
if status == 1 {
|
||||
channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
|
||||
disabledCount++
|
||||
}
|
||||
}
|
||||
|
||||
if disabledCount == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "没有可禁用的密钥",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
|
||||
})
|
||||
return
|
||||
|
||||
case "delete_disabled_keys":
|
||||
keys := channel.GetKeys()
|
||||
var remainingKeys []string
|
||||
var deletedCount int
|
||||
var newStatusList = make(map[int]int)
|
||||
var newDisabledTime = make(map[int]int64)
|
||||
var newDisabledReason = make(map[int]string)
|
||||
|
||||
newIndex := 0
|
||||
for i, key := range keys {
|
||||
status := 1 // default enabled
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||
status = s
|
||||
}
|
||||
}
|
||||
|
||||
// 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
|
||||
if status == 3 {
|
||||
deletedCount++
|
||||
} else {
|
||||
remainingKeys = append(remainingKeys, key)
|
||||
// 保留非自动禁用密钥的状态信息,重新索引
|
||||
if status != 1 {
|
||||
newStatusList[newIndex] = status
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
|
||||
newDisabledTime[newIndex] = t
|
||||
}
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
|
||||
newDisabledReason[newIndex] = r
|
||||
}
|
||||
}
|
||||
}
|
||||
newIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if deletedCount == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "没有需要删除的自动禁用密钥",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update channel with remaining keys
|
||||
channel.Key = strings.Join(remainingKeys, "\n")
|
||||
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
|
||||
channel.ChannelInfo.MultiKeyStatusList = newStatusList
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
|
||||
"data": deletedCount,
|
||||
})
|
||||
return
|
||||
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不支持的操作",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,19 +28,19 @@ func Playground(c *gin.Context) {
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
@@ -51,7 +51,7 @@ func Playground(c *gin.Context) {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
@@ -62,7 +62,7 @@ func Playground(c *gin.Context) {
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -63,7 +64,7 @@ func AddRedemption(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
||||
if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "兑换码名称长度必须在1-20之间",
|
||||
|
||||
@@ -47,7 +47,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
|
||||
if constant2.ErrorLogEnabled && err != nil {
|
||||
if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
|
||||
// 保存错误日志到mysql中
|
||||
userId := c.GetInt("id")
|
||||
tokenName := c.GetString("token_name")
|
||||
@@ -56,14 +56,21 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = err.ErrorType
|
||||
other["error_type"] = err.GetErrorType()
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
|
||||
adminInfo := make(map[string]interface{})
|
||||
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||
if isMultiKey {
|
||||
adminInfo["is_multi_key"] = true
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
other["admin_info"] = adminInfo
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -128,7 +135,7 @@ func WssRelay(c *gin.Context) {
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -259,10 +266,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
}
|
||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed)
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
if newAPIError != nil {
|
||||
@@ -278,7 +285,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
if types.IsChannelError(openaiErr) {
|
||||
return true
|
||||
}
|
||||
if types.IsLocalError(openaiErr) {
|
||||
if types.IsSkipRetryError(openaiErr) {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
|
||||
@@ -83,7 +83,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = json.Unmarshal(responseBody, &responseItems); err == nil {
|
||||
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
|
||||
553
controller/twofa.go
Normal file
553
controller/twofa.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Setup2FARequest 设置2FA请求结构
|
||||
type Setup2FARequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// Verify2FARequest 验证2FA请求结构
|
||||
type Verify2FARequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// Setup2FAResponse 设置2FA响应结构
|
||||
type Setup2FAResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeData string `json:"qr_code_data"`
|
||||
BackupCodes []string `json:"backup_codes"`
|
||||
}
|
||||
|
||||
// Setup2FA 初始化2FA设置
|
||||
func Setup2FA(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 检查用户是否已经启用2FA
|
||||
existing, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if existing != nil && existing.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已启用2FA,请先禁用后重新设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果存在已禁用的2FA记录,先删除它
|
||||
if existing != nil && !existing.IsEnabled {
|
||||
if err := existing.Delete(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
existing = nil // 重置为nil,后续将创建新记录
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 生成TOTP密钥
|
||||
key, err := common.GenerateTOTPSecret(user.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成2FA密钥失败",
|
||||
})
|
||||
common.SysError("生成TOTP密钥失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 生成备用码
|
||||
backupCodes, err := common.GenerateBackupCodes()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成备用码失败",
|
||||
})
|
||||
common.SysError("生成备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 生成二维码数据
|
||||
qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
|
||||
|
||||
// 创建或更新2FA记录(暂未启用)
|
||||
twoFA := &model.TwoFA{
|
||||
UserId: userId,
|
||||
Secret: key.Secret(),
|
||||
IsEnabled: false,
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
// 更新现有记录
|
||||
twoFA.Id = existing.Id
|
||||
err = twoFA.Update()
|
||||
} else {
|
||||
// 创建新记录
|
||||
err = twoFA.Create()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建备用码记录
|
||||
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "保存备用码失败",
|
||||
})
|
||||
common.SysError("保存备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置",
|
||||
"data": Setup2FAResponse{
|
||||
Secret: key.Secret(),
|
||||
QRCodeData: qrCodeData,
|
||||
BackupCodes: backupCodes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Enable2FA 启用2FA
|
||||
func Enable2FA(c *gin.Context) {
|
||||
var req Setup2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "请先完成2FA初始化设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
if twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "2FA已经启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 启用2FA
|
||||
if err := twoFA.Enable(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "两步验证启用成功",
|
||||
})
|
||||
}
|
||||
|
||||
// Disable2FA 禁用2FA
|
||||
func Disable2FA(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码或备用码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
isValidTOTP := false
|
||||
isValidBackup := false
|
||||
|
||||
if err == nil {
|
||||
// 尝试验证TOTP
|
||||
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
}
|
||||
|
||||
if !isValidTOTP {
|
||||
// 尝试验证备用码
|
||||
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidTOTP && !isValidBackup {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 禁用2FA
|
||||
if err := model.DisableTwoFA(userId); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "两步验证已禁用",
|
||||
})
|
||||
}
|
||||
|
||||
// Get2FAStatus 获取用户2FA状态
|
||||
func Get2FAStatus(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
status := map[string]interface{}{
|
||||
"enabled": false,
|
||||
"locked": false,
|
||||
}
|
||||
|
||||
if twoFA != nil {
|
||||
status["enabled"] = twoFA.IsEnabled
|
||||
status["locked"] = twoFA.IsLocked()
|
||||
if twoFA.IsEnabled {
|
||||
// 获取剩余备用码数量
|
||||
backupCount, err := model.GetUnusedBackupCodeCount(userId)
|
||||
if err != nil {
|
||||
common.SysError("获取备用码数量失败: " + err.Error())
|
||||
} else {
|
||||
status["backup_codes_remaining"] = backupCount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": status,
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateBackupCodes 重新生成备用码
|
||||
func RegenerateBackupCodes(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新的备用码
|
||||
backupCodes, err := common.GenerateBackupCodes()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成备用码失败",
|
||||
})
|
||||
common.SysError("生成备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 保存新的备用码
|
||||
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "保存备用码失败",
|
||||
})
|
||||
common.SysError("保存备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "备用码重新生成成功",
|
||||
"data": map[string]interface{}{
|
||||
"backup_codes": backupCodes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Verify2FALogin 登录时验证2FA
|
||||
func Verify2FALogin(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从会话中获取pending用户信息
|
||||
session := sessions.Default(c)
|
||||
pendingUserId := session.Get("pending_user_id")
|
||||
if pendingUserId == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "会话已过期,请重新登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
userId, ok := pendingUserId.(int)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "会话数据无效,请重新登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 获取用户信息
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(user.Id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码或备用码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
isValidTOTP := false
|
||||
isValidBackup := false
|
||||
|
||||
if err == nil {
|
||||
// 尝试验证TOTP
|
||||
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
}
|
||||
|
||||
if !isValidTOTP {
|
||||
// 尝试验证备用码
|
||||
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidTOTP && !isValidBackup {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2FA验证成功,清理pending会话信息并完成登录
|
||||
session.Delete("pending_username")
|
||||
session.Delete("pending_user_id")
|
||||
session.Save()
|
||||
|
||||
setupLogin(user, c)
|
||||
}
|
||||
|
||||
// Admin2FAStats 管理员获取2FA统计信息
|
||||
func Admin2FAStats(c *gin.Context) {
|
||||
stats, err := model.GetTwoFAStats()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": stats,
|
||||
})
|
||||
}
|
||||
|
||||
// AdminDisable2FA 管理员强制禁用用户2FA
|
||||
func AdminDisable2FA(c *gin.Context) {
|
||||
userIdStr := c.Param("id")
|
||||
userId, err := strconv.Atoi(userIdStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户ID格式错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查目标用户权限
|
||||
targetUser, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权操作同级或更高级用户的2FA设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 禁用2FA
|
||||
if err := model.DisableTwoFA(userId); err != nil {
|
||||
if errors.Is(err, model.ErrTwoFANotEnabled) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
adminId := c.GetInt("id")
|
||||
model.RecordLog(userId, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "用户2FA已被强制禁用",
|
||||
})
|
||||
}
|
||||
@@ -62,6 +62,32 @@ func Login(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否启用2FA
|
||||
if model.IsTwoFAEnabled(user.Id) {
|
||||
// 设置pending session,等待2FA验证
|
||||
session := sessions.Default(c)
|
||||
session.Set("pending_username", user.Username)
|
||||
session.Set("pending_user_id", user.Id)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "无法保存会话信息,请重试",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "请输入两步验证码",
|
||||
"success": true,
|
||||
"data": map[string]interface{}{
|
||||
"require_2fa": true,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package dto
|
||||
|
||||
type ChannelSettings struct {
|
||||
ForceFormat bool `json:"force_format,omitempty"`
|
||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||
Proxy string `json:"proxy"`
|
||||
ForceFormat bool `json:"force_format,omitempty"`
|
||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||
Proxy string `json:"proxy"`
|
||||
PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
@@ -284,14 +285,9 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
||||
return mediaContent
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeErrorWithStatusCode struct {
|
||||
Error ClaudeError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
Error types.ClaudeError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
LocalError bool
|
||||
}
|
||||
|
||||
@@ -303,7 +299,7 @@ type ClaudeResponse struct {
|
||||
Completion string `json:"completion,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Error *types.ClaudeError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||
@@ -324,6 +320,42 @@ func (c *ClaudeResponse) GetIndex() int {
|
||||
return *c.Index
|
||||
}
|
||||
|
||||
// GetClaudeError 从动态错误类型中提取ClaudeError结构
|
||||
func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
|
||||
if c.Error == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err := c.Error.(type) {
|
||||
case types.ClaudeError:
|
||||
return &err
|
||||
case *types.ClaudeError:
|
||||
return err
|
||||
case map[string]interface{}:
|
||||
// 处理从JSON解析来的map结构
|
||||
claudeErr := &types.ClaudeError{}
|
||||
if errType, ok := err["type"].(string); ok {
|
||||
claudeErr.Type = errType
|
||||
}
|
||||
if errMsg, ok := err["message"].(string); ok {
|
||||
claudeErr.Message = errMsg
|
||||
}
|
||||
return claudeErr
|
||||
case string:
|
||||
// 处理简单字符串错误
|
||||
return &types.ClaudeError{
|
||||
Type: "error",
|
||||
Message: err,
|
||||
}
|
||||
default:
|
||||
// 未知类型,尝试转换为字符串
|
||||
return &types.ClaudeError{
|
||||
Type: "unknown_error",
|
||||
Message: fmt.Sprintf("%v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package gemini
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
@@ -32,7 +35,7 @@ func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
|
||||
MimeTypeSnake string `json:"mime_type"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -53,7 +56,7 @@ type FunctionCall struct {
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
}
|
||||
@@ -78,7 +81,7 @@ type GeminiPart struct {
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||
@@ -93,7 +96,7 @@ func (p *GeminiPart) UnmarshalJSON(data []byte) error {
|
||||
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -7,15 +7,15 @@ import (
|
||||
)
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema json.RawMessage `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type FormatJsonSchema struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema any `json:"schema,omitempty"`
|
||||
Strict any `json:"strict,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema any `json:"schema,omitempty"`
|
||||
Strict json.RawMessage `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
@@ -73,6 +73,15 @@ func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
|
||||
if strings.HasPrefix(r.Model, "o") {
|
||||
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
|
||||
return "developer"
|
||||
}
|
||||
}
|
||||
return "system"
|
||||
}
|
||||
|
||||
type ToolCallRequest struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
|
||||
@@ -2,12 +2,18 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
Error *OpenAIError `json:"error"`
|
||||
Error any `json:"error"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(s.Error)
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
@@ -31,10 +37,15 @@ type OpenAITextResponse struct {
|
||||
Object string `json:"object"`
|
||||
Created any `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(o.Error)
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
@@ -217,7 +228,7 @@ type OpenAIResponsesResponse struct {
|
||||
Object string `json:"object"`
|
||||
CreatedAt int `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
Instructions string `json:"instructions"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
@@ -237,6 +248,11 @@ type OpenAIResponsesResponse struct {
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(o.Error)
|
||||
}
|
||||
|
||||
type IncompleteDetails struct {
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
@@ -276,3 +292,45 @@ type ResponsesStreamResponse struct {
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Item *ResponsesOutput `json:"item,omitempty"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func GetOpenAIError(errorField any) *types.OpenAIError {
|
||||
if errorField == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err := errorField.(type) {
|
||||
case types.OpenAIError:
|
||||
return &err
|
||||
case *types.OpenAIError:
|
||||
return err
|
||||
case map[string]interface{}:
|
||||
// 处理从JSON解析来的map结构
|
||||
openaiErr := &types.OpenAIError{}
|
||||
if errType, ok := err["type"].(string); ok {
|
||||
openaiErr.Type = errType
|
||||
}
|
||||
if errMsg, ok := err["message"].(string); ok {
|
||||
openaiErr.Message = errMsg
|
||||
}
|
||||
if errParam, ok := err["param"].(string); ok {
|
||||
openaiErr.Param = errParam
|
||||
}
|
||||
if errCode, ok := err["code"]; ok {
|
||||
openaiErr.Code = errCode
|
||||
}
|
||||
return openaiErr
|
||||
case string:
|
||||
// 处理简单字符串错误
|
||||
return &types.OpenAIError{
|
||||
Type: "error",
|
||||
Message: err,
|
||||
}
|
||||
default:
|
||||
// 未知类型,尝试转换为字符串
|
||||
return &types.OpenAIError{
|
||||
Type: "unknown_error",
|
||||
Message: fmt.Sprintf("%v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -45,6 +45,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||
github.com/aws/smithy-go v1.20.2 // indirect
|
||||
github.com/boombuler/barcode v1.1.0 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -79,6 +80,7 @@ require (
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||
github.com/pquerna/otp v1.5.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -20,6 +20,10 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
||||
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
||||
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
|
||||
github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
@@ -169,6 +173,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
||||
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
|
||||
@@ -585,6 +585,19 @@
|
||||
"渠道权重": "渠道权重",
|
||||
"渠道额外设置": "渠道额外设置",
|
||||
"此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:": "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:",
|
||||
"强制格式化": "强制格式化",
|
||||
"强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)": "强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)",
|
||||
"思考内容转换": "思考内容转换",
|
||||
"将 reasoning_content 转换为 <think> 标签拼接到内容中": "将 reasoning_content 转换为 <think> 标签拼接到内容中",
|
||||
"透传请求体": "透传请求体",
|
||||
"启用请求体透传功能": "启用请求体透传功能",
|
||||
"代理地址": "代理地址",
|
||||
"例如: socks5://user:pass@host:port": "例如: socks5://user:pass@host:port",
|
||||
"用于配置网络代理": "用于配置网络代理",
|
||||
"用于配置网络代理,支持 socks5 协议": "用于配置网络代理,支持 socks5 协议",
|
||||
"系统提示词": "系统提示词",
|
||||
"输入系统提示词,用户的系统提示词将优先于此设置": "输入系统提示词,用户的系统提示词将优先于此设置",
|
||||
"用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置": "用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置",
|
||||
"参数覆盖": "参数覆盖",
|
||||
"此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:": "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:",
|
||||
"请输入组织org-xxx": "请输入组织org-xxx",
|
||||
|
||||
@@ -122,6 +122,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Set("role", role)
|
||||
c.Set("id", id)
|
||||
c.Set("group", session.Get("group"))
|
||||
c.Set("user_group", session.Get("group"))
|
||||
c.Set("use_access_token", useAccessToken)
|
||||
|
||||
//userCache, err := model.GetUserCache(id.(int))
|
||||
|
||||
@@ -100,6 +100,10 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
if shouldSelectChannel {
|
||||
if modelRequest.Model == "" {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空")
|
||||
return
|
||||
}
|
||||
var selectGroup string
|
||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
@@ -107,18 +111,17 @@ func Distribute() func(c *gin.Context) {
|
||||
if userGroup == "auto" {
|
||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||
}
|
||||
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
}
|
||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
||||
//if channel != nil {
|
||||
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
// message = "数据库一致性已被破坏,请联系管理员"
|
||||
//}
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
||||
return
|
||||
}
|
||||
if channel == nil {
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,distributor)", userGroup, modelRequest.Model))
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -244,7 +247,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||
c.Set("original_model", modelName) // for retry
|
||||
if channel == nil {
|
||||
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
|
||||
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||
@@ -266,6 +269,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
||||
} else {
|
||||
// 必须设置为 false,否则在重试到单个 key 的时候会导致日志显示错误
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
|
||||
}
|
||||
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
|
||||
@@ -136,7 +136,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("channel not found")
|
||||
return nil, nil
|
||||
}
|
||||
err = DB.First(&channel, "id = ?", channel.Id).Error
|
||||
return &channel, err
|
||||
@@ -284,6 +284,21 @@ func FixAbility() (int, int, error) {
|
||||
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
||||
}
|
||||
defer fixLock.Unlock()
|
||||
|
||||
// truncate abilities table
|
||||
if common.UsingSQLite {
|
||||
err := DB.Exec("DELETE FROM abilities").Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
return 0, 0, err
|
||||
}
|
||||
} else {
|
||||
err := DB.Exec("TRUNCATE TABLE abilities").Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
var channels []*Channel
|
||||
// Find all channels
|
||||
err := DB.Model(&Channel{}).Find(&channels).Error
|
||||
|
||||
@@ -41,19 +41,25 @@ type Channel struct {
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
Settings string `json:"settings"`
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
// add after v0.8.5
|
||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||
|
||||
// cache info
|
||||
Keys []string `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
type ChannelInfo struct {
|
||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||
MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason
|
||||
MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time
|
||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer interface
|
||||
@@ -67,15 +73,18 @@ func (c *ChannelInfo) Scan(value interface{}) error {
|
||||
return common.Unmarshal(bytesValue, c)
|
||||
}
|
||||
|
||||
func (channel *Channel) getKeys() []string {
|
||||
func (channel *Channel) GetKeys() []string {
|
||||
if channel.Key == "" {
|
||||
return []string{}
|
||||
}
|
||||
if len(channel.Keys) > 0 {
|
||||
return channel.Keys
|
||||
}
|
||||
trimmed := strings.TrimSpace(channel.Key)
|
||||
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
res := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
res[i] = string(v)
|
||||
@@ -95,7 +104,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||||
}
|
||||
|
||||
// Obtain all keys (split by \n)
|
||||
keys := channel.getKeys()
|
||||
keys := channel.GetKeys()
|
||||
if len(keys) == 0 {
|
||||
// No keys available, return error, should disable the channel
|
||||
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
||||
@@ -138,7 +147,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||||
|
||||
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||||
if err != nil {
|
||||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
||||
defer func() {
|
||||
@@ -197,7 +206,7 @@ func (channel *Channel) GetGroups() []string {
|
||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
otherInfo := make(map[string]interface{})
|
||||
if channel.OtherInfo != "" {
|
||||
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal other info: " + err.Error())
|
||||
}
|
||||
@@ -425,7 +434,7 @@ func (channel *Channel) Update() error {
|
||||
trimmed := strings.TrimSpace(keyStr)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
keys = make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
keys[i] = string(v)
|
||||
@@ -522,8 +531,8 @@ func CleanupChannelPollingLocks() {
|
||||
})
|
||||
}
|
||||
|
||||
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
|
||||
keys := channel.getKeys()
|
||||
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
|
||||
keys := channel.GetKeys()
|
||||
if len(keys) == 0 {
|
||||
channel.Status = status
|
||||
} else {
|
||||
@@ -541,6 +550,14 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||
} else {
|
||||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
}
|
||||
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
|
||||
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
|
||||
}
|
||||
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||||
channel.Status = common.ChannelStatusAutoDisabled
|
||||
@@ -563,7 +580,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
}
|
||||
if channelCache.ChannelInfo.IsMultiKey {
|
||||
// 如果是多Key模式,更新缓存中的状态
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status)
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||||
//CacheUpdateChannel(channelCache)
|
||||
//return true
|
||||
} else {
|
||||
@@ -571,10 +588,6 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
if channelCache.Status == status {
|
||||
return false
|
||||
}
|
||||
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
||||
if status != common.ChannelStatusEnabled {
|
||||
return false
|
||||
}
|
||||
CacheUpdateChannelStatus(channelId, status)
|
||||
}
|
||||
}
|
||||
@@ -598,7 +611,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
beforeStatus := channel.Status
|
||||
handlerMultiKeyUpdate(channel, usingKey, status)
|
||||
handlerMultiKeyUpdate(channel, usingKey, status, reason)
|
||||
if beforeStatus != channel.Status {
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
@@ -778,7 +791,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
func (channel *Channel) ValidateSettings() error {
|
||||
channelParams := &dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
err := common.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -789,7 +802,7 @@ func (channel *Channel) ValidateSettings() error {
|
||||
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
setting := dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
channel.Setting = nil // 清空设置以避免后续错误
|
||||
@@ -800,7 +813,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
}
|
||||
|
||||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
settingBytes, err := common.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
@@ -811,7 +824,7 @@ func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||||
paramOverride := make(map[string]interface{})
|
||||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||||
err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal param override: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -66,6 +67,20 @@ func InitChannelCache() {
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
//channelsIDM = newChannelId2channel
|
||||
for i, channel := range newChannelId2channel {
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
channel.Keys = channel.GetKeys()
|
||||
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||
if oldChannel, ok := channelsIDM[i]; ok {
|
||||
// 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
|
||||
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
channelsIDM = newChannelId2channel
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
@@ -130,7 +145,7 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
channels := group2model2channels[group][model]
|
||||
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(channels) == 1 {
|
||||
@@ -203,9 +218,6 @@ func CacheGetChannel(id int) (*Channel, error) {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
if c.Status != common.ChannelStatusEnabled {
|
||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -224,9 +236,6 @@ func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
if c.Status != common.ChannelStatusEnabled {
|
||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||
}
|
||||
return &c.ChannelInfo, nil
|
||||
}
|
||||
|
||||
@@ -239,6 +248,20 @@ func CacheUpdateChannelStatus(id int, status int) {
|
||||
if channel, ok := channelsIDM[id]; ok {
|
||||
channel.Status = status
|
||||
}
|
||||
if status != common.ChannelStatusEnabled {
|
||||
// delete the channel from group2model2channels
|
||||
for group, model2channels := range group2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
for i, channelId := range channels {
|
||||
if channelId == id {
|
||||
// remove the channel from the slice
|
||||
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CacheUpdateChannel(channel *Channel) {
|
||||
|
||||
@@ -251,6 +251,8 @@ func migrateDB() error {
|
||||
&QuotaData{},
|
||||
&Task{},
|
||||
&Setup{},
|
||||
&TwoFA{},
|
||||
&TwoFABackupCode{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -277,6 +279,8 @@ func migrateDBFast() error {
|
||||
{&QuotaData{}, "QuotaData"},
|
||||
{&Task{}, "Task"},
|
||||
{&Setup{}, "Setup"},
|
||||
{&TwoFA{}, "TwoFA"},
|
||||
{&TwoFABackupCode{}, "TwoFABackupCode"},
|
||||
}
|
||||
// 动态计算migration数量,确保errChan缓冲区足够大
|
||||
errChan := make(chan error, len(migrations))
|
||||
|
||||
322
model/twofa.go
Normal file
322
model/twofa.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var ErrTwoFANotEnabled = errors.New("用户未启用2FA")
|
||||
|
||||
// TwoFA 用户2FA设置表
|
||||
type TwoFA struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"unique;not null;index"`
|
||||
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
|
||||
IsEnabled bool `json:"is_enabled" gorm:"default:false"`
|
||||
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
|
||||
LockedUntil *time.Time `json:"locked_until,omitempty"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
// TwoFABackupCode 备用码使用记录表
|
||||
type TwoFABackupCode struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"not null;index"`
|
||||
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
|
||||
IsUsed bool `json:"is_used" gorm:"default:false"`
|
||||
UsedAt *time.Time `json:"used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
// GetTwoFAByUserId 根据用户ID获取2FA设置
|
||||
func GetTwoFAByUserId(userId int) (*TwoFA, error) {
|
||||
if userId == 0 {
|
||||
return nil, errors.New("用户ID不能为空")
|
||||
}
|
||||
|
||||
var twoFA TwoFA
|
||||
err := DB.Where("user_id = ?", userId).First(&twoFA).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil // 返回nil表示未设置2FA
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &twoFA, nil
|
||||
}
|
||||
|
||||
// IsTwoFAEnabled 检查用户是否启用了2FA
|
||||
func IsTwoFAEnabled(userId int) bool {
|
||||
twoFA, err := GetTwoFAByUserId(userId)
|
||||
if err != nil || twoFA == nil {
|
||||
return false
|
||||
}
|
||||
return twoFA.IsEnabled
|
||||
}
|
||||
|
||||
// CreateTwoFA 创建2FA设置
|
||||
func (t *TwoFA) Create() error {
|
||||
// 检查用户是否已存在2FA设置
|
||||
existing, err := GetTwoFAByUserId(t.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
return errors.New("用户已存在2FA设置")
|
||||
}
|
||||
|
||||
// 验证用户存在
|
||||
var user User
|
||||
if err := DB.First(&user, t.UserId).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return DB.Create(t).Error
|
||||
}
|
||||
|
||||
// Update 更新2FA设置
|
||||
func (t *TwoFA) Update() error {
|
||||
if t.Id == 0 {
|
||||
return errors.New("2FA记录ID不能为空")
|
||||
}
|
||||
return DB.Save(t).Error
|
||||
}
|
||||
|
||||
// Delete 删除2FA设置
|
||||
func (t *TwoFA) Delete() error {
|
||||
if t.Id == 0 {
|
||||
return errors.New("2FA记录ID不能为空")
|
||||
}
|
||||
|
||||
// 使用事务确保原子性
|
||||
return DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 同时删除相关的备用码记录(硬删除)
|
||||
if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 硬删除2FA记录
|
||||
return tx.Unscoped().Delete(t).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ResetFailedAttempts 重置失败尝试次数
|
||||
func (t *TwoFA) ResetFailedAttempts() error {
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
return t.Update()
|
||||
}
|
||||
|
||||
// IncrementFailedAttempts 增加失败尝试次数
|
||||
func (t *TwoFA) IncrementFailedAttempts() error {
|
||||
t.FailedAttempts++
|
||||
|
||||
// 检查是否需要锁定
|
||||
if t.FailedAttempts >= common.MaxFailAttempts {
|
||||
lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second)
|
||||
t.LockedUntil = &lockUntil
|
||||
}
|
||||
|
||||
return t.Update()
|
||||
}
|
||||
|
||||
// IsLocked 检查账户是否被锁定
|
||||
func (t *TwoFA) IsLocked() bool {
|
||||
if t.LockedUntil == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*t.LockedUntil)
|
||||
}
|
||||
|
||||
// CreateBackupCodes 创建备用码
|
||||
func CreateBackupCodes(userId int, codes []string) error {
|
||||
return DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 先删除现有的备用码
|
||||
if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建新的备用码记录
|
||||
for _, code := range codes {
|
||||
hashedCode, err := common.HashBackupCode(code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backupCode := TwoFABackupCode{
|
||||
UserId: userId,
|
||||
CodeHash: hashedCode,
|
||||
IsUsed: false,
|
||||
}
|
||||
|
||||
if err := tx.Create(&backupCode).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateBackupCode 验证并使用备用码
|
||||
func ValidateBackupCode(userId int, code string) (bool, error) {
|
||||
if !common.ValidateBackupCode(code) {
|
||||
return false, errors.New("验证码或备用码不正确")
|
||||
}
|
||||
|
||||
normalizedCode := common.NormalizeBackupCode(code)
|
||||
|
||||
// 查找未使用的备用码
|
||||
var backupCodes []TwoFABackupCode
|
||||
if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// 验证备用码
|
||||
for _, bc := range backupCodes {
|
||||
if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) {
|
||||
// 标记为已使用
|
||||
now := time.Now()
|
||||
bc.IsUsed = true
|
||||
bc.UsedAt = &now
|
||||
|
||||
if err := DB.Save(&bc).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetUnusedBackupCodeCount 获取未使用的备用码数量
|
||||
func GetUnusedBackupCodeCount(userId int) (int, error) {
|
||||
var count int64
|
||||
err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error
|
||||
return int(count), err
|
||||
}
|
||||
|
||||
// DisableTwoFA 禁用用户的2FA
|
||||
func DisableTwoFA(userId int) error {
|
||||
twoFA, err := GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if twoFA == nil {
|
||||
return ErrTwoFANotEnabled
|
||||
}
|
||||
|
||||
// 删除2FA设置和备用码
|
||||
return twoFA.Delete()
|
||||
}
|
||||
|
||||
// EnableTwoFA 启用2FA
|
||||
func (t *TwoFA) Enable() error {
|
||||
t.IsEnabled = true
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
return t.Update()
|
||||
}
|
||||
|
||||
// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录
|
||||
func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
|
||||
// 检查是否被锁定
|
||||
if t.IsLocked() {
|
||||
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
|
||||
// 验证TOTP码
|
||||
if !common.ValidateTOTPCode(t.Secret, code) {
|
||||
// 增加失败次数
|
||||
if err := t.IncrementFailedAttempts(); err != nil {
|
||||
common.SysError("更新2FA失败次数失败: " + err.Error())
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 验证成功,重置失败次数并更新最后使用时间
|
||||
now := time.Now()
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
t.LastUsedAt = &now
|
||||
|
||||
if err := t.Update(); err != nil {
|
||||
common.SysError("更新2FA使用记录失败: " + err.Error())
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录
|
||||
func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
|
||||
// 检查是否被锁定
|
||||
if t.IsLocked() {
|
||||
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
|
||||
// 验证备用码
|
||||
valid, err := ValidateBackupCode(t.UserId, code)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !valid {
|
||||
// 增加失败次数
|
||||
if err := t.IncrementFailedAttempts(); err != nil {
|
||||
common.SysError("更新2FA失败次数失败: " + err.Error())
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 验证成功,重置失败次数并更新最后使用时间
|
||||
now := time.Now()
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
t.LastUsedAt = &now
|
||||
|
||||
if err := t.Update(); err != nil {
|
||||
common.SysError("更新2FA使用记录失败: " + err.Error())
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetTwoFAStats 获取2FA统计信息(管理员使用)
|
||||
func GetTwoFAStats() (map[string]interface{}, error) {
|
||||
var totalUsers, enabledUsers int64
|
||||
|
||||
// 总用户数
|
||||
if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 启用2FA的用户数
|
||||
if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
enabledRate := float64(0)
|
||||
if totalUsers > 0 {
|
||||
enabledRate = float64(enabledUsers) / float64(totalUsers) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_users": totalUsers,
|
||||
"enabled_users": enabledUsers,
|
||||
"enabled_rate": fmt.Sprintf("%.1f%%", enabledRate),
|
||||
}, nil
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptTokens := 0
|
||||
@@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -90,18 +90,18 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
||||
|
||||
@@ -26,6 +26,7 @@ type Adaptor interface {
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
|
||||
}
|
||||
|
||||
type TaskAdaptor interface {
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -132,12 +132,12 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
var aliTaskResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliTaskResponse.Message != "" {
|
||||
|
||||
@@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
|
||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var aliResponse AliRerankResponse
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
|
||||
@@ -43,7 +43,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
|
||||
var fullTextResponse dto.FlexibleEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
@@ -179,12 +179,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
|
||||
@@ -223,7 +223,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
helper.SetEventStreamHeaders(c)
|
||||
// 处理流式请求的 ping 保活
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
if generalSettings.PingIntervalEnabled {
|
||||
if generalSettings.PingIntervalEnabled && !info.DisablePing {
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
stopPinger = startPingKeepAlive(c, pingInterval)
|
||||
// 使用defer确保在任何情况下都能停止ping goroutine
|
||||
|
||||
@@ -22,6 +22,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", request)
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -43,15 +48,15 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 0 || keyParts[0] == "" {
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+keyParts[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -612,8 +612,8 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
@@ -704,8 +704,8 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ConvertAudioRequest implements channel.Adaptor.
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
||||
BotType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/setting/model_setting"
|
||||
@@ -21,10 +20,33 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
if len(request.Contents) > 0 {
|
||||
for i, content := range request.Contents {
|
||||
if i == 0 {
|
||||
if request.Contents[0].Role == "" {
|
||||
request.Contents[0].Role = "user"
|
||||
}
|
||||
}
|
||||
for _, part := range content.Parts {
|
||||
if part.FileData != nil {
|
||||
if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") {
|
||||
part.FileData.MimeType = "video/webm"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, oaiReq.(*dto.GeneralOpenAIRequest))
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -49,13 +71,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
// build gemini imagen request
|
||||
geminiRequest := GeminiImageRequest{
|
||||
Instances: []GeminiImageInstance{
|
||||
geminiRequest := dto.GeminiImageRequest{
|
||||
Instances: []dto.GeminiImageInstance{
|
||||
{
|
||||
Prompt: request.Prompt,
|
||||
},
|
||||
},
|
||||
Parameters: GeminiImageParameters{
|
||||
Parameters: dto.GeminiImageParameters{
|
||||
SampleCount: request.N,
|
||||
AspectRatio: aspectRatio,
|
||||
PersonGeneration: "allow_adult", // default allow adult
|
||||
@@ -136,9 +158,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
}
|
||||
|
||||
// only process the first input
|
||||
geminiRequest := GeminiEmbeddingRequest{
|
||||
Content: GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
geminiRequest := dto.GeminiEmbeddingRequest{
|
||||
Content: dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{
|
||||
Text: inputs[0],
|
||||
},
|
||||
@@ -171,6 +193,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
if info.IsStream {
|
||||
info.DisablePing = true
|
||||
return GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
} else {
|
||||
return GeminiTextGenerationHandler(c, info, resp)
|
||||
@@ -208,60 +231,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse GeminiImageResponse
|
||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
if len(geminiResponse.Predictions) == 0 {
|
||||
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
openAIResponse := dto.ImageResponse{
|
||||
Created: common.GetTimestamp(),
|
||||
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
|
||||
}
|
||||
|
||||
for _, prediction := range geminiResponse.Predictions {
|
||||
if prediction.RaiFilteredReason != "" {
|
||||
continue // skip filtered image
|
||||
}
|
||||
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
|
||||
B64Json: prediction.BytesBase64Encoded,
|
||||
})
|
||||
}
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
|
||||
// each image has fixed 258 tokens
|
||||
const imageTokens = 258
|
||||
generatedImages := len(openAIResponse.Data)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
||||
CompletionTokens: 0, // image generation does not calculate completion tokens
|
||||
TotalTokens: imageTokens * generatedImages,
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -20,7 +21,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
@@ -28,10 +29,10 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
|
||||
// 解析为 Gemini 原生响应格式
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
@@ -54,7 +55,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||
jsonResponse, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
@@ -71,7 +72,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
@@ -110,10 +111,14 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
|
||||
info.SendResponseCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -80,7 +81,7 @@ func clampThinkingBudget(modelName string, budget int) int {
|
||||
return budget
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
@@ -92,7 +93,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
@@ -112,11 +113,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
if isUnsupported {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
@@ -127,7 +128,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||
if !isNew25Pro {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
@@ -136,11 +137,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
geminiRequest := dto.GeminiChatRequest{
|
||||
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
GenerationConfig: dto.GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
@@ -157,9 +158,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
|
||||
ThinkingAdaptor(&geminiRequest, info)
|
||||
|
||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
for _, category := range SafetySettingList {
|
||||
safetySettings = append(safetySettings, GeminiChatSafetySettings{
|
||||
safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
|
||||
Category: category,
|
||||
Threshold: model_setting.GetGeminiSafetySetting(category),
|
||||
})
|
||||
@@ -197,17 +198,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
if codeExecution {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
CodeExecution: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if googleSearch {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
})
|
||||
}
|
||||
@@ -219,9 +220,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
|
||||
|
||||
if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
|
||||
cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
|
||||
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
||||
if len(textRequest.ResponseFormat.JsonSchema) > 0 {
|
||||
// 先将json.RawMessage解析
|
||||
var jsonSchema dto.FormatJsonSchema
|
||||
if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
|
||||
cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
|
||||
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
||||
}
|
||||
}
|
||||
}
|
||||
tool_call_ids := make(map[string]string)
|
||||
@@ -233,7 +238,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
continue
|
||||
} else if message.Role == "tool" || message.Role == "function" {
|
||||
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
|
||||
Role: "user",
|
||||
})
|
||||
}
|
||||
@@ -260,18 +265,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
}
|
||||
|
||||
functionResp := &FunctionResponse{
|
||||
functionResp := &dto.GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: contentMap,
|
||||
}
|
||||
|
||||
*parts = append(*parts, GeminiPart{
|
||||
*parts = append(*parts, dto.GeminiPart{
|
||||
FunctionResponse: functionResp,
|
||||
})
|
||||
continue
|
||||
}
|
||||
var parts []GeminiPart
|
||||
content := GeminiChatContent{
|
||||
var parts []dto.GeminiPart
|
||||
content := dto.GeminiChatContent{
|
||||
Role: message.Role,
|
||||
}
|
||||
// isToolCall := false
|
||||
@@ -285,8 +290,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
|
||||
}
|
||||
}
|
||||
toolCall := GeminiPart{
|
||||
FunctionCall: &FunctionCall{
|
||||
toolCall := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: call.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
@@ -303,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if part.Text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
@@ -326,8 +331,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
|
||||
Data: fileData.Base64Data,
|
||||
},
|
||||
@@ -337,8 +342,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -352,8 +357,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -366,8 +371,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: "audio/" + part.GetInputAudio().Format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -387,8 +392,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
|
||||
if len(system_content) > 0 {
|
||||
geminiRequest.SystemInstructions = &GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
geminiRequest.SystemInstructions = &dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{
|
||||
Text: strings.Join(system_content, "\n"),
|
||||
},
|
||||
@@ -631,7 +636,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
|
||||
return data
|
||||
}
|
||||
|
||||
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
|
||||
var argsBytes []byte
|
||||
var err error
|
||||
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||
@@ -653,7 +658,7 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
@@ -720,10 +725,9 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
||||
isStop := false
|
||||
hasImage := false
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
||||
isStop = true
|
||||
@@ -732,7 +736,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
choice := dto.ChatCompletionsStreamResponseChoice{
|
||||
Index: int(candidate.Index),
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Role: "assistant",
|
||||
//Role: "assistant",
|
||||
},
|
||||
}
|
||||
var texts []string
|
||||
@@ -754,7 +758,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
hasImage = true
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
isTools = true
|
||||
@@ -791,28 +794,59 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Choices = choices
|
||||
return &response, isStop, hasImage
|
||||
return &response, isStop
|
||||
}
|
||||
|
||||
func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
|
||||
streamData, err := common.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal stream response: %w", err)
|
||||
}
|
||||
err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to handle stream format: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
|
||||
streamData, err := common.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal stream response: %w", err)
|
||||
}
|
||||
openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
// responseText := ""
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
responseText := strings.Builder{}
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
if hasImage {
|
||||
imageCount++
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
if part.Text != "" {
|
||||
responseText.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
@@ -829,18 +863,30 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
}
|
||||
err = helper.ObjectData(c, response)
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
// send first response
|
||||
err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil))
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
err = handleStream(c, info, response)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
if isStop {
|
||||
response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
|
||||
helper.ObjectData(c, response)
|
||||
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop))
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
var response *dto.ChatCompletionsStreamResponse
|
||||
if info.SendResponseCount == 0 {
|
||||
// 空补全,报错不计费
|
||||
// empty response, throw an error
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
@@ -851,14 +897,24 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := handleFinalStream(c, info, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
// helper.Done(c)
|
||||
//}
|
||||
//resp.Body.Close()
|
||||
return usage, nil
|
||||
}
|
||||
@@ -866,19 +922,19 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
@@ -915,12 +971,12 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var geminiResponse GeminiEmbeddingResponse
|
||||
var geminiResponse dto.GeminiEmbeddingResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
@@ -950,9 +1006,63 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
|
||||
jsonResponse, jsonErr := common.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse dto.GeminiImageResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if len(geminiResponse.Predictions) == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
openAIResponse := dto.ImageResponse{
|
||||
Created: common.GetTimestamp(),
|
||||
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
|
||||
}
|
||||
|
||||
for _, prediction := range geminiResponse.Predictions {
|
||||
if prediction.RaiFilteredReason != "" {
|
||||
continue // skip filtered image
|
||||
}
|
||||
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
|
||||
B64Json: prediction.BytesBase64Encoded,
|
||||
})
|
||||
}
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
|
||||
// each image has fixed 258 tokens
|
||||
const imageTokens = 258
|
||||
generatedImages := len(openAIResponse.Data)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
||||
CompletionTokens: 0, // image generation does not calculate completion tokens
|
||||
TotalTokens: imageTokens * generatedImages,
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
@@ -13,11 +12,18 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
|
||||
var jimengResponse ImageResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
err = json.Unmarshal(responseBody, &jimengResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// Check if the response indicates an error
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -17,10 +17,21 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
openaiAdaptor := openai.Adaptor{}
|
||||
openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -37,6 +48,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
return info.BaseUrl + "/v1/chat/completions", nil
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return info.BaseUrl + "/api/embed", nil
|
||||
@@ -55,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return requestOpenAI2Ollama(*request)
|
||||
return requestOpenAI2Ollama(request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -76,11 +90,12 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if !message.IsStringContent() {
|
||||
@@ -92,15 +92,15 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if ollamaEmbeddingResponse.Error != "" {
|
||||
return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||
@@ -121,7 +121,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
}
|
||||
doResponseBody, err := common.Marshal(embeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
common.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||
return usage, nil
|
||||
|
||||
@@ -34,6 +34,15 @@ type Adaptor struct {
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
// 使用 service.GeminiToOpenAIRequest 转换请求格式
|
||||
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, openaiRequest)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
//if !strings.Contains(request.Model, "claude") {
|
||||
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||
@@ -64,7 +73,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
|
||||
@@ -2,6 +2,8 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -14,20 +16,23 @@ import (
|
||||
)
|
||||
|
||||
// 辅助函数
|
||||
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
case relaycommon.RelayFormatClaude:
|
||||
return handleClaudeFormat(c, data, info)
|
||||
case relaycommon.RelayFormatGemini:
|
||||
return handleGeminiFormat(c, data, info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -41,6 +46,36 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
common.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// 如果返回 nil,表示没有实际内容,跳过发送
|
||||
if geminiResponse == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "failed to marshal gemini response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// send gemini format response
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
@@ -158,7 +193,7 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||||
func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||||
responseId string, createAt int64, model string, systemFingerprint string,
|
||||
usage *dto.Usage, containStreamUsage bool) {
|
||||
|
||||
@@ -174,7 +209,7 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
case relaycommon.RelayFormatClaude:
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
@@ -183,7 +218,38 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
|
||||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||
for _, resp := range claudeResponses {
|
||||
helper.ClaudeData(c, *resp)
|
||||
_ = helper.ClaudeData(c, *resp)
|
||||
}
|
||||
|
||||
case relaycommon.RelayFormatGemini:
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
|
||||
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
|
||||
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
|
||||
// 暂不知是否有程序会不兼容。
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// openai 流响应开头的空数据
|
||||
if geminiResponse == nil {
|
||||
return
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling gemini response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 发送最终的 Gemini 响应
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
common.LogError(c, "invalid response or response body")
|
||||
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
@@ -123,30 +123,19 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
var toolCount int
|
||||
var usage = &dto.Usage{}
|
||||
var streamItems []string // store stream items
|
||||
var forceFormat bool
|
||||
var thinkToContent bool
|
||||
|
||||
if info.ChannelSetting.ForceFormat {
|
||||
forceFormat = true
|
||||
}
|
||||
|
||||
if info.ChannelSetting.ThinkingToContent {
|
||||
thinkToContent = true
|
||||
}
|
||||
|
||||
var (
|
||||
lastStreamData string
|
||||
)
|
||||
var lastStreamData string
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
if lastStreamData != "" {
|
||||
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
if err != nil {
|
||||
common.SysError("error handling stream format: " + err.Error())
|
||||
}
|
||||
}
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
if len(data) > 0 {
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -154,16 +143,18 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
shouldSendLastResp := true
|
||||
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
|
||||
&containStreamUsage, info, &shouldSendLastResp); err != nil {
|
||||
common.SysError("error handling last response: " + err.Error())
|
||||
common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
|
||||
}
|
||||
|
||||
if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
if shouldSendLastResp {
|
||||
_ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理token计算
|
||||
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
||||
common.SysError("error processing tokens: " + err.Error())
|
||||
common.LogError(c, "error processing tokens: "+err.Error())
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
@@ -176,8 +167,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
@@ -188,14 +178,14 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
var simpleResponse dto.OpenAITextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
err = common.Unmarshal(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
|
||||
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
|
||||
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
forceFormat := false
|
||||
@@ -233,6 +223,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
case relaycommon.RelayFormatGemini:
|
||||
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
|
||||
geminiRespStr, err := common.Marshal(geminiResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = geminiRespStr
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
@@ -273,7 +270,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
@@ -557,13 +554,13 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
err = common.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
|
||||
@@ -22,14 +22,14 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
var responsesResponse dto.OpenAIResponsesResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
err = common.Unmarshal(responseBody, &responsesResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if responsesResponse.Error != nil {
|
||||
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
|
||||
if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -127,13 +127,13 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
||||
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,20 +18,24 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertImageRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -47,7 +51,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
} else if info.RelayMode == constant.RelayModeCompletions {
|
||||
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
|
||||
}
|
||||
return "", errors.New("invalid relay mode")
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
@@ -81,16 +85,19 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRerank:
|
||||
usage, err = siliconflowRerankHandler(c, info, resp)
|
||||
case constant.RelayModeEmbeddings:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
case constant.RelayModeCompletions:
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
fallthrough
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,13 +15,13 @@ import (
|
||||
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var siliconflowResp SFRerankResponse
|
||||
err = json.Unmarshal(responseBody, &siliconflowResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
|
||||
|
||||
285
relay/channel/task/vidu/adaptor.go
Normal file
285
relay/channel/task/vidu/adaptor.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package vidu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Model string `json:"model"`
|
||||
Images []string `json:"images"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
MovementAmplitude string `json:"movement_amplitude,omitempty"`
|
||||
Bgm bool `json:"bgm,omitempty"`
|
||||
Payload string `json:"payload,omitempty"`
|
||||
CallbackUrl string `json:"callback_url,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
TaskId string `json:"task_id"`
|
||||
State string `json:"state"`
|
||||
Model string `json:"model"`
|
||||
Images []string `json:"images"`
|
||||
Prompt string `json:"prompt"`
|
||||
Duration int `json:"duration"`
|
||||
Seed int `json:"seed"`
|
||||
Resolution string `json:"resolution"`
|
||||
Bgm bool `json:"bgm"`
|
||||
MovementAmplitude string `json:"movement_amplitude"`
|
||||
Payload string `json:"payload"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
type taskResultResponse struct {
|
||||
State string `json:"state"`
|
||||
ErrCode string `json:"err_code"`
|
||||
Credits int `json:"credits"`
|
||||
Payload string `json:"payload"`
|
||||
Creations []creation `json:"creations"`
|
||||
}
|
||||
|
||||
type creation struct {
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
CoverURL string `json:"cover_url"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.BaseUrl
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {
|
||||
var req SubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Image != "" {
|
||||
info.Action = constant.TaskActionGenerate
|
||||
} else {
|
||||
info.Action = constant.TaskActionTextGenerate
|
||||
}
|
||||
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(body.Images) == 0 {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
var path string
|
||||
switch info.Action {
|
||||
case constant.TaskActionGenerate:
|
||||
path = "/img2video"
|
||||
default:
|
||||
path = "/text2video"
|
||||
}
|
||||
return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
if action := c.GetString("action"); action != "" {
|
||||
info.Action = action
|
||||
}
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var vResp responsePayload
|
||||
err = json.Unmarshal(responseBody, &vResp)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if vResp.State == "failed" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, vResp)
|
||||
return vResp.TaskId, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return []string{"viduq1", "vidu2.0", "vidu1.5"}
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return "vidu"
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
var images []string
|
||||
if req.Image != "" {
|
||||
images = []string{req.Image}
|
||||
}
|
||||
|
||||
r := requestPayload{
|
||||
Model: defaultString(req.Model, "viduq1"),
|
||||
Images: images,
|
||||
Prompt: req.Prompt,
|
||||
Duration: defaultInt(req.Duration, 5),
|
||||
Resolution: defaultString(req.Size, "1080p"),
|
||||
MovementAmplitude: "auto",
|
||||
Bgm: false,
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func defaultString(value, defaultValue string) string {
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func defaultInt(value, defaultValue int) int {
|
||||
if value == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
|
||||
var taskResp taskResultResponse
|
||||
err := json.Unmarshal(respBody, &taskResp)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
|
||||
state := taskResp.State
|
||||
switch state {
|
||||
case "created", "queueing":
|
||||
taskInfo.Status = model.TaskStatusSubmitted
|
||||
case "processing":
|
||||
taskInfo.Status = model.TaskStatusInProgress
|
||||
case "success":
|
||||
taskInfo.Status = model.TaskStatusSuccess
|
||||
if len(taskResp.Creations) > 0 {
|
||||
taskInfo.Url = taskResp.Creations[0].URL
|
||||
}
|
||||
case "failed":
|
||||
taskInfo.Status = model.TaskStatusFailure
|
||||
if taskResp.ErrCode != "" {
|
||||
taskInfo.Reason = taskResp.ErrCode
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown task state: %s", state)
|
||||
}
|
||||
|
||||
return taskInfo, nil
|
||||
}
|
||||
@@ -25,6 +25,11 @@ type Adaptor struct {
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
|
||||
var tencentSb TencentChatResponseSB
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &tencentSb)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if tencentSb.Response.Error.Code != 0 {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
|
||||
@@ -44,6 +44,11 @@ type Adaptor struct {
|
||||
AccountCredentials Credentials
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
geminiAdaptor := gemini.Adaptor{}
|
||||
return geminiAdaptor.ConvertGeminiRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
c.Set("request_model", v)
|
||||
@@ -67,10 +72,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude") {
|
||||
a.RequestMode = RequestModeClaude
|
||||
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
|
||||
a.RequestMode = RequestModeGemini
|
||||
} else if strings.Contains(info.UpstreamModelName, "llama") {
|
||||
a.RequestMode = RequestModeLlama
|
||||
} else {
|
||||
a.RequestMode = RequestModeGemini
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,6 +88,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
a.AccountCredentials = *adc
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
@@ -100,6 +106,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
} else {
|
||||
suffix = "generateContent"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
suffix = "predict"
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
@@ -231,6 +242,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||
} else {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
return gemini.GeminiImageHandler(c, info, resp)
|
||||
}
|
||||
usage, err = gemini.GeminiChatHandler(c, info, resp)
|
||||
}
|
||||
case RequestModeLlama:
|
||||
|
||||
@@ -36,7 +36,12 @@ var Cache = asynccache.NewAsyncCache(asynccache.Options{
|
||||
})
|
||||
|
||||
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
|
||||
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
|
||||
var cacheKey string
|
||||
if info.ChannelIsMultiKey {
|
||||
cacheKey = fmt.Sprintf("access-token-%d-%d", info.ChannelId, info.ChannelMultiKeyIndex)
|
||||
} else {
|
||||
cacheKey = fmt.Sprintf("access-token-%d", info.ChannelId)
|
||||
}
|
||||
val, err := Cache.Get(cacheKey)
|
||||
if err == nil {
|
||||
return val.(string), nil
|
||||
|
||||
@@ -23,6 +23,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
//panic("implement me")
|
||||
|
||||
@@ -17,6 +17,11 @@ type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -220,12 +220,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
|
||||
var zhipuResponse ZhipuResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if !zhipuResponse.Success {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -40,7 +40,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateClaudeRequest(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if textRequest.Stream {
|
||||
@@ -49,18 +49,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
@@ -77,10 +77,9 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
|
||||
if textRequest.MaxTokens == 0 {
|
||||
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
@@ -108,18 +107,41 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
}
|
||||
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(jsonData))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(jsonData))
|
||||
}
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
|
||||
@@ -60,17 +60,19 @@ type ResponsesUsageInfo struct {
|
||||
}
|
||||
|
||||
type RelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
UsingGroup string // 使用的分组
|
||||
UserGroup string // 用户所在分组
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
isFirstResponse bool
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
ChannelIsMultiKey bool // 是否多密钥
|
||||
ChannelMultiKeyIndex int // 多密钥索引
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
UsingGroup string // 使用的分组
|
||||
UserGroup string // 用户所在分组
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
isFirstResponse bool
|
||||
//SendLastReasoningResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
@@ -88,6 +90,7 @@ type RelayInfo struct {
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
||||
IsModelMapped bool
|
||||
ClientWs *websocket.Conn
|
||||
TargetWs *websocket.Conn
|
||||
@@ -259,6 +262,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
},
|
||||
|
||||
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
|
||||
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
||||
info.IsPlayground = true
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
if common.DebugEnabled {
|
||||
@@ -27,7 +27,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
var xinRerankResponse xinference.XinRerankResponse
|
||||
err = common.Unmarshal(responseBody, &xinRerankResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
|
||||
for i, result := range xinRerankResponse.Results {
|
||||
@@ -62,7 +62,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
} else {
|
||||
err = common.Unmarshal(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
|
||||
}
|
||||
|
||||
@@ -41,17 +41,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||
@@ -59,7 +59,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -74,18 +74,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -2,9 +2,9 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
|
||||
request := &gemini.GeminiChatRequest{}
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
|
||||
request := &dto.GeminiChatRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -44,7 +44,7 @@ func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
// }
|
||||
}
|
||||
|
||||
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
|
||||
func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
|
||||
var inputTexts []string
|
||||
for _, content := range textRequest.Contents {
|
||||
for _, part := range content.Parts {
|
||||
@@ -61,7 +61,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
// 计算输入 token 数量
|
||||
var inputTexts []string
|
||||
for _, content := range req.Contents {
|
||||
@@ -78,9 +78,13 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
|
||||
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
|
||||
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
|
||||
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0
|
||||
configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
|
||||
if configBudget != nil && *configBudget == 0 {
|
||||
// 如果思考预算为 0,则认为是非思考请求
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -109,7 +113,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
req, err := getAndValidateGeminiRequest(c)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||
@@ -121,14 +125,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
sensitiveWords, err := checkGeminiInputSensitive(req)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
// model mapped 模型映射
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
@@ -159,7 +163,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// pre consume quota
|
||||
@@ -175,7 +179,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
adaptor.Init(relayInfo)
|
||||
@@ -194,19 +198,47 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewReader(body)
|
||||
} else {
|
||||
// 使用 ConvertGeminiRequest 转换请求格式
|
||||
convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("Gemini request body: %s", string(jsonData))
|
||||
}
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("Gemini request body: %s", string(requestBody))
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -139,6 +139,24 @@ func GetLocalRealtimeID(c *gin.Context) string {
|
||||
return fmt.Sprintf("evt_%s", logID)
|
||||
}
|
||||
|
||||
func GenerateStartEmptyResponse(id string, createAt int64, model string, systemFingerprint *string) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: createAt,
|
||||
Model: model,
|
||||
SystemFingerprint: systemFingerprint,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{
|
||||
{
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Role: "assistant",
|
||||
Content: common.GetPointer(""),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
|
||||
@@ -54,7 +54,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
)
|
||||
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
pingEnabled := generalSettings.PingIntervalEnabled
|
||||
pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
if pingInterval <= 0 {
|
||||
pingInterval = DefaultPingInterval
|
||||
@@ -234,6 +234,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// done, 处理完成标志,直接退出停止读取剩余数据防止出错
|
||||
if common.DebugEnabled {
|
||||
println("received [DONE], stopping scanner")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
@@ -114,17 +115,17 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
var preConsumedQuota int
|
||||
var quota int
|
||||
@@ -172,37 +173,58 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
||||
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if userQuota-quota < 0 {
|
||||
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota)
|
||||
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
var requestBody io.Reader
|
||||
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||
requestBody = convertedRequest.(io.Reader)
|
||||
} else {
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||
requestBody = convertedRequest.(io.Reader)
|
||||
} else {
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("image request body: %s", requestBody))
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("image request body: %s", string(jsonData)))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -2,7 +2,6 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -91,9 +90,8 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
@@ -104,13 +102,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||||
@@ -122,14 +120,14 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
promptTokens, err = getPromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
@@ -166,25 +164,49 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(body))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
|
||||
if relayInfo.ChannelSetting.SystemPrompt != "" {
|
||||
// 如果有系统提示,则将其添加到请求中
|
||||
request := convertedRequest.(*dto.GeneralOpenAIRequest)
|
||||
containSystemPrompt := false
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
containSystemPrompt = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !containSystemPrompt {
|
||||
// 如果没有系统提示,则添加系统提示
|
||||
systemMessage := dto.Message{
|
||||
Role: request.GetSystemRoleName(),
|
||||
Content: relayInfo.ChannelSetting.SystemPrompt,
|
||||
}
|
||||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
@@ -196,7 +218,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,7 +230,6 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
@@ -281,13 +302,13 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
|
||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if userQuota <= 0 {
|
||||
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
|
||||
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
|
||||
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
relayInfo.UserQuota = userQuota
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
@@ -311,11 +332,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
if preConsumedQuota > 0 {
|
||||
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden)
|
||||
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError)
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
return preConsumedQuota, userQuota, nil
|
||||
@@ -494,6 +515,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||
} else {
|
||||
if !ratio.IsZero() && quota == 0 {
|
||||
quota = 1
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
taskjimeng "one-api/relay/channel/task/jimeng"
|
||||
"one-api/relay/channel/task/kling"
|
||||
"one-api/relay/channel/task/suno"
|
||||
taskVidu "one-api/relay/channel/task/vidu"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/vertex"
|
||||
"one-api/relay/channel/volcengine"
|
||||
@@ -122,6 +123,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
return &kling.TaskAdaptor{}
|
||||
case constant.ChannelTypeJimeng:
|
||||
return &taskjimeng.TaskAdaptor{}
|
||||
case constant.ChannelTypeVidu:
|
||||
return &taskVidu.TaskAdaptor{}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -3,12 +3,14 @@ package relay
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -29,21 +31,21 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
|
||||
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
|
||||
|
||||
if rerankRequest.Query == "" {
|
||||
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if len(rerankRequest.Documents) == 0 {
|
||||
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptToken := getRerankPromptToken(*rerankRequest)
|
||||
@@ -51,7 +53,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -66,22 +68,46 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("Rerank request body: %s", requestBody.String()))
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("Rerank request body: %s", string(jsonData)))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
|
||||
@@ -51,7 +51,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
req, err := getAndValidateResponsesRequest(c)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
|
||||
@@ -60,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
sensitiveWords, err := checkInputSensitive(req, relayInfo)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
@@ -79,7 +79,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
// pre consume quota
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -93,38 +93,38 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}()
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed)
|
||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
err = json.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = json.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,12 +24,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr
|
||||
|
||||
err := helper.ModelMappedHelper(c, relayInfo, nil)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
@@ -46,7 +46,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
//var requestBody io.Reader
|
||||
|
||||
@@ -44,6 +44,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
{
|
||||
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
|
||||
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
|
||||
userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), controller.Verify2FALogin)
|
||||
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
|
||||
userRoute.GET("/logout", controller.Logout)
|
||||
userRoute.GET("/epay/notify", controller.EpayNotify)
|
||||
@@ -66,6 +67,13 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
|
||||
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
||||
selfRoute.PUT("/setting", controller.UpdateUserSetting)
|
||||
|
||||
// 2FA routes
|
||||
selfRoute.GET("/2fa/status", controller.Get2FAStatus)
|
||||
selfRoute.POST("/2fa/setup", controller.Setup2FA)
|
||||
selfRoute.POST("/2fa/enable", controller.Enable2FA)
|
||||
selfRoute.POST("/2fa/disable", controller.Disable2FA)
|
||||
selfRoute.POST("/2fa/backup_codes", controller.RegenerateBackupCodes)
|
||||
}
|
||||
|
||||
adminRoute := userRoute.Group("/")
|
||||
@@ -78,6 +86,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
adminRoute.POST("/manage", controller.ManageUser)
|
||||
adminRoute.PUT("/", controller.UpdateUser)
|
||||
adminRoute.DELETE("/:id", controller.DeleteUser)
|
||||
|
||||
// Admin 2FA routes
|
||||
adminRoute.GET("/2fa/stats", controller.Admin2FAStats)
|
||||
adminRoute.DELETE("/:id/2fa", controller.AdminDisable2FA)
|
||||
}
|
||||
}
|
||||
optionRoute := apiRouter.Group("/option")
|
||||
@@ -120,6 +132,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
|
||||
channelRoute.GET("/tag/models", controller.GetTagModels)
|
||||
channelRoute.POST("/copy/:id", controller.CopyChannel)
|
||||
channelRoute.POST("/multi_key/manage", controller.ManageMultiKeys)
|
||||
}
|
||||
tokenRoute := apiRouter.Group("/token")
|
||||
tokenRoute.Use(middleware.UserAuth())
|
||||
|
||||
@@ -45,7 +45,7 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
|
||||
if types.IsChannelError(err) {
|
||||
return true
|
||||
}
|
||||
if types.IsLocalError(err) {
|
||||
if types.IsSkipRetryError(err) {
|
||||
return false
|
||||
}
|
||||
if err.StatusCode == http.StatusUnauthorized {
|
||||
|
||||
@@ -188,28 +188,6 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
|
||||
return &openAIRequest, nil
|
||||
}
|
||||
|
||||
func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
|
||||
claudeError := dto.ClaudeError{
|
||||
Type: "new_api_error",
|
||||
Message: openAIError.Error.Message,
|
||||
}
|
||||
return &dto.ClaudeErrorWithStatusCode{
|
||||
Error: claudeError,
|
||||
StatusCode: openAIError.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: claudeError.Error.Message,
|
||||
Type: "new_api_error",
|
||||
}
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: openAIError,
|
||||
StatusCode: claudeError.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func generateStopBlock(index int) *dto.ClaudeResponse {
|
||||
return &dto.ClaudeResponse{
|
||||
Type: "content_block_stop",
|
||||
@@ -251,22 +229,54 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
resp.SetIndex(0)
|
||||
claudeResponses = append(claudeResponses, resp)
|
||||
} else {
|
||||
//resp := &dto.ClaudeResponse{
|
||||
// Type: "content_block_start",
|
||||
// ContentBlock: &dto.ClaudeMediaMessage{
|
||||
// Type: "text",
|
||||
// Text: common.GetPointer[string](""),
|
||||
// },
|
||||
//}
|
||||
//resp.SetIndex(0)
|
||||
//claudeResponses = append(claudeResponses, resp)
|
||||
|
||||
}
|
||||
// 判断首个响应是否存在内容(非标准的 OpenAI 响应)
|
||||
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()),
|
||||
},
|
||||
})
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||
}
|
||||
return claudeResponses
|
||||
}
|
||||
|
||||
if len(openAIResponse.Choices) == 0 {
|
||||
// no choices
|
||||
// TODO: handle this case
|
||||
// 可能为非标准的 OpenAI 响应,判断是否已经完成
|
||||
if info.Done {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
})
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_stop",
|
||||
})
|
||||
}
|
||||
return claudeResponses
|
||||
} else {
|
||||
chosenChoice := openAIResponse.Choices[0]
|
||||
@@ -438,3 +448,353 @@ func toJSONString(v interface{}) string {
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||||
openaiRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: info.UpstreamModelName,
|
||||
Stream: info.IsStream,
|
||||
}
|
||||
|
||||
// 转换 messages
|
||||
var messages []dto.Message
|
||||
for _, content := range geminiRequest.Contents {
|
||||
message := dto.Message{
|
||||
Role: convertGeminiRoleToOpenAI(content.Role),
|
||||
}
|
||||
|
||||
// 处理 parts
|
||||
var mediaContents []dto.MediaContent
|
||||
var toolCalls []dto.ToolCallRequest
|
||||
for _, part := range content.Parts {
|
||||
if part.Text != "" {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "text",
|
||||
Text: part.Text,
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.InlineData != nil {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "image_url",
|
||||
ImageUrl: &dto.MessageImageUrl{
|
||||
Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
|
||||
Detail: "auto",
|
||||
MimeType: part.InlineData.MimeType,
|
||||
},
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.FileData != nil {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "image_url",
|
||||
ImageUrl: &dto.MessageImageUrl{
|
||||
Url: part.FileData.FileUri,
|
||||
Detail: "auto",
|
||||
MimeType: part.FileData.MimeType,
|
||||
},
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.FunctionCall != nil {
|
||||
// 处理 Gemini 的工具调用
|
||||
toolCall := dto.ToolCallRequest{
|
||||
ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: part.FunctionCall.FunctionName,
|
||||
Arguments: toJSONString(part.FunctionCall.Arguments),
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
} else if part.FunctionResponse != nil {
|
||||
// 处理 Gemini 的工具响应,创建单独的 tool 消息
|
||||
toolMessage := dto.Message{
|
||||
Role: "tool",
|
||||
ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
|
||||
}
|
||||
toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
|
||||
messages = append(messages, toolMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置消息内容
|
||||
if len(toolCalls) > 0 {
|
||||
// 如果有工具调用,设置工具调用
|
||||
message.SetToolCalls(toolCalls)
|
||||
} else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
|
||||
// 如果只有一个文本内容,直接设置字符串
|
||||
message.Content = mediaContents[0].Text
|
||||
} else if len(mediaContents) > 0 {
|
||||
// 如果有多个内容或包含媒体,设置为数组
|
||||
message.SetMediaContent(mediaContents)
|
||||
}
|
||||
|
||||
// 只有当消息有内容或工具调用时才添加
|
||||
if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
|
||||
messages = append(messages, message)
|
||||
}
|
||||
}
|
||||
|
||||
openaiRequest.Messages = messages
|
||||
|
||||
if geminiRequest.GenerationConfig.Temperature != nil {
|
||||
openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
|
||||
}
|
||||
if geminiRequest.GenerationConfig.TopP > 0 {
|
||||
openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
|
||||
}
|
||||
if geminiRequest.GenerationConfig.TopK > 0 {
|
||||
openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
|
||||
}
|
||||
// gemini stop sequences 最多 5 个,openai stop 最多 4 个
|
||||
if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
|
||||
openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
|
||||
}
|
||||
if geminiRequest.GenerationConfig.CandidateCount > 0 {
|
||||
openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
|
||||
}
|
||||
|
||||
// 转换工具调用
|
||||
if len(geminiRequest.Tools) > 0 {
|
||||
var tools []dto.ToolCallRequest
|
||||
for _, tool := range geminiRequest.Tools {
|
||||
if tool.FunctionDeclarations != nil {
|
||||
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
||||
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
||||
if ok {
|
||||
for _, function := range functionDeclarations {
|
||||
openAITool := dto.ToolCallRequest{
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: function.Name,
|
||||
Description: function.Description,
|
||||
Parameters: function.Parameters,
|
||||
},
|
||||
}
|
||||
tools = append(tools, openAITool)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
openaiRequest.Tools = tools
|
||||
}
|
||||
}
|
||||
|
||||
// gemini system instructions
|
||||
if geminiRequest.SystemInstructions != nil {
|
||||
// 将系统指令作为第一条消息插入
|
||||
systemMessage := dto.Message{
|
||||
Role: "system",
|
||||
Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
|
||||
}
|
||||
openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
|
||||
}
|
||||
|
||||
return openaiRequest, nil
|
||||
}
|
||||
|
||||
func convertGeminiRoleToOpenAI(geminiRole string) string {
|
||||
switch geminiRole {
|
||||
case "user":
|
||||
return "user"
|
||||
case "model":
|
||||
return "assistant"
|
||||
case "function":
|
||||
return "function"
|
||||
default:
|
||||
return "user"
|
||||
}
|
||||
}
|
||||
|
||||
func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
|
||||
var texts []string
|
||||
for _, part := range parts {
|
||||
if part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
}
|
||||
|
||||
// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
|
||||
func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: openAIResponse.PromptTokens,
|
||||
CandidatesTokenCount: openAIResponse.CompletionTokens,
|
||||
TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
}
|
||||
|
||||
// 设置结束原因
|
||||
var finishReason string
|
||||
switch choice.FinishReason {
|
||||
case "stop":
|
||||
finishReason = "STOP"
|
||||
case "length":
|
||||
finishReason = "MAX_TOKENS"
|
||||
case "content_filter":
|
||||
finishReason = "SAFETY"
|
||||
case "tool_calls":
|
||||
finishReason = "STOP"
|
||||
default:
|
||||
finishReason = "STOP"
|
||||
}
|
||||
candidate.FinishReason = &finishReason
|
||||
|
||||
// 转换消息内容
|
||||
content := dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: make([]dto.GeminiPart, 0),
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
toolCalls := choice.Message.ParseToolCalls()
|
||||
if len(toolCalls) > 0 {
|
||||
for _, toolCall := range toolCalls {
|
||||
// 解析参数
|
||||
var args map[string]interface{}
|
||||
if toolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||||
}
|
||||
} else {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
|
||||
part := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: toolCall.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
} else {
|
||||
// 处理文本内容
|
||||
textContent := choice.Message.StringContent()
|
||||
if textContent != "" {
|
||||
part := dto.GeminiPart{
|
||||
Text: textContent,
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
}
|
||||
|
||||
candidate.Content = content
|
||||
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||||
}
|
||||
|
||||
return geminiResponse
|
||||
}
|
||||
|
||||
// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
|
||||
func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||||
// 检查是否有实际内容或结束标志
|
||||
hasContent := false
|
||||
hasFinishReason := false
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
|
||||
hasContent = true
|
||||
}
|
||||
if choice.FinishReason != nil {
|
||||
hasFinishReason = true
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
|
||||
if !hasContent && !hasFinishReason {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: info.PromptTokens,
|
||||
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
|
||||
TotalTokenCount: info.PromptTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
}
|
||||
|
||||
// 设置结束原因
|
||||
if choice.FinishReason != nil {
|
||||
var finishReason string
|
||||
switch *choice.FinishReason {
|
||||
case "stop":
|
||||
finishReason = "STOP"
|
||||
case "length":
|
||||
finishReason = "MAX_TOKENS"
|
||||
case "content_filter":
|
||||
finishReason = "SAFETY"
|
||||
case "tool_calls":
|
||||
finishReason = "STOP"
|
||||
default:
|
||||
finishReason = "STOP"
|
||||
}
|
||||
candidate.FinishReason = &finishReason
|
||||
}
|
||||
|
||||
// 转换消息内容
|
||||
content := dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: make([]dto.GeminiPart, 0),
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
for _, toolCall := range choice.Delta.ToolCalls {
|
||||
// 解析参数
|
||||
var args map[string]interface{}
|
||||
if toolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||||
}
|
||||
} else {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
|
||||
part := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: toolCall.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
} else {
|
||||
// 处理文本内容
|
||||
textContent := choice.Delta.GetContentString()
|
||||
if textContent != "" {
|
||||
part := dto.GeminiPart{
|
||||
Text: textContent,
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
}
|
||||
|
||||
candidate.Content = content
|
||||
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||||
}
|
||||
|
||||
return geminiResponse
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -63,7 +62,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError
|
||||
text = "请求上游地址失败"
|
||||
}
|
||||
}
|
||||
claudeError := dto.ClaudeError{
|
||||
claudeError := types.ClaudeError{
|
||||
Message: text,
|
||||
Type: "new_api_error",
|
||||
}
|
||||
@@ -80,10 +79,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
||||
newApiErr = &types.NewAPIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ErrorType: types.ErrorTypeOpenAIError,
|
||||
}
|
||||
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -105,8 +101,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
|
||||
// General format error (OpenAI, Anthropic, Gemini, etc.)
|
||||
newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode)
|
||||
} else {
|
||||
newApiErr = types.NewErrorWithStatusCode(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
newApiErr.ErrorType = types.ErrorTypeOpenAIError
|
||||
newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -116,7 +111,7 @@ func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string)
|
||||
return
|
||||
}
|
||||
statusCodeMapping := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
|
||||
err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package setting
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
)
|
||||
@@ -58,6 +59,9 @@ func CheckModelRequestRateLimitGroup(jsonStr string) error {
|
||||
if limits[0] < 0 || limits[1] < 1 {
|
||||
return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1])
|
||||
}
|
||||
if limits[0] > math.MaxInt32 || limits[1] > math.MaxInt32 {
|
||||
return fmt.Errorf("group %s [%d, %d] has max rate limits value 2147483647", group, limits[0], limits[1])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
147
types/error.go
147
types/error.go
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -15,8 +16,8 @@ type OpenAIError struct {
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type ErrorType string
|
||||
@@ -28,6 +29,7 @@ const (
|
||||
ErrorTypeMidjourneyError ErrorType = "midjourney_error"
|
||||
ErrorTypeGeminiError ErrorType = "gemini_error"
|
||||
ErrorTypeRerankError ErrorType = "rerank_error"
|
||||
ErrorTypeUpstreamError ErrorType = "upstream_error"
|
||||
)
|
||||
|
||||
type ErrorCode string
|
||||
@@ -62,6 +64,7 @@ const (
|
||||
ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"
|
||||
ErrorCodeBadResponse ErrorCode = "bad_response"
|
||||
ErrorCodeBadResponseBody ErrorCode = "bad_response_body"
|
||||
ErrorCodeEmptyResponse ErrorCode = "empty_response"
|
||||
|
||||
// sql error
|
||||
ErrorCodeQueryDataError ErrorCode = "query_data_error"
|
||||
@@ -73,11 +76,13 @@ const (
|
||||
)
|
||||
|
||||
type NewAPIError struct {
|
||||
Err error
|
||||
RelayError any
|
||||
ErrorType ErrorType
|
||||
errorCode ErrorCode
|
||||
StatusCode int
|
||||
Err error
|
||||
RelayError any
|
||||
skipRetry bool
|
||||
recordErrorLog *bool
|
||||
errorType ErrorType
|
||||
errorCode ErrorCode
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (e *NewAPIError) GetErrorCode() ErrorCode {
|
||||
@@ -87,6 +92,13 @@ func (e *NewAPIError) GetErrorCode() ErrorCode {
|
||||
return e.errorCode
|
||||
}
|
||||
|
||||
func (e *NewAPIError) GetErrorType() ErrorType {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.errorType
|
||||
}
|
||||
|
||||
func (e *NewAPIError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
@@ -98,19 +110,30 @@ func (e *NewAPIError) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *NewAPIError) MaskSensitiveError() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.Err == nil {
|
||||
return string(e.errorCode)
|
||||
}
|
||||
return common.MaskSensitiveInfo(e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *NewAPIError) SetMessage(message string) {
|
||||
e.Err = errors.New(message)
|
||||
}
|
||||
|
||||
func (e *NewAPIError) ToOpenAIError() OpenAIError {
|
||||
switch e.ErrorType {
|
||||
var result OpenAIError
|
||||
switch e.errorType {
|
||||
case ErrorTypeOpenAIError:
|
||||
if openAIError, ok := e.RelayError.(OpenAIError); ok {
|
||||
return openAIError
|
||||
result = openAIError
|
||||
}
|
||||
case ErrorTypeClaudeError:
|
||||
if claudeError, ok := e.RelayError.(ClaudeError); ok {
|
||||
return OpenAIError{
|
||||
result = OpenAIError{
|
||||
Message: e.Error(),
|
||||
Type: claudeError.Type,
|
||||
Param: "",
|
||||
@@ -118,85 +141,122 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
|
||||
}
|
||||
}
|
||||
}
|
||||
return OpenAIError{
|
||||
result = OpenAIError{
|
||||
Message: e.Error(),
|
||||
Type: string(e.ErrorType),
|
||||
Type: string(e.errorType),
|
||||
Param: "",
|
||||
Code: e.errorCode,
|
||||
}
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
return result
|
||||
}
|
||||
|
||||
func (e *NewAPIError) ToClaudeError() ClaudeError {
|
||||
switch e.ErrorType {
|
||||
var result ClaudeError
|
||||
switch e.errorType {
|
||||
case ErrorTypeOpenAIError:
|
||||
openAIError := e.RelayError.(OpenAIError)
|
||||
return ClaudeError{
|
||||
result = ClaudeError{
|
||||
Message: e.Error(),
|
||||
Type: fmt.Sprintf("%v", openAIError.Code),
|
||||
}
|
||||
case ErrorTypeClaudeError:
|
||||
return e.RelayError.(ClaudeError)
|
||||
result = e.RelayError.(ClaudeError)
|
||||
default:
|
||||
return ClaudeError{
|
||||
result = ClaudeError{
|
||||
Message: e.Error(),
|
||||
Type: string(e.ErrorType),
|
||||
Type: string(e.errorType),
|
||||
}
|
||||
}
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
return result
|
||||
}
|
||||
|
||||
func NewError(err error, errorCode ErrorCode) *NewAPIError {
|
||||
return &NewAPIError{
|
||||
type NewAPIErrorOptions func(*NewAPIError)
|
||||
|
||||
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
e := &NewAPIError{
|
||||
Err: err,
|
||||
RelayError: nil,
|
||||
ErrorType: ErrorTypeNewAPIError,
|
||||
errorType: ErrorTypeNewAPIError,
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
errorCode: errorCode,
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(e)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
|
||||
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
openaiError := OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: string(errorCode),
|
||||
}
|
||||
return WithOpenAIError(openaiError, statusCode)
|
||||
return WithOpenAIError(openaiError, statusCode, ops...)
|
||||
}
|
||||
|
||||
func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
|
||||
return &NewAPIError{
|
||||
func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
openaiError := OpenAIError{
|
||||
Type: string(errorCode),
|
||||
}
|
||||
return WithOpenAIError(openaiError, statusCode, ops...)
|
||||
}
|
||||
|
||||
func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
e := &NewAPIError{
|
||||
Err: err,
|
||||
RelayError: OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: string(errorCode),
|
||||
},
|
||||
ErrorType: ErrorTypeNewAPIError,
|
||||
errorType: ErrorTypeNewAPIError,
|
||||
StatusCode: statusCode,
|
||||
errorCode: errorCode,
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(e)
|
||||
}
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError {
|
||||
func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
code, ok := openAIError.Code.(string)
|
||||
if !ok {
|
||||
code = fmt.Sprintf("%v", openAIError.Code)
|
||||
}
|
||||
return &NewAPIError{
|
||||
if openAIError.Type == "" {
|
||||
openAIError.Type = "upstream_error"
|
||||
}
|
||||
e := &NewAPIError{
|
||||
RelayError: openAIError,
|
||||
ErrorType: ErrorTypeOpenAIError,
|
||||
errorType: ErrorTypeOpenAIError,
|
||||
StatusCode: statusCode,
|
||||
Err: errors.New(openAIError.Message),
|
||||
errorCode: ErrorCode(code),
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(e)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError {
|
||||
return &NewAPIError{
|
||||
func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
if claudeError.Type == "" {
|
||||
claudeError.Type = "upstream_error"
|
||||
}
|
||||
e := &NewAPIError{
|
||||
RelayError: claudeError,
|
||||
ErrorType: ErrorTypeClaudeError,
|
||||
errorType: ErrorTypeClaudeError,
|
||||
StatusCode: statusCode,
|
||||
Err: errors.New(claudeError.Message),
|
||||
errorCode: ErrorCode(claudeError.Type),
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(e)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func IsChannelError(err *NewAPIError) bool {
|
||||
@@ -206,10 +266,33 @@ func IsChannelError(err *NewAPIError) bool {
|
||||
return strings.HasPrefix(string(err.errorCode), "channel:")
|
||||
}
|
||||
|
||||
func IsLocalError(err *NewAPIError) bool {
|
||||
func IsSkipRetryError(err *NewAPIError) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return err.ErrorType == ErrorTypeNewAPIError
|
||||
return err.skipRetry
|
||||
}
|
||||
|
||||
func ErrOptionWithSkipRetry() NewAPIErrorOptions {
|
||||
return func(e *NewAPIError) {
|
||||
e.skipRetry = true
|
||||
}
|
||||
}
|
||||
|
||||
func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
|
||||
return func(e *NewAPIError) {
|
||||
e.recordErrorLog = common.GetPointer(false)
|
||||
}
|
||||
}
|
||||
|
||||
func IsRecordErrorLog(e *NewAPIError) bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
if e.recordErrorLog == nil {
|
||||
// default to true if not set
|
||||
return true
|
||||
}
|
||||
return *e.recordErrorLog
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
"lucide-react": "^0.511.0",
|
||||
"marked": "^4.1.1",
|
||||
"mermaid": "^11.6.0",
|
||||
"qrcode.react": "^4.2.0",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-dropzone": "^14.2.3",
|
||||
@@ -1492,6 +1493,8 @@
|
||||
|
||||
"punycode": ["punycode@2.3.1", "", {}, "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg=="],
|
||||
|
||||
"qrcode.react": ["qrcode.react@4.2.0", "", { "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-QpgqWi8rD9DsS9EP3z7BT+5lY5SFhsqGjpgW5DY/i3mK4M9DTBNz3ErMi8BWYEfI3L0d8GIbGmcdFAS1uIRGjA=="],
|
||||
|
||||
"quansync": ["quansync@0.2.10", "", {}, "sha512-t41VRkMYbkHyCYmOvx/6URnN80H7k4X0lLdBMGsz+maAwrJQYB1djpV6vHrQIBE0WBSGqhtEHrK9U3DWWH8v7A=="],
|
||||
|
||||
"query-string": ["query-string@9.2.0", "", { "dependencies": { "decode-uri-component": "^0.4.1", "filter-obj": "^5.1.0", "split-on-first": "^3.0.0" } }, "sha512-YIRhrHujoQxhexwRLxfy3VSjOXmvZRd2nyw1PwL1UUqZ/ys1dEZd1+NSgXkne2l/4X/7OXkigEAuhTX0g/ivJQ=="],
|
||||
@@ -1502,7 +1505,7 @@
|
||||
|
||||
"rc-checkbox": ["rc-checkbox@3.5.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "^2.3.2", "rc-util": "^5.25.2" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-aOAQc3E98HteIIsSqm6Xk2FPKIER6+5vyEFMZfo73TqM+VVAIqOkHoPjgKLqSNtVLWScoaM7vY2ZrGEheI79yg=="],
|
||||
|
||||
"rc-collapse": ["rc-collapse@3.9.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-swDdz4QZ4dFTo4RAUMLL50qP0EY62N2kvmk2We5xYdRwcRn8WcYtuetCJpwpaCbUfUt5+huLpVxhvmnK+PHrkA=="],
|
||||
"rc-collapse": ["rc-collapse@4.0.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-SwoOByE39/3oIokDs/BnkqI+ltwirZbP8HZdq1/3SkPSBi7xDdvWHTp7cpNI9ullozkR6mwTWQi6/E/9huQVrA=="],
|
||||
|
||||
"rc-dialog": ["rc-dialog@9.6.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "@rc-component/portal": "^1.0.0-8", "classnames": "^2.2.6", "rc-motion": "^2.3.0", "rc-util": "^5.21.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-ApoVi9Z8PaCQg6FsUzS8yvBEQy0ZL2PkuvAgrmohPkN3okps5WZ5WQWPc1RNuiOKaAYv8B97ACdsFU5LizzCqg=="],
|
||||
|
||||
@@ -1946,8 +1949,6 @@
|
||||
|
||||
"@lobehub/ui/lucide-react": ["lucide-react@0.484.0", "", { "peerDependencies": { "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-oZy8coK9kZzvqhSgfbGkPtTgyjpBvs3ukLgDPv14dSOZtBtboryWF5o8i3qen7QbGg7JhiJBz5mK1p8YoMZTLQ=="],
|
||||
|
||||
"@lobehub/ui/rc-collapse": ["rc-collapse@4.0.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-SwoOByE39/3oIokDs/BnkqI+ltwirZbP8HZdq1/3SkPSBi7xDdvWHTp7cpNI9ullozkR6mwTWQi6/E/9huQVrA=="],
|
||||
|
||||
"@radix-ui/react-dismissable-layer/@radix-ui/react-compose-refs": ["@radix-ui/react-compose-refs@1.0.0", "", { "dependencies": { "@babel/runtime": "^7.13.10" }, "peerDependencies": { "react": "^16.8 || ^17.0 || ^18.0" } }, "sha512-0KaSv6sx787/hK3eF53iOkiSLwAGlFMx5lotrqD2pTjB18KbybKoEIgkNZTKC60YECDQTKGTRcDBILwZVqVKvA=="],
|
||||
|
||||
"@radix-ui/react-popper/@floating-ui/react-dom": ["@floating-ui/react-dom@0.7.2", "", { "dependencies": { "@floating-ui/dom": "^0.5.3", "use-isomorphic-layout-effect": "^1.1.1" }, "peerDependencies": { "react": ">=16.8.0", "react-dom": ">=16.8.0" } }, "sha512-1T0sJcpHgX/u4I1OzIEhlcrvkUN8ln39nz7fMoE/2HDHrPiMFoOGR7++GYyfUmIQHkkrTinaeQsO3XWubjSvGg=="],
|
||||
@@ -1964,6 +1965,8 @@
|
||||
|
||||
"@visactor/vrender-kits/roughjs": ["roughjs@4.5.2", "", { "dependencies": { "path-data-parser": "^0.1.0", "points-on-curve": "^0.2.0", "points-on-path": "^0.2.1" } }, "sha512-2xSlLDKdsWyFxrveYWk9YQ/Y9UfK38EAMRNkYkMqYBJvPX8abCa9PN0x3w02H8Oa6/0bcZICJU+U95VumPqseg=="],
|
||||
|
||||
"antd/rc-collapse": ["rc-collapse@3.9.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-swDdz4QZ4dFTo4RAUMLL50qP0EY62N2kvmk2We5xYdRwcRn8WcYtuetCJpwpaCbUfUt5+huLpVxhvmnK+PHrkA=="],
|
||||
|
||||
"antd/scroll-into-view-if-needed": ["scroll-into-view-if-needed@3.1.0", "", { "dependencies": { "compute-scroll-into-view": "^3.0.2" } }, "sha512-49oNpRjWRvnU8NyGVmUaYG4jtTkNonFZI86MmGRDqBphEK2EXT9gdEUoQPZhuBM8yWHxCWbobltqYO5M4XrUvQ=="],
|
||||
|
||||
"chokidar/glob-parent": ["glob-parent@5.1.2", "", { "dependencies": { "is-glob": "^4.0.1" } }, "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow=="],
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
"lucide-react": "^0.511.0",
|
||||
"marked": "^4.1.1",
|
||||
"mermaid": "^11.6.0",
|
||||
"qrcode.react": "^4.2.0",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-dropzone": "^14.2.3",
|
||||
|
||||
@@ -50,6 +50,7 @@ import { IconGithubLogo, IconMail, IconLock } from '@douyinfe/semi-icons';
|
||||
import OIDCIcon from '../common/logo/OIDCIcon.js';
|
||||
import WeChatIcon from '../common/logo/WeChatIcon.js';
|
||||
import LinuxDoIcon from '../common/logo/LinuxDoIcon.js';
|
||||
import TwoFAVerification from './TwoFAVerification.js';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const LoginForm = () => {
|
||||
@@ -78,6 +79,7 @@ const LoginForm = () => {
|
||||
const [resetPasswordLoading, setResetPasswordLoading] = useState(false);
|
||||
const [otherLoginOptionsLoading, setOtherLoginOptionsLoading] = useState(false);
|
||||
const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false);
|
||||
const [showTwoFA, setShowTwoFA] = useState(false);
|
||||
|
||||
const logo = getLogo();
|
||||
const systemName = getSystemName();
|
||||
@@ -162,6 +164,13 @@ const LoginForm = () => {
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
// 检查是否需要2FA验证
|
||||
if (data && data.require_2fa) {
|
||||
setShowTwoFA(true);
|
||||
setLoginLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
setUserData(data);
|
||||
updateAPI();
|
||||
@@ -280,6 +289,21 @@ const LoginForm = () => {
|
||||
setOtherLoginOptionsLoading(false);
|
||||
};
|
||||
|
||||
// 2FA验证成功处理
|
||||
const handle2FASuccess = (data) => {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
setUserData(data);
|
||||
updateAPI();
|
||||
showSuccess('登录成功!');
|
||||
navigate('/console');
|
||||
};
|
||||
|
||||
// 返回登录页面
|
||||
const handleBackToLogin = () => {
|
||||
setShowTwoFA(false);
|
||||
setInputs({ username: '', password: '', wechat_verification_code: '' });
|
||||
};
|
||||
|
||||
const renderOAuthOptions = () => {
|
||||
return (
|
||||
<div className="flex flex-col items-center">
|
||||
@@ -537,6 +561,35 @@ const LoginForm = () => {
|
||||
);
|
||||
};
|
||||
|
||||
// 2FA验证弹窗
|
||||
const render2FAModal = () => {
|
||||
return (
|
||||
<Modal
|
||||
title={
|
||||
<div className="flex items-center">
|
||||
<div className="w-8 h-8 rounded-full bg-green-100 dark:bg-green-900 flex items-center justify-center mr-3">
|
||||
<svg className="w-4 h-4 text-green-600 dark:text-green-400" fill="currentColor" viewBox="0 0 20 20">
|
||||
<path fillRule="evenodd" d="M6 8a2 2 0 11-4 0 2 2 0 014 0zM8 7a1 1 0 100 2h8a1 1 0 100-2H8zM6 14a2 2 0 11-4 0 2 2 0 014 0zM8 13a1 1 0 100 2h8a1 1 0 100-2H8z" clipRule="evenodd" />
|
||||
</svg>
|
||||
</div>
|
||||
两步验证
|
||||
</div>
|
||||
}
|
||||
visible={showTwoFA}
|
||||
onCancel={handleBackToLogin}
|
||||
footer={null}
|
||||
width={450}
|
||||
centered
|
||||
>
|
||||
<TwoFAVerification
|
||||
onSuccess={handle2FASuccess}
|
||||
onBack={handleBackToLogin}
|
||||
isModal={true}
|
||||
/>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="relative overflow-hidden bg-gray-100 flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
|
||||
{/* 背景模糊晕染球 */}
|
||||
@@ -547,6 +600,7 @@ const LoginForm = () => {
|
||||
? renderEmailLoginForm()
|
||||
: renderOAuthOptions()}
|
||||
{renderWeChatLoginModal()}
|
||||
{render2FAModal()}
|
||||
|
||||
{turnstileEnabled && (
|
||||
<div className="flex justify-center mt-6">
|
||||
|
||||
230
web/src/components/auth/TwoFAVerification.js
Normal file
230
web/src/components/auth/TwoFAVerification.js
Normal file
@@ -0,0 +1,230 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import { API, showError, showSuccess } from '../../helpers';
|
||||
import { Button, Card, Divider, Form, Input, Typography } from '@douyinfe/semi-ui';
|
||||
import React, { useState } from 'react';
|
||||
|
||||
const { Title, Text, Paragraph } = Typography;
|
||||
|
||||
const TwoFAVerification = ({ onSuccess, onBack, isModal = false }) => {
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [useBackupCode, setUseBackupCode] = useState(false);
|
||||
const [verificationCode, setVerificationCode] = useState('');
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!verificationCode) {
|
||||
showError('请输入验证码');
|
||||
return;
|
||||
}
|
||||
// Validate code format
|
||||
if (useBackupCode && verificationCode.length !== 8) {
|
||||
showError('备用码必须是8位');
|
||||
return;
|
||||
} else if (!useBackupCode && !/^\d{6}$/.test(verificationCode)) {
|
||||
showError('验证码必须是6位数字');
|
||||
return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.post('/api/user/login/2fa', {
|
||||
code: verificationCode
|
||||
});
|
||||
|
||||
if (res.data.success) {
|
||||
showSuccess('登录成功');
|
||||
// 保存用户信息到本地存储
|
||||
localStorage.setItem('user', JSON.stringify(res.data.data));
|
||||
if (onSuccess) {
|
||||
onSuccess(res.data.data);
|
||||
}
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError('验证失败,请重试');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleKeyPress = (e) => {
|
||||
if (e.key === 'Enter') {
|
||||
handleSubmit();
|
||||
}
|
||||
};
|
||||
|
||||
if (isModal) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<Paragraph className="text-gray-600 dark:text-gray-300">
|
||||
请输入认证器应用显示的验证码完成登录
|
||||
</Paragraph>
|
||||
|
||||
<Form onSubmit={handleSubmit}>
|
||||
<Form.Input
|
||||
field="code"
|
||||
label={useBackupCode ? "备用码" : "验证码"}
|
||||
placeholder={useBackupCode ? "请输入8位备用码" : "请输入6位验证码"}
|
||||
value={verificationCode}
|
||||
onChange={setVerificationCode}
|
||||
onKeyPress={handleKeyPress}
|
||||
size="large"
|
||||
style={{ marginBottom: 16 }}
|
||||
autoFocus
|
||||
/>
|
||||
|
||||
<Button
|
||||
htmlType="submit"
|
||||
type="primary"
|
||||
loading={loading}
|
||||
block
|
||||
size="large"
|
||||
style={{ marginBottom: 16 }}
|
||||
>
|
||||
验证并登录
|
||||
</Button>
|
||||
</Form>
|
||||
|
||||
<Divider />
|
||||
|
||||
<div style={{ textAlign: 'center' }}>
|
||||
<Button
|
||||
theme="borderless"
|
||||
type="tertiary"
|
||||
onClick={() => {
|
||||
setUseBackupCode(!useBackupCode);
|
||||
setVerificationCode('');
|
||||
}}
|
||||
style={{ marginRight: 16, color: '#1890ff', padding: 0 }}
|
||||
>
|
||||
{useBackupCode ? '使用认证器验证码' : '使用备用码'}
|
||||
</Button>
|
||||
|
||||
{onBack && (
|
||||
<Button
|
||||
theme="borderless"
|
||||
type="tertiary"
|
||||
onClick={onBack}
|
||||
style={{ color: '#1890ff', padding: 0 }}
|
||||
>
|
||||
返回登录
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="bg-gray-50 dark:bg-gray-800 rounded-lg p-3">
|
||||
<Text size="small" type="secondary">
|
||||
<strong>提示:</strong>
|
||||
<br />
|
||||
• 验证码每30秒更新一次
|
||||
<br />
|
||||
• 如果无法获取验证码,请使用备用码
|
||||
<br />
|
||||
• 每个备用码只能使用一次
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
minHeight: '60vh'
|
||||
}}>
|
||||
<Card style={{ width: 400, padding: 24 }}>
|
||||
<div style={{ textAlign: 'center', marginBottom: 24 }}>
|
||||
<Title heading={3}>两步验证</Title>
|
||||
<Paragraph type="secondary">
|
||||
请输入认证器应用显示的验证码完成登录
|
||||
</Paragraph>
|
||||
</div>
|
||||
|
||||
<Form onSubmit={handleSubmit}>
|
||||
<Form.Input
|
||||
field="code"
|
||||
label={useBackupCode ? "备用码" : "验证码"}
|
||||
placeholder={useBackupCode ? "请输入8位备用码" : "请输入6位验证码"}
|
||||
value={verificationCode}
|
||||
onChange={setVerificationCode}
|
||||
onKeyPress={handleKeyPress}
|
||||
size="large"
|
||||
style={{ marginBottom: 16 }}
|
||||
autoFocus
|
||||
/>
|
||||
|
||||
<Button
|
||||
htmlType="submit"
|
||||
type="primary"
|
||||
loading={loading}
|
||||
block
|
||||
size="large"
|
||||
style={{ marginBottom: 16 }}
|
||||
>
|
||||
验证并登录
|
||||
</Button>
|
||||
</Form>
|
||||
|
||||
<Divider />
|
||||
|
||||
<div style={{ textAlign: 'center' }}>
|
||||
<Button
|
||||
theme="borderless"
|
||||
type="tertiary"
|
||||
onClick={() => {
|
||||
setUseBackupCode(!useBackupCode);
|
||||
setVerificationCode('');
|
||||
}}
|
||||
style={{ marginRight: 16, color: '#1890ff', padding: 0 }}
|
||||
>
|
||||
{useBackupCode ? '使用认证器验证码' : '使用备用码'}
|
||||
</Button>
|
||||
|
||||
{onBack && (
|
||||
<Button
|
||||
theme="borderless"
|
||||
type="tertiary"
|
||||
onClick={onBack}
|
||||
style={{ color: '#1890ff', padding: 0 }}
|
||||
>
|
||||
返回登录
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div style={{ marginTop: 24, padding: 16, background: '#f6f8fa', borderRadius: 6 }}>
|
||||
<Text size="small" type="secondary">
|
||||
<strong>提示:</strong>
|
||||
<br />
|
||||
• 验证码每30秒更新一次
|
||||
<br />
|
||||
• 如果无法获取验证码,请使用备用码
|
||||
<br />
|
||||
• 每个备用码只能使用一次
|
||||
</Text>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default TwoFAVerification;
|
||||
622
web/src/components/common/JSONEditor.js
Normal file
622
web/src/components/common/JSONEditor.js
Normal file
@@ -0,0 +1,622 @@
|
||||
import React, { useState, useEffect, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
Space,
|
||||
Button,
|
||||
Form,
|
||||
Card,
|
||||
Typography,
|
||||
Banner,
|
||||
Row,
|
||||
Col,
|
||||
InputNumber,
|
||||
Switch,
|
||||
Select,
|
||||
Input,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IconCode,
|
||||
IconEdit,
|
||||
IconPlus,
|
||||
IconDelete,
|
||||
IconSetting,
|
||||
} from '@douyinfe/semi-icons';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const JSONEditor = ({
|
||||
value = '',
|
||||
onChange,
|
||||
field,
|
||||
label,
|
||||
placeholder,
|
||||
extraText,
|
||||
showClear = true,
|
||||
template,
|
||||
templateLabel,
|
||||
editorType = 'keyValue', // keyValue, object, region
|
||||
autosize = true,
|
||||
rules = [],
|
||||
formApi = null,
|
||||
...props
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
// 初始化JSON数据
|
||||
const [jsonData, setJsonData] = useState(() => {
|
||||
// 初始化时解析JSON数据
|
||||
if (value && value.trim()) {
|
||||
try {
|
||||
const parsed = JSON.parse(value);
|
||||
return parsed;
|
||||
} catch (error) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
return {};
|
||||
});
|
||||
|
||||
// 根据键数量决定默认编辑模式
|
||||
const [editMode, setEditMode] = useState(() => {
|
||||
// 如果初始JSON数据的键数量大于10个,则默认使用手动模式
|
||||
if (value && value.trim()) {
|
||||
try {
|
||||
const parsed = JSON.parse(value);
|
||||
const keyCount = Object.keys(parsed).length;
|
||||
return keyCount > 10 ? 'manual' : 'visual';
|
||||
} catch (error) {
|
||||
// JSON无效时默认显示手动编辑模式
|
||||
return 'manual';
|
||||
}
|
||||
}
|
||||
return 'visual';
|
||||
});
|
||||
const [jsonError, setJsonError] = useState('');
|
||||
|
||||
// 数据同步 - 当value变化时总是更新jsonData(如果JSON有效)
|
||||
useEffect(() => {
|
||||
try {
|
||||
const parsed = value && value.trim() ? JSON.parse(value) : {};
|
||||
setJsonData(parsed);
|
||||
setJsonError('');
|
||||
} catch (error) {
|
||||
console.log('JSON解析失败:', error.message);
|
||||
setJsonError(error.message);
|
||||
// JSON格式错误时不更新jsonData
|
||||
}
|
||||
}, [value]);
|
||||
|
||||
|
||||
// 处理可视化编辑的数据变化
|
||||
const handleVisualChange = useCallback((newData) => {
|
||||
setJsonData(newData);
|
||||
setJsonError('');
|
||||
const jsonString = Object.keys(newData).length === 0 ? '' : JSON.stringify(newData, null, 2);
|
||||
|
||||
// 通过formApi设置值(如果提供的话)
|
||||
if (formApi && field) {
|
||||
formApi.setValue(field, jsonString);
|
||||
}
|
||||
|
||||
onChange?.(jsonString);
|
||||
}, [onChange, formApi, field]);
|
||||
|
||||
// 处理手动编辑的数据变化
|
||||
const handleManualChange = useCallback((newValue) => {
|
||||
onChange?.(newValue);
|
||||
// 验证JSON格式
|
||||
if (newValue && newValue.trim()) {
|
||||
try {
|
||||
const parsed = JSON.parse(newValue);
|
||||
setJsonError('');
|
||||
// 预先准备可视化数据,但不立即应用
|
||||
// 这样切换到可视化模式时数据已经准备好了
|
||||
} catch (error) {
|
||||
setJsonError(error.message);
|
||||
}
|
||||
} else {
|
||||
setJsonError('');
|
||||
}
|
||||
}, [onChange]);
|
||||
|
||||
// 切换编辑模式
|
||||
const toggleEditMode = useCallback(() => {
|
||||
if (editMode === 'visual') {
|
||||
// 从可视化模式切换到手动模式
|
||||
setEditMode('manual');
|
||||
} else {
|
||||
// 从手动模式切换到可视化模式,需要验证JSON
|
||||
try {
|
||||
const parsed = value && value.trim() ? JSON.parse(value) : {};
|
||||
setJsonData(parsed);
|
||||
setJsonError('');
|
||||
setEditMode('visual');
|
||||
} catch (error) {
|
||||
setJsonError(error.message);
|
||||
// JSON格式错误时不切换模式
|
||||
return;
|
||||
}
|
||||
}
|
||||
}, [editMode, value]);
|
||||
|
||||
// 添加键值对
|
||||
const addKeyValue = useCallback(() => {
|
||||
const newData = { ...jsonData };
|
||||
const keys = Object.keys(newData);
|
||||
let newKey = 'key';
|
||||
let counter = 1;
|
||||
while (newData.hasOwnProperty(newKey)) {
|
||||
newKey = `key${counter}`;
|
||||
counter++;
|
||||
}
|
||||
newData[newKey] = '';
|
||||
handleVisualChange(newData);
|
||||
}, [jsonData, handleVisualChange]);
|
||||
|
||||
// 删除键值对
|
||||
const removeKeyValue = useCallback((keyToRemove) => {
|
||||
const newData = { ...jsonData };
|
||||
delete newData[keyToRemove];
|
||||
handleVisualChange(newData);
|
||||
}, [jsonData, handleVisualChange]);
|
||||
|
||||
// 更新键名
|
||||
const updateKey = useCallback((oldKey, newKey) => {
|
||||
if (oldKey === newKey) return;
|
||||
const newData = { ...jsonData };
|
||||
const value = newData[oldKey];
|
||||
delete newData[oldKey];
|
||||
newData[newKey] = value;
|
||||
handleVisualChange(newData);
|
||||
}, [jsonData, handleVisualChange]);
|
||||
|
||||
// 更新值
|
||||
const updateValue = useCallback((key, newValue) => {
|
||||
const newData = { ...jsonData };
|
||||
newData[key] = newValue;
|
||||
handleVisualChange(newData);
|
||||
}, [jsonData, handleVisualChange]);
|
||||
|
||||
// 填入模板
|
||||
const fillTemplate = useCallback(() => {
|
||||
if (template) {
|
||||
const templateString = JSON.stringify(template, null, 2);
|
||||
|
||||
// 通过formApi设置值(如果提供的话)
|
||||
if (formApi && field) {
|
||||
formApi.setValue(field, templateString);
|
||||
}
|
||||
|
||||
// 无论哪种模式都要更新值
|
||||
onChange?.(templateString);
|
||||
|
||||
// 如果是可视化模式,同时更新jsonData
|
||||
if (editMode === 'visual') {
|
||||
setJsonData(template);
|
||||
}
|
||||
|
||||
// 清除错误状态
|
||||
setJsonError('');
|
||||
}
|
||||
}, [template, onChange, editMode, formApi, field]);
|
||||
|
||||
// 渲染键值对编辑器
|
||||
const renderKeyValueEditor = () => {
|
||||
if (typeof jsonData !== 'object' || jsonData === null) {
|
||||
return (
|
||||
<div className="text-center py-6 px-4">
|
||||
<div className="text-gray-400 mb-2">
|
||||
<IconCode size={32} />
|
||||
</div>
|
||||
<Text type="tertiary" className="text-gray-500 text-sm">
|
||||
{t('无效的JSON数据,请检查格式')}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const entries = Object.entries(jsonData);
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
{entries.length === 0 && (
|
||||
<div className="text-center py-6 px-4">
|
||||
<div className="text-gray-400 mb-2">
|
||||
<IconCode size={32} />
|
||||
</div>
|
||||
<Text type="tertiary" className="text-gray-500 text-sm">
|
||||
{t('暂无数据,点击下方按钮添加键值对')}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{entries.map(([key, value], index) => (
|
||||
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
|
||||
<Row gutter={12} align="middle">
|
||||
<Col span={10}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('键名')}</Text>
|
||||
<Input
|
||||
placeholder={t('键名')}
|
||||
value={key}
|
||||
onChange={(newKey) => updateKey(key, newKey)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={11}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('值')}</Text>
|
||||
<Input
|
||||
placeholder={t('值')}
|
||||
value={value}
|
||||
onChange={(newValue) => updateValue(key, newValue)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={3}>
|
||||
<div className="flex justify-center pt-4">
|
||||
<Button
|
||||
icon={<IconDelete />}
|
||||
type="danger"
|
||||
theme="borderless"
|
||||
size="small"
|
||||
onClick={() => removeKeyValue(key)}
|
||||
className="hover:bg-red-50"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
</Row>
|
||||
</Card>
|
||||
))}
|
||||
|
||||
<div className="flex justify-center pt-1">
|
||||
<Button
|
||||
icon={<IconPlus />}
|
||||
onClick={addKeyValue}
|
||||
size="small"
|
||||
theme="solid"
|
||||
type="primary"
|
||||
className="shadow-sm hover:shadow-md transition-shadow px-4"
|
||||
>
|
||||
{t('添加键值对')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 渲染对象编辑器(用于复杂JSON)
|
||||
const renderObjectEditor = () => {
|
||||
const entries = Object.entries(jsonData);
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
{entries.length === 0 && (
|
||||
<div className="text-center py-6 px-4">
|
||||
<div className="text-gray-400 mb-2">
|
||||
<IconSetting size={32} />
|
||||
</div>
|
||||
<Text type="tertiary" className="text-gray-500 text-sm">
|
||||
{t('暂无参数,点击下方按钮添加请求参数')}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{entries.map(([key, value], index) => (
|
||||
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
|
||||
<Row gutter={12} align="middle">
|
||||
<Col span={8}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('参数名')}</Text>
|
||||
<Input
|
||||
placeholder={t('参数名')}
|
||||
value={key}
|
||||
onChange={(newKey) => updateKey(key, newKey)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={13}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('参数值')} ({typeof value})</Text>
|
||||
{renderValueInput(key, value)}
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={3}>
|
||||
<div className="flex justify-center pt-4">
|
||||
<Button
|
||||
icon={<IconDelete />}
|
||||
type="danger"
|
||||
theme="borderless"
|
||||
size="small"
|
||||
onClick={() => removeKeyValue(key)}
|
||||
className="hover:bg-red-50"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
</Row>
|
||||
</Card>
|
||||
))}
|
||||
|
||||
<div className="flex justify-center pt-1">
|
||||
<Button
|
||||
icon={<IconPlus />}
|
||||
onClick={addKeyValue}
|
||||
size="small"
|
||||
theme="solid"
|
||||
type="primary"
|
||||
className="shadow-sm hover:shadow-md transition-shadow px-4"
|
||||
>
|
||||
{t('添加参数')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 渲染参数值输入控件
|
||||
const renderValueInput = (key, value) => {
|
||||
const valueType = typeof value;
|
||||
|
||||
if (valueType === 'boolean') {
|
||||
return (
|
||||
<div className="flex items-center">
|
||||
<Switch
|
||||
checked={value}
|
||||
onChange={(newValue) => updateValue(key, newValue)}
|
||||
size="small"
|
||||
/>
|
||||
<Text type="tertiary" size="small" className="ml-2">
|
||||
{value ? t('true') : t('false')}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (valueType === 'number') {
|
||||
return (
|
||||
<InputNumber
|
||||
value={value}
|
||||
onChange={(newValue) => updateValue(key, newValue)}
|
||||
size="small"
|
||||
style={{ width: '100%' }}
|
||||
step={key === 'temperature' ? 0.1 : 1}
|
||||
precision={key === 'temperature' ? 2 : 0}
|
||||
placeholder={t('输入数字')}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// 字符串类型或其他类型
|
||||
return (
|
||||
<Input
|
||||
placeholder={t('参数值')}
|
||||
value={String(value)}
|
||||
onChange={(newValue) => {
|
||||
// 尝试转换为适当的类型
|
||||
let convertedValue = newValue;
|
||||
if (newValue === 'true') convertedValue = true;
|
||||
else if (newValue === 'false') convertedValue = false;
|
||||
else if (!isNaN(newValue) && newValue !== '' && newValue !== '0') {
|
||||
convertedValue = Number(newValue);
|
||||
}
|
||||
|
||||
updateValue(key, convertedValue);
|
||||
}}
|
||||
size="small"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
// 渲染区域编辑器(特殊格式)
|
||||
const renderRegionEditor = () => {
|
||||
const entries = Object.entries(jsonData);
|
||||
const defaultEntry = entries.find(([key]) => key === 'default');
|
||||
const modelEntries = entries.filter(([key]) => key !== 'default');
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
{/* 默认区域 */}
|
||||
<Card className="!p-2 !border-blue-200 !bg-blue-50">
|
||||
<div className="flex items-center mb-1">
|
||||
<Text strong size="small" className="text-blue-700">{t('默认区域')}</Text>
|
||||
</div>
|
||||
<Input
|
||||
placeholder={t('默认区域,如: us-central1')}
|
||||
value={defaultEntry ? defaultEntry[1] : ''}
|
||||
onChange={(value) => updateValue('default', value)}
|
||||
size="small"
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* 模型专用区域 */}
|
||||
<div className="space-y-1">
|
||||
<Text strong size="small">{t('模型专用区域')}</Text>
|
||||
{modelEntries.map(([modelName, region], index) => (
|
||||
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
|
||||
<Row gutter={12} align="middle">
|
||||
<Col span={10}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('模型名称')}</Text>
|
||||
<Input
|
||||
placeholder={t('模型名称')}
|
||||
value={modelName}
|
||||
onChange={(newKey) => updateKey(modelName, newKey)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={11}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('区域')}</Text>
|
||||
<Input
|
||||
placeholder={t('区域')}
|
||||
value={region}
|
||||
onChange={(newValue) => updateValue(modelName, newValue)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={3}>
|
||||
<div className="flex justify-center pt-4">
|
||||
<Button
|
||||
icon={<IconDelete />}
|
||||
type="danger"
|
||||
theme="borderless"
|
||||
size="small"
|
||||
onClick={() => removeKeyValue(modelName)}
|
||||
className="hover:bg-red-50"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
</Row>
|
||||
</Card>
|
||||
))}
|
||||
|
||||
<div className="flex justify-center pt-1">
|
||||
<Button
|
||||
icon={<IconPlus />}
|
||||
onClick={addKeyValue}
|
||||
size="small"
|
||||
theme="solid"
|
||||
type="primary"
|
||||
className="shadow-sm hover:shadow-md transition-shadow px-4"
|
||||
>
|
||||
{t('添加模型区域')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 渲染可视化编辑器
|
||||
const renderVisualEditor = () => {
|
||||
switch (editorType) {
|
||||
case 'region':
|
||||
return renderRegionEditor();
|
||||
case 'object':
|
||||
return renderObjectEditor();
|
||||
case 'keyValue':
|
||||
default:
|
||||
return renderKeyValueEditor();
|
||||
}
|
||||
};
|
||||
|
||||
const hasJsonError = jsonError && jsonError.trim() !== '';
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
{/* Label统一显示在上方 */}
|
||||
{label && (
|
||||
<div className="flex items-center">
|
||||
<Text className="text-sm font-medium text-gray-900">{label}</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 编辑模式切换 */}
|
||||
<div className="flex items-center justify-between p-2 bg-gray-50 rounded-md">
|
||||
<div className="flex items-center gap-2">
|
||||
{editMode === 'visual' && (
|
||||
<Text type="tertiary" size="small" className="bg-blue-100 text-blue-700 px-2 py-0.5 rounded text-xs">
|
||||
{t('可视化模式')}
|
||||
</Text>
|
||||
)}
|
||||
{editMode === 'manual' && (
|
||||
<Text type="tertiary" size="small" className="bg-green-100 text-green-700 px-2 py-0.5 rounded text-xs">
|
||||
{t('手动编辑模式')}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{template && templateLabel && (
|
||||
<Button
|
||||
size="small"
|
||||
type="tertiary"
|
||||
onClick={fillTemplate}
|
||||
className="!text-semi-color-primary hover:bg-blue-50 text-xs"
|
||||
>
|
||||
{templateLabel}
|
||||
</Button>
|
||||
)}
|
||||
<Space size="tight">
|
||||
<Button
|
||||
size="small"
|
||||
type={editMode === 'visual' ? 'primary' : 'tertiary'}
|
||||
icon={<IconEdit />}
|
||||
onClick={toggleEditMode}
|
||||
disabled={editMode === 'manual' && hasJsonError}
|
||||
className={editMode === 'visual' ? 'shadow-sm' : ''}
|
||||
>
|
||||
{t('可视化')}
|
||||
</Button>
|
||||
<Button
|
||||
size="small"
|
||||
type={editMode === 'manual' ? 'primary' : 'tertiary'}
|
||||
icon={<IconCode />}
|
||||
onClick={toggleEditMode}
|
||||
className={editMode === 'manual' ? 'shadow-sm' : ''}
|
||||
>
|
||||
{t('手动编辑')}
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* JSON错误提示 */}
|
||||
{hasJsonError && (
|
||||
<Banner
|
||||
type="danger"
|
||||
description={`JSON 格式错误: ${jsonError}`}
|
||||
className="!rounded-md text-sm"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* 编辑器内容 */}
|
||||
{editMode === 'visual' ? (
|
||||
<div>
|
||||
<Card className="!p-3 !border-gray-200 !shadow-sm !rounded-md bg-white">
|
||||
{renderVisualEditor()}
|
||||
</Card>
|
||||
{/* 可视化模式下的额外文本显示在下方 */}
|
||||
{extraText && (
|
||||
<div className="text-xs text-gray-600 mt-0.5">
|
||||
{extraText}
|
||||
</div>
|
||||
)}
|
||||
{/* 隐藏的Form字段用于验证和数据绑定 */}
|
||||
<Form.Input
|
||||
field={field}
|
||||
value={value}
|
||||
rules={rules}
|
||||
style={{ display: 'none' }}
|
||||
noLabel={true}
|
||||
{...props}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<Form.TextArea
|
||||
field={field}
|
||||
placeholder={placeholder}
|
||||
value={value}
|
||||
onChange={handleManualChange}
|
||||
showClear={showClear}
|
||||
rows={Math.max(8, value ? value.split('\n').length : 8)}
|
||||
rules={rules}
|
||||
noLabel={true}
|
||||
{...props}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* 额外文本在手动编辑模式下显示 */}
|
||||
{extraText && editMode === 'manual' && (
|
||||
<div className="text-xs text-gray-600">
|
||||
{extraText}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default JSONEditor;
|
||||
@@ -33,7 +33,7 @@ import {
|
||||
Settings,
|
||||
} from 'lucide-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { renderGroupOption, modelSelectFilter } from '../../helpers';
|
||||
import { renderGroupOption, selectFilter } from '../../helpers';
|
||||
import ParameterControl from './ParameterControl';
|
||||
import ImageUrlInput from './ImageUrlInput';
|
||||
import ConfigManager from './ConfigManager';
|
||||
@@ -173,7 +173,7 @@ const SettingsPanel = ({
|
||||
name='model'
|
||||
required
|
||||
selection
|
||||
filter={modelSelectFilter}
|
||||
filter={selectFilter}
|
||||
autoClearSearchValue={false}
|
||||
onChange={(value) => onInputChange('model', value)}
|
||||
value={inputs.model}
|
||||
|
||||
@@ -36,6 +36,7 @@ import {
|
||||
renderModelTag,
|
||||
getModelCategories
|
||||
} from '../../helpers';
|
||||
import TwoFASetting from './TwoFASetting';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import { UserContext } from '../../context/User';
|
||||
import { useTheme } from '../../context/Theme';
|
||||
@@ -1041,6 +1042,9 @@ const PersonalSetting = () => {
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{/* 两步验证设置 */}
|
||||
<TwoFASetting />
|
||||
|
||||
{/* 危险区域 */}
|
||||
<Card
|
||||
className="!rounded-xl border-red-200 w-full"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user