mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-17 22:37:27 +00:00
Compare commits
74 Commits
v0.8.8.0-a
...
v0.8.8.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
78f34a8245 | ||
|
|
97d6f10f15 | ||
|
|
afefc4caca | ||
|
|
6abbd036f8 | ||
|
|
ef0db0f914 | ||
|
|
e01986fdd4 | ||
|
|
a0c6ebe2d8 | ||
|
|
d2183af23f | ||
|
|
953f1bdc3c | ||
|
|
e2429f20f8 | ||
|
|
f0945da4fb | ||
|
|
8df3de9ae5 | ||
|
|
277cc1cac8 | ||
|
|
07a92293e4 | ||
|
|
f995e31d04 | ||
|
|
9758a9e60d | ||
|
|
6f56696af2 | ||
|
|
345fbdf3d2 | ||
|
|
ce031f7d15 | ||
|
|
bd6b811183 | ||
|
|
196bafff03 | ||
|
|
f20b558e22 | ||
|
|
54447bf227 | ||
|
|
fc09051d8b | ||
|
|
1f5ef24ecd | ||
|
|
b1faf42529 | ||
|
|
6a85206e32 | ||
|
|
e3d3e697d3 | ||
|
|
db9b333930 | ||
|
|
f7b284ad73 | ||
|
|
e1970e8a66 | ||
|
|
0cd93d67ff | ||
|
|
6e806e21bd | ||
|
|
a8462c1b70 | ||
|
|
706ea8b649 | ||
|
|
95d46d1dfc | ||
|
|
010f27678d | ||
|
|
d87117a2cf | ||
|
|
4ed92a94a1 | ||
|
|
821ea34a3c | ||
|
|
ecb3d01376 | ||
|
|
e322ed4f05 | ||
|
|
bcf7e78665 | ||
|
|
0cb2bb2ea7 | ||
|
|
c5d97597c4 | ||
|
|
fe9acb6c59 | ||
|
|
bca78beb1b | ||
|
|
a8a42cbfa8 | ||
|
|
19df2ac234 | ||
|
|
e7524c85c2 | ||
|
|
a4356727e9 | ||
|
|
f15a53fae4 | ||
|
|
8e3cf2eaab | ||
|
|
c51ec3135b | ||
|
|
2469c439b1 | ||
|
|
1297addfb1 | ||
|
|
d6cbf43373 | ||
|
|
df647e7b42 | ||
|
|
fe16d05fbb | ||
|
|
1430c05b6c | ||
|
|
b25841e50d | ||
|
|
b704fc9254 | ||
|
|
352da66bd1 | ||
|
|
8205ad2cd0 | ||
|
|
e162b9c169 | ||
|
|
77e3502028 | ||
|
|
ae0461692c | ||
|
|
13bdb80958 | ||
|
|
6f74e7b738 | ||
|
|
eaee89f77a | ||
|
|
6103888610 | ||
|
|
7af3fb5ae4 | ||
|
|
3ac54b2178 | ||
|
|
5621755655 |
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type stringWriter interface {
|
||||
@@ -52,6 +53,8 @@ type CustomEvent struct {
|
||||
Id string
|
||||
Retry uint
|
||||
Data interface{}
|
||||
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
func encode(writer io.Writer, event CustomEvent) error {
|
||||
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
|
||||
}
|
||||
|
||||
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||
r.Mutex.Lock()
|
||||
defer r.Mutex.Unlock()
|
||||
header := w.Header()
|
||||
header["Content-Type"] = contentType
|
||||
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -95,3 +98,95 @@ func GetJsonString(data any) string {
|
||||
b, _ := json.Marshal(data)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string
|
||||
// Example:
|
||||
// http://example.com -> http://***.com
|
||||
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
|
||||
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
|
||||
// 192.168.1.1 -> ***.***.***.***
|
||||
func MaskSensitiveInfo(str string) string {
|
||||
// Mask URLs
|
||||
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
host := u.Host
|
||||
if host == "" {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
// Split host by dots
|
||||
parts := strings.Split(host, ".")
|
||||
if len(parts) < 2 {
|
||||
// If less than 2 parts, just mask the whole host
|
||||
return u.Scheme + "://***" + u.Path
|
||||
}
|
||||
|
||||
// Keep the TLD (Top Level Domain) and mask the rest
|
||||
var maskedHost string
|
||||
if len(parts) == 2 {
|
||||
// example.com -> ***.com
|
||||
maskedHost = "***." + parts[len(parts)-1]
|
||||
} else {
|
||||
// Handle cases like sub.domain.co.uk or api.example.com
|
||||
// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
|
||||
lastPart := parts[len(parts)-1]
|
||||
secondLastPart := parts[len(parts)-2]
|
||||
|
||||
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
|
||||
// Likely country code TLD like co.uk, com.cn
|
||||
maskedHost = "***." + secondLastPart + "." + lastPart
|
||||
} else {
|
||||
// Regular TLD like .com, .org
|
||||
maskedHost = "***." + lastPart
|
||||
}
|
||||
}
|
||||
|
||||
result := u.Scheme + "://" + maskedHost
|
||||
|
||||
// Mask path
|
||||
if u.Path != "" && u.Path != "/" {
|
||||
pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
|
||||
maskedPathParts := make([]string, len(pathParts))
|
||||
for i := range pathParts {
|
||||
if pathParts[i] != "" {
|
||||
maskedPathParts[i] = "***"
|
||||
}
|
||||
}
|
||||
if len(maskedPathParts) > 0 {
|
||||
result += "/" + strings.Join(maskedPathParts, "/")
|
||||
}
|
||||
} else if u.Path == "/" {
|
||||
result += "/"
|
||||
}
|
||||
|
||||
// Mask query parameters
|
||||
if u.RawQuery != "" {
|
||||
values, err := url.ParseQuery(u.RawQuery)
|
||||
if err != nil {
|
||||
// If can't parse query, just mask the whole query string
|
||||
result += "?***"
|
||||
} else {
|
||||
maskedParams := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
maskedParams = append(maskedParams, key+"=***")
|
||||
}
|
||||
if len(maskedParams) > 0 {
|
||||
result += "?" + strings.Join(maskedParams, "&")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// Mask IP addresses
|
||||
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ const (
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeVidu = 52
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -106,4 +107,5 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
"https://api.vidu.cn", //52
|
||||
}
|
||||
|
||||
@@ -69,6 +69,12 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeVidu {
|
||||
return testResult{
|
||||
localErr: errors.New("vidu channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
@@ -203,7 +209,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
var httpResp *http.Response
|
||||
@@ -214,7 +220,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -230,7 +236,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: errors.New("usage is nil"),
|
||||
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
|
||||
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
@@ -240,7 +246,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
@@ -326,8 +332,11 @@ func TestChannel(c *gin.Context) {
|
||||
}
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
channel, err = model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
//defer func() {
|
||||
// if channel.ChannelInfo.IsMultiKey {
|
||||
@@ -411,7 +420,7 @@ func testAllChannels(notify bool) error {
|
||||
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||
if milliseconds > disableThreshold {
|
||||
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
|
||||
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
||||
shouldBanChannel = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -669,6 +669,7 @@ func DeleteChannelBatch(c *gin.Context) {
|
||||
type PatchChannel struct {
|
||||
model.Channel
|
||||
MultiKeyMode *string `json:"multi_key_mode"`
|
||||
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
|
||||
}
|
||||
|
||||
func UpdateChannel(c *gin.Context) {
|
||||
@@ -688,7 +689,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
||||
originChannel, err := model.GetChannelById(channel.Id, false)
|
||||
originChannel, err := model.GetChannelById(channel.Id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -704,6 +705,69 @@ func UpdateChannel(c *gin.Context) {
|
||||
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
||||
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
||||
}
|
||||
|
||||
// 处理多key模式下的密钥追加/覆盖逻辑
|
||||
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
|
||||
switch *channel.KeyMode {
|
||||
case "append":
|
||||
// 追加模式:将新密钥添加到现有密钥列表
|
||||
if originChannel.Key != "" {
|
||||
var newKeys []string
|
||||
var existingKeys []string
|
||||
|
||||
// 解析现有密钥
|
||||
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
|
||||
// JSON数组格式
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
|
||||
existingKeys = make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
existingKeys[i] = string(v)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 换行分隔格式
|
||||
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
|
||||
}
|
||||
|
||||
// 处理 Vertex AI 的特殊情况
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
// 尝试解析新密钥为JSON数组
|
||||
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||||
array, err := getVertexArrayKeys(channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "追加密钥解析失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
newKeys = array
|
||||
} else {
|
||||
// 单个JSON密钥
|
||||
newKeys = []string{channel.Key}
|
||||
}
|
||||
// 合并密钥
|
||||
allKeys := append(existingKeys, newKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
} else {
|
||||
// 普通渠道的处理
|
||||
inputKeys := strings.Split(channel.Key, "\n")
|
||||
for _, key := range inputKeys {
|
||||
key = strings.TrimSpace(key)
|
||||
if key != "" {
|
||||
newKeys = append(newKeys, key)
|
||||
}
|
||||
}
|
||||
// 合并密钥
|
||||
allKeys := append(existingKeys, newKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
}
|
||||
}
|
||||
case "replace":
|
||||
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
|
||||
}
|
||||
}
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
|
||||
@@ -28,19 +28,19 @@ func Playground(c *gin.Context) {
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
@@ -51,7 +51,7 @@ func Playground(c *gin.Context) {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
@@ -62,7 +62,7 @@ func Playground(c *gin.Context) {
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -63,7 +64,7 @@ func AddRedemption(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
||||
if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "兑换码名称长度必须在1-20之间",
|
||||
|
||||
@@ -47,7 +47,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
|
||||
if constant2.ErrorLogEnabled && err != nil {
|
||||
if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
|
||||
// 保存错误日志到mysql中
|
||||
userId := c.GetInt("id")
|
||||
tokenName := c.GetString("token_name")
|
||||
@@ -56,14 +56,21 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = err.ErrorType
|
||||
other["error_type"] = err.GetErrorType()
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
|
||||
adminInfo := make(map[string]interface{})
|
||||
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||
if isMultiKey {
|
||||
adminInfo["is_multi_key"] = true
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
other["admin_info"] = adminInfo
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -128,7 +135,7 @@ func WssRelay(c *gin.Context) {
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -259,10 +266,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
}
|
||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed)
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
if newAPIError != nil {
|
||||
@@ -278,7 +285,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
if types.IsChannelError(openaiErr) {
|
||||
return true
|
||||
}
|
||||
if types.IsLocalError(openaiErr) {
|
||||
if types.IsSkipRetryError(openaiErr) {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
|
||||
@@ -83,7 +83,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = json.Unmarshal(responseBody, &responseItems); err == nil {
|
||||
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package dto
|
||||
|
||||
type ChannelSettings struct {
|
||||
ForceFormat bool `json:"force_format,omitempty"`
|
||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||
Proxy string `json:"proxy"`
|
||||
ForceFormat bool `json:"force_format,omitempty"`
|
||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||
Proxy string `json:"proxy"`
|
||||
PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
@@ -284,14 +285,9 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
||||
return mediaContent
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeErrorWithStatusCode struct {
|
||||
Error ClaudeError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
Error types.ClaudeError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
LocalError bool
|
||||
}
|
||||
|
||||
@@ -303,7 +299,7 @@ type ClaudeResponse struct {
|
||||
Completion string `json:"completion,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Error *types.ClaudeError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||
@@ -324,6 +320,42 @@ func (c *ClaudeResponse) GetIndex() int {
|
||||
return *c.Index
|
||||
}
|
||||
|
||||
// GetClaudeError 从动态错误类型中提取ClaudeError结构
|
||||
func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
|
||||
if c.Error == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err := c.Error.(type) {
|
||||
case types.ClaudeError:
|
||||
return &err
|
||||
case *types.ClaudeError:
|
||||
return err
|
||||
case map[string]interface{}:
|
||||
// 处理从JSON解析来的map结构
|
||||
claudeErr := &types.ClaudeError{}
|
||||
if errType, ok := err["type"].(string); ok {
|
||||
claudeErr.Type = errType
|
||||
}
|
||||
if errMsg, ok := err["message"].(string); ok {
|
||||
claudeErr.Message = errMsg
|
||||
}
|
||||
return claudeErr
|
||||
case string:
|
||||
// 处理简单字符串错误
|
||||
return &types.ClaudeError{
|
||||
Type: "error",
|
||||
Message: err,
|
||||
}
|
||||
default:
|
||||
// 未知类型,尝试转换为字符串
|
||||
return &types.ClaudeError{
|
||||
Type: "unknown_error",
|
||||
Message: fmt.Sprintf("%v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package gemini
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
@@ -32,7 +35,7 @@ func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
|
||||
MimeTypeSnake string `json:"mime_type"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -53,7 +56,7 @@ type FunctionCall struct {
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
}
|
||||
@@ -78,7 +81,7 @@ type GeminiPart struct {
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||
@@ -93,7 +96,7 @@ func (p *GeminiPart) UnmarshalJSON(data []byte) error {
|
||||
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -7,15 +7,15 @@ import (
|
||||
)
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema json.RawMessage `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type FormatJsonSchema struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema any `json:"schema,omitempty"`
|
||||
Strict any `json:"strict,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema any `json:"schema,omitempty"`
|
||||
Strict json.RawMessage `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
@@ -73,6 +73,15 @@ func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
|
||||
if strings.HasPrefix(r.Model, "o") {
|
||||
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
|
||||
return "developer"
|
||||
}
|
||||
}
|
||||
return "system"
|
||||
}
|
||||
|
||||
type ToolCallRequest struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
|
||||
@@ -2,12 +2,18 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
Error *OpenAIError `json:"error"`
|
||||
Error any `json:"error"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(s.Error)
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
@@ -31,10 +37,15 @@ type OpenAITextResponse struct {
|
||||
Object string `json:"object"`
|
||||
Created any `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(o.Error)
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
@@ -217,7 +228,7 @@ type OpenAIResponsesResponse struct {
|
||||
Object string `json:"object"`
|
||||
CreatedAt int `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
Instructions string `json:"instructions"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
@@ -237,6 +248,11 @@ type OpenAIResponsesResponse struct {
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(o.Error)
|
||||
}
|
||||
|
||||
type IncompleteDetails struct {
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
@@ -276,3 +292,45 @@ type ResponsesStreamResponse struct {
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Item *ResponsesOutput `json:"item,omitempty"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func GetOpenAIError(errorField any) *types.OpenAIError {
|
||||
if errorField == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err := errorField.(type) {
|
||||
case types.OpenAIError:
|
||||
return &err
|
||||
case *types.OpenAIError:
|
||||
return err
|
||||
case map[string]interface{}:
|
||||
// 处理从JSON解析来的map结构
|
||||
openaiErr := &types.OpenAIError{}
|
||||
if errType, ok := err["type"].(string); ok {
|
||||
openaiErr.Type = errType
|
||||
}
|
||||
if errMsg, ok := err["message"].(string); ok {
|
||||
openaiErr.Message = errMsg
|
||||
}
|
||||
if errParam, ok := err["param"].(string); ok {
|
||||
openaiErr.Param = errParam
|
||||
}
|
||||
if errCode, ok := err["code"]; ok {
|
||||
openaiErr.Code = errCode
|
||||
}
|
||||
return openaiErr
|
||||
case string:
|
||||
// 处理简单字符串错误
|
||||
return &types.OpenAIError{
|
||||
Type: "error",
|
||||
Message: err,
|
||||
}
|
||||
default:
|
||||
// 未知类型,尝试转换为字符串
|
||||
return &types.OpenAIError{
|
||||
Type: "unknown_error",
|
||||
Message: fmt.Sprintf("%v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -585,6 +585,19 @@
|
||||
"渠道权重": "渠道权重",
|
||||
"渠道额外设置": "渠道额外设置",
|
||||
"此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:": "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:",
|
||||
"强制格式化": "强制格式化",
|
||||
"强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)": "强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)",
|
||||
"思考内容转换": "思考内容转换",
|
||||
"将 reasoning_content 转换为 <think> 标签拼接到内容中": "将 reasoning_content 转换为 <think> 标签拼接到内容中",
|
||||
"透传请求体": "透传请求体",
|
||||
"启用请求体透传功能": "启用请求体透传功能",
|
||||
"代理地址": "代理地址",
|
||||
"例如: socks5://user:pass@host:port": "例如: socks5://user:pass@host:port",
|
||||
"用于配置网络代理": "用于配置网络代理",
|
||||
"用于配置网络代理,支持 socks5 协议": "用于配置网络代理,支持 socks5 协议",
|
||||
"系统提示词": "系统提示词",
|
||||
"输入系统提示词,用户的系统提示词将优先于此设置": "输入系统提示词,用户的系统提示词将优先于此设置",
|
||||
"用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置": "用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置",
|
||||
"参数覆盖": "参数覆盖",
|
||||
"此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:": "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:",
|
||||
"请输入组织org-xxx": "请输入组织org-xxx",
|
||||
|
||||
@@ -122,6 +122,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Set("role", role)
|
||||
c.Set("id", id)
|
||||
c.Set("group", session.Get("group"))
|
||||
c.Set("user_group", session.Get("group"))
|
||||
c.Set("use_access_token", useAccessToken)
|
||||
|
||||
//userCache, err := model.GetUserCache(id.(int))
|
||||
|
||||
@@ -100,6 +100,10 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
if shouldSelectChannel {
|
||||
if modelRequest.Model == "" {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空")
|
||||
return
|
||||
}
|
||||
var selectGroup string
|
||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
@@ -107,18 +111,17 @@ func Distribute() func(c *gin.Context) {
|
||||
if userGroup == "auto" {
|
||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||
}
|
||||
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
}
|
||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
||||
//if channel != nil {
|
||||
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
// message = "数据库一致性已被破坏,请联系管理员"
|
||||
//}
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
||||
return
|
||||
}
|
||||
if channel == nil {
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,distributor)", userGroup, modelRequest.Model))
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -244,7 +247,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||
c.Set("original_model", modelName) // for retry
|
||||
if channel == nil {
|
||||
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
|
||||
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||
@@ -266,6 +269,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
||||
} else {
|
||||
// 必须设置为 false,否则在重试到单个 key 的时候会导致日志显示错误
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
|
||||
}
|
||||
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
|
||||
@@ -136,7 +136,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("channel not found")
|
||||
return nil, nil
|
||||
}
|
||||
err = DB.First(&channel, "id = ?", channel.Id).Error
|
||||
return &channel, err
|
||||
|
||||
@@ -75,7 +75,7 @@ func (channel *Channel) getKeys() []string {
|
||||
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
res := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
res[i] = string(v)
|
||||
@@ -138,7 +138,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||||
|
||||
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||||
if err != nil {
|
||||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
||||
defer func() {
|
||||
@@ -197,7 +197,7 @@ func (channel *Channel) GetGroups() []string {
|
||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
otherInfo := make(map[string]interface{})
|
||||
if channel.OtherInfo != "" {
|
||||
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal other info: " + err.Error())
|
||||
}
|
||||
@@ -425,7 +425,7 @@ func (channel *Channel) Update() error {
|
||||
trimmed := strings.TrimSpace(keyStr)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
keys = make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
keys[i] = string(v)
|
||||
@@ -571,10 +571,6 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
if channelCache.Status == status {
|
||||
return false
|
||||
}
|
||||
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
||||
if status != common.ChannelStatusEnabled {
|
||||
return false
|
||||
}
|
||||
CacheUpdateChannelStatus(channelId, status)
|
||||
}
|
||||
}
|
||||
@@ -778,7 +774,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
func (channel *Channel) ValidateSettings() error {
|
||||
channelParams := &dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
err := common.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -789,7 +785,7 @@ func (channel *Channel) ValidateSettings() error {
|
||||
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
setting := dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
channel.Setting = nil // 清空设置以避免后续错误
|
||||
@@ -800,7 +796,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
}
|
||||
|
||||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
settingBytes, err := common.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
@@ -811,7 +807,7 @@ func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||||
paramOverride := make(map[string]interface{})
|
||||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||||
err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal param override: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -66,6 +67,15 @@ func InitChannelCache() {
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
//channelsIDM = newChannelId2channel
|
||||
for i, channel := range newChannelId2channel {
|
||||
if oldChannel, ok := channelsIDM[i]; ok {
|
||||
// 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
|
||||
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
channelsIDM = newChannelId2channel
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
@@ -130,7 +140,7 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
channels := group2model2channels[group][model]
|
||||
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(channels) == 1 {
|
||||
@@ -203,9 +213,6 @@ func CacheGetChannel(id int) (*Channel, error) {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
if c.Status != common.ChannelStatusEnabled {
|
||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -224,9 +231,6 @@ func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
if c.Status != common.ChannelStatusEnabled {
|
||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||
}
|
||||
return &c.ChannelInfo, nil
|
||||
}
|
||||
|
||||
@@ -239,6 +243,20 @@ func CacheUpdateChannelStatus(id int, status int) {
|
||||
if channel, ok := channelsIDM[id]; ok {
|
||||
channel.Status = status
|
||||
}
|
||||
if status != common.ChannelStatusEnabled {
|
||||
// delete the channel from group2model2channels
|
||||
for group, model2channels := range group2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
for i, channelId := range channels {
|
||||
if channelId == id {
|
||||
// remove the channel from the slice
|
||||
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CacheUpdateChannel(channel *Channel) {
|
||||
|
||||
@@ -62,7 +62,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptTokens := 0
|
||||
@@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -90,18 +90,18 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
||||
|
||||
@@ -26,6 +26,7 @@ type Adaptor interface {
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
|
||||
}
|
||||
|
||||
type TaskAdaptor interface {
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -132,12 +132,12 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
var aliTaskResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliTaskResponse.Message != "" {
|
||||
|
||||
@@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
|
||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var aliResponse AliRerankResponse
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
|
||||
@@ -43,7 +43,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
|
||||
var fullTextResponse dto.FlexibleEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
@@ -179,12 +179,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
|
||||
@@ -223,7 +223,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
helper.SetEventStreamHeaders(c)
|
||||
// 处理流式请求的 ping 保活
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
if generalSettings.PingIntervalEnabled {
|
||||
if generalSettings.PingIntervalEnabled && !info.DisablePing {
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
stopPinger = startPingKeepAlive(c, pingInterval)
|
||||
// 使用defer确保在任何情况下都能停止ping goroutine
|
||||
|
||||
@@ -22,6 +22,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", request)
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -43,15 +48,15 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 0 || keyParts[0] == "" {
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+keyParts[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -612,8 +612,8 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
@@ -704,8 +704,8 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ConvertAudioRequest implements channel.Adaptor.
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
||||
BotType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/setting/model_setting"
|
||||
@@ -21,10 +20,33 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
if len(request.Contents) > 0 {
|
||||
for i, content := range request.Contents {
|
||||
if i == 0 {
|
||||
if request.Contents[0].Role == "" {
|
||||
request.Contents[0].Role = "user"
|
||||
}
|
||||
}
|
||||
for _, part := range content.Parts {
|
||||
if part.FileData != nil {
|
||||
if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") {
|
||||
part.FileData.MimeType = "video/webm"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, oaiReq.(*dto.GeneralOpenAIRequest))
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -49,13 +71,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
// build gemini imagen request
|
||||
geminiRequest := GeminiImageRequest{
|
||||
Instances: []GeminiImageInstance{
|
||||
geminiRequest := dto.GeminiImageRequest{
|
||||
Instances: []dto.GeminiImageInstance{
|
||||
{
|
||||
Prompt: request.Prompt,
|
||||
},
|
||||
},
|
||||
Parameters: GeminiImageParameters{
|
||||
Parameters: dto.GeminiImageParameters{
|
||||
SampleCount: request.N,
|
||||
AspectRatio: aspectRatio,
|
||||
PersonGeneration: "allow_adult", // default allow adult
|
||||
@@ -136,9 +158,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
}
|
||||
|
||||
// only process the first input
|
||||
geminiRequest := GeminiEmbeddingRequest{
|
||||
Content: GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
geminiRequest := dto.GeminiEmbeddingRequest{
|
||||
Content: dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{
|
||||
Text: inputs[0],
|
||||
},
|
||||
@@ -171,6 +193,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
if info.IsStream {
|
||||
info.DisablePing = true
|
||||
return GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
} else {
|
||||
return GeminiTextGenerationHandler(c, info, resp)
|
||||
@@ -208,60 +231,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse GeminiImageResponse
|
||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
if len(geminiResponse.Predictions) == 0 {
|
||||
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
openAIResponse := dto.ImageResponse{
|
||||
Created: common.GetTimestamp(),
|
||||
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
|
||||
}
|
||||
|
||||
for _, prediction := range geminiResponse.Predictions {
|
||||
if prediction.RaiFilteredReason != "" {
|
||||
continue // skip filtered image
|
||||
}
|
||||
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
|
||||
B64Json: prediction.BytesBase64Encoded,
|
||||
})
|
||||
}
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
|
||||
// each image has fixed 258 tokens
|
||||
const imageTokens = 258
|
||||
generatedImages := len(openAIResponse.Data)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
||||
CompletionTokens: 0, // image generation does not calculate completion tokens
|
||||
TotalTokens: imageTokens * generatedImages,
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -20,7 +21,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
@@ -28,10 +29,10 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
|
||||
// 解析为 Gemini 原生响应格式
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
@@ -54,7 +55,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||
jsonResponse, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
@@ -71,7 +72,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
@@ -110,10 +111,14 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
|
||||
info.SendResponseCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -80,7 +81,7 @@ func clampThinkingBudget(modelName string, budget int) int {
|
||||
return budget
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
@@ -92,7 +93,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
@@ -112,11 +113,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
if isUnsupported {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
@@ -127,7 +128,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||
if !isNew25Pro {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
@@ -136,11 +137,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
geminiRequest := dto.GeminiChatRequest{
|
||||
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
GenerationConfig: dto.GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
@@ -157,9 +158,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
|
||||
ThinkingAdaptor(&geminiRequest, info)
|
||||
|
||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
for _, category := range SafetySettingList {
|
||||
safetySettings = append(safetySettings, GeminiChatSafetySettings{
|
||||
safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
|
||||
Category: category,
|
||||
Threshold: model_setting.GetGeminiSafetySetting(category),
|
||||
})
|
||||
@@ -197,17 +198,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
if codeExecution {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
CodeExecution: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if googleSearch {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
})
|
||||
}
|
||||
@@ -219,9 +220,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
|
||||
|
||||
if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
|
||||
cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
|
||||
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
||||
if len(textRequest.ResponseFormat.JsonSchema) > 0 {
|
||||
// 先将json.RawMessage解析
|
||||
var jsonSchema dto.FormatJsonSchema
|
||||
if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
|
||||
cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
|
||||
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
||||
}
|
||||
}
|
||||
}
|
||||
tool_call_ids := make(map[string]string)
|
||||
@@ -233,7 +238,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
continue
|
||||
} else if message.Role == "tool" || message.Role == "function" {
|
||||
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
|
||||
Role: "user",
|
||||
})
|
||||
}
|
||||
@@ -260,18 +265,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
}
|
||||
|
||||
functionResp := &FunctionResponse{
|
||||
functionResp := &dto.GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: contentMap,
|
||||
}
|
||||
|
||||
*parts = append(*parts, GeminiPart{
|
||||
*parts = append(*parts, dto.GeminiPart{
|
||||
FunctionResponse: functionResp,
|
||||
})
|
||||
continue
|
||||
}
|
||||
var parts []GeminiPart
|
||||
content := GeminiChatContent{
|
||||
var parts []dto.GeminiPart
|
||||
content := dto.GeminiChatContent{
|
||||
Role: message.Role,
|
||||
}
|
||||
// isToolCall := false
|
||||
@@ -285,8 +290,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
|
||||
}
|
||||
}
|
||||
toolCall := GeminiPart{
|
||||
FunctionCall: &FunctionCall{
|
||||
toolCall := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: call.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
@@ -303,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if part.Text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
@@ -326,8 +331,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
|
||||
Data: fileData.Base64Data,
|
||||
},
|
||||
@@ -337,8 +342,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -352,8 +357,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -366,8 +371,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: "audio/" + part.GetInputAudio().Format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -387,8 +392,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
|
||||
if len(system_content) > 0 {
|
||||
geminiRequest.SystemInstructions = &GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
geminiRequest.SystemInstructions = &dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{
|
||||
Text: strings.Join(system_content, "\n"),
|
||||
},
|
||||
@@ -631,7 +636,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
|
||||
return data
|
||||
}
|
||||
|
||||
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
|
||||
var argsBytes []byte
|
||||
var err error
|
||||
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||
@@ -653,7 +658,7 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
@@ -720,10 +725,9 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
||||
isStop := false
|
||||
hasImage := false
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
||||
isStop = true
|
||||
@@ -732,7 +736,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
choice := dto.ChatCompletionsStreamResponseChoice{
|
||||
Index: int(candidate.Index),
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Role: "assistant",
|
||||
//Role: "assistant",
|
||||
},
|
||||
}
|
||||
var texts []string
|
||||
@@ -754,7 +758,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
hasImage = true
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
isTools = true
|
||||
@@ -791,28 +794,59 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Choices = choices
|
||||
return &response, isStop, hasImage
|
||||
return &response, isStop
|
||||
}
|
||||
|
||||
func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
|
||||
streamData, err := common.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal stream response: %w", err)
|
||||
}
|
||||
err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to handle stream format: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
|
||||
streamData, err := common.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal stream response: %w", err)
|
||||
}
|
||||
openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, info.ShouldIncludeUsage)
|
||||
return nil
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
// responseText := ""
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
responseText := strings.Builder{}
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
if hasImage {
|
||||
imageCount++
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
if part.Text != "" {
|
||||
responseText.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
@@ -829,18 +863,30 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
}
|
||||
err = helper.ObjectData(c, response)
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
// send first response
|
||||
err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil))
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
err = handleStream(c, info, response)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
if isStop {
|
||||
response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
|
||||
helper.ObjectData(c, response)
|
||||
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop))
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
var response *dto.ChatCompletionsStreamResponse
|
||||
if info.SendResponseCount == 0 {
|
||||
// 空补全,报错不计费
|
||||
// empty response, throw an error
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
@@ -851,14 +897,24 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := handleFinalStream(c, info, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
// helper.Done(c)
|
||||
//}
|
||||
//resp.Body.Close()
|
||||
return usage, nil
|
||||
}
|
||||
@@ -866,19 +922,19 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
@@ -915,12 +971,12 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var geminiResponse GeminiEmbeddingResponse
|
||||
var geminiResponse dto.GeminiEmbeddingResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
@@ -950,9 +1006,63 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
|
||||
jsonResponse, jsonErr := common.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse dto.GeminiImageResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if len(geminiResponse.Predictions) == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
openAIResponse := dto.ImageResponse{
|
||||
Created: common.GetTimestamp(),
|
||||
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
|
||||
}
|
||||
|
||||
for _, prediction := range geminiResponse.Predictions {
|
||||
if prediction.RaiFilteredReason != "" {
|
||||
continue // skip filtered image
|
||||
}
|
||||
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
|
||||
B64Json: prediction.BytesBase64Encoded,
|
||||
})
|
||||
}
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
|
||||
// each image has fixed 258 tokens
|
||||
const imageTokens = 258
|
||||
generatedImages := len(openAIResponse.Data)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
||||
CompletionTokens: 0, // image generation does not calculate completion tokens
|
||||
TotalTokens: imageTokens * generatedImages,
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
@@ -13,11 +12,18 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
|
||||
var jimengResponse ImageResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
err = json.Unmarshal(responseBody, &jimengResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// Check if the response indicates an error
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -17,10 +17,21 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
openaiAdaptor := openai.Adaptor{}
|
||||
openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -37,6 +48,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
return info.BaseUrl + "/v1/chat/completions", nil
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return info.BaseUrl + "/api/embed", nil
|
||||
@@ -55,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return requestOpenAI2Ollama(*request)
|
||||
return requestOpenAI2Ollama(request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -76,11 +90,12 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if !message.IsStringContent() {
|
||||
@@ -92,15 +92,15 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if ollamaEmbeddingResponse.Error != "" {
|
||||
return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||
@@ -121,7 +121,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
}
|
||||
doResponseBody, err := common.Marshal(embeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
common.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||
return usage, nil
|
||||
|
||||
@@ -34,6 +34,15 @@ type Adaptor struct {
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
// 使用 service.GeminiToOpenAIRequest 转换请求格式
|
||||
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, openaiRequest)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
//if !strings.Contains(request.Model, "claude") {
|
||||
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||
@@ -64,7 +73,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
|
||||
@@ -2,6 +2,8 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -14,20 +16,23 @@ import (
|
||||
)
|
||||
|
||||
// 辅助函数
|
||||
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
case relaycommon.RelayFormatClaude:
|
||||
return handleClaudeFormat(c, data, info)
|
||||
case relaycommon.RelayFormatGemini:
|
||||
return handleGeminiFormat(c, data, info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -41,6 +46,36 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
common.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// 如果返回 nil,表示没有实际内容,跳过发送
|
||||
if geminiResponse == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "failed to marshal gemini response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// send gemini format response
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
@@ -158,7 +193,7 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||||
func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||||
responseId string, createAt int64, model string, systemFingerprint string,
|
||||
usage *dto.Usage, containStreamUsage bool) {
|
||||
|
||||
@@ -174,7 +209,7 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
case relaycommon.RelayFormatClaude:
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
@@ -183,7 +218,38 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
|
||||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||
for _, resp := range claudeResponses {
|
||||
helper.ClaudeData(c, *resp)
|
||||
_ = helper.ClaudeData(c, *resp)
|
||||
}
|
||||
|
||||
case relaycommon.RelayFormatGemini:
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
|
||||
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
|
||||
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
|
||||
// 暂不知是否有程序会不兼容。
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// openai 流响应开头的空数据
|
||||
if geminiResponse == nil {
|
||||
return
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling gemini response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 发送最终的 Gemini 响应
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
common.LogError(c, "invalid response or response body")
|
||||
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
@@ -123,30 +123,19 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
var toolCount int
|
||||
var usage = &dto.Usage{}
|
||||
var streamItems []string // store stream items
|
||||
var forceFormat bool
|
||||
var thinkToContent bool
|
||||
|
||||
if info.ChannelSetting.ForceFormat {
|
||||
forceFormat = true
|
||||
}
|
||||
|
||||
if info.ChannelSetting.ThinkingToContent {
|
||||
thinkToContent = true
|
||||
}
|
||||
|
||||
var (
|
||||
lastStreamData string
|
||||
)
|
||||
var lastStreamData string
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
if lastStreamData != "" {
|
||||
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
if err != nil {
|
||||
common.SysError("error handling stream format: " + err.Error())
|
||||
}
|
||||
}
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
if len(data) > 0 {
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -154,16 +143,18 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
shouldSendLastResp := true
|
||||
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
|
||||
&containStreamUsage, info, &shouldSendLastResp); err != nil {
|
||||
common.SysError("error handling last response: " + err.Error())
|
||||
common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
|
||||
}
|
||||
|
||||
if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
if shouldSendLastResp {
|
||||
_ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理token计算
|
||||
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
||||
common.SysError("error processing tokens: " + err.Error())
|
||||
common.LogError(c, "error processing tokens: "+err.Error())
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
@@ -176,8 +167,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
@@ -188,14 +178,14 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
var simpleResponse dto.OpenAITextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
err = common.Unmarshal(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
|
||||
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
|
||||
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
forceFormat := false
|
||||
@@ -233,6 +223,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
case relaycommon.RelayFormatGemini:
|
||||
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
|
||||
geminiRespStr, err := common.Marshal(geminiResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = geminiRespStr
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
@@ -273,7 +270,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
@@ -557,13 +554,13 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
err = common.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
|
||||
@@ -22,14 +22,14 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
var responsesResponse dto.OpenAIResponsesResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
err = common.Unmarshal(responseBody, &responsesResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if responsesResponse.Error != nil {
|
||||
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
|
||||
if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -127,13 +127,13 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
||||
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,20 +18,24 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertImageRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -47,7 +51,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
} else if info.RelayMode == constant.RelayModeCompletions {
|
||||
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
|
||||
}
|
||||
return "", errors.New("invalid relay mode")
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
@@ -81,16 +85,19 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRerank:
|
||||
usage, err = siliconflowRerankHandler(c, info, resp)
|
||||
case constant.RelayModeEmbeddings:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
case constant.RelayModeCompletions:
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
fallthrough
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,13 +15,13 @@ import (
|
||||
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var siliconflowResp SFRerankResponse
|
||||
err = json.Unmarshal(responseBody, &siliconflowResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
|
||||
|
||||
285
relay/channel/task/vidu/adaptor.go
Normal file
285
relay/channel/task/vidu/adaptor.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package vidu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Model string `json:"model"`
|
||||
Images []string `json:"images"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
MovementAmplitude string `json:"movement_amplitude,omitempty"`
|
||||
Bgm bool `json:"bgm,omitempty"`
|
||||
Payload string `json:"payload,omitempty"`
|
||||
CallbackUrl string `json:"callback_url,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
TaskId string `json:"task_id"`
|
||||
State string `json:"state"`
|
||||
Model string `json:"model"`
|
||||
Images []string `json:"images"`
|
||||
Prompt string `json:"prompt"`
|
||||
Duration int `json:"duration"`
|
||||
Seed int `json:"seed"`
|
||||
Resolution string `json:"resolution"`
|
||||
Bgm bool `json:"bgm"`
|
||||
MovementAmplitude string `json:"movement_amplitude"`
|
||||
Payload string `json:"payload"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
type taskResultResponse struct {
|
||||
State string `json:"state"`
|
||||
ErrCode string `json:"err_code"`
|
||||
Credits int `json:"credits"`
|
||||
Payload string `json:"payload"`
|
||||
Creations []creation `json:"creations"`
|
||||
}
|
||||
|
||||
type creation struct {
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
CoverURL string `json:"cover_url"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.BaseUrl
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {
|
||||
var req SubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Image != "" {
|
||||
info.Action = constant.TaskActionGenerate
|
||||
} else {
|
||||
info.Action = constant.TaskActionTextGenerate
|
||||
}
|
||||
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(body.Images) == 0 {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
var path string
|
||||
switch info.Action {
|
||||
case constant.TaskActionGenerate:
|
||||
path = "/img2video"
|
||||
default:
|
||||
path = "/text2video"
|
||||
}
|
||||
return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
if action := c.GetString("action"); action != "" {
|
||||
info.Action = action
|
||||
}
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var vResp responsePayload
|
||||
err = json.Unmarshal(responseBody, &vResp)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if vResp.State == "failed" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, vResp)
|
||||
return vResp.TaskId, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return []string{"viduq1", "vidu2.0", "vidu1.5"}
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return "vidu"
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
var images []string
|
||||
if req.Image != "" {
|
||||
images = []string{req.Image}
|
||||
}
|
||||
|
||||
r := requestPayload{
|
||||
Model: defaultString(req.Model, "viduq1"),
|
||||
Images: images,
|
||||
Prompt: req.Prompt,
|
||||
Duration: defaultInt(req.Duration, 5),
|
||||
Resolution: defaultString(req.Size, "1080p"),
|
||||
MovementAmplitude: "auto",
|
||||
Bgm: false,
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func defaultString(value, defaultValue string) string {
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func defaultInt(value, defaultValue int) int {
|
||||
if value == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
|
||||
var taskResp taskResultResponse
|
||||
err := json.Unmarshal(respBody, &taskResp)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
|
||||
state := taskResp.State
|
||||
switch state {
|
||||
case "created", "queueing":
|
||||
taskInfo.Status = model.TaskStatusSubmitted
|
||||
case "processing":
|
||||
taskInfo.Status = model.TaskStatusInProgress
|
||||
case "success":
|
||||
taskInfo.Status = model.TaskStatusSuccess
|
||||
if len(taskResp.Creations) > 0 {
|
||||
taskInfo.Url = taskResp.Creations[0].URL
|
||||
}
|
||||
case "failed":
|
||||
taskInfo.Status = model.TaskStatusFailure
|
||||
if taskResp.ErrCode != "" {
|
||||
taskInfo.Reason = taskResp.ErrCode
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown task state: %s", state)
|
||||
}
|
||||
|
||||
return taskInfo, nil
|
||||
}
|
||||
@@ -25,6 +25,11 @@ type Adaptor struct {
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
|
||||
var tencentSb TencentChatResponseSB
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &tencentSb)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if tencentSb.Response.Error.Code != 0 {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
|
||||
@@ -44,6 +44,11 @@ type Adaptor struct {
|
||||
AccountCredentials Credentials
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
geminiAdaptor := gemini.Adaptor{}
|
||||
return geminiAdaptor.ConvertGeminiRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
c.Set("request_model", v)
|
||||
@@ -67,10 +72,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude") {
|
||||
a.RequestMode = RequestModeClaude
|
||||
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
|
||||
a.RequestMode = RequestModeGemini
|
||||
} else if strings.Contains(info.UpstreamModelName, "llama") {
|
||||
a.RequestMode = RequestModeLlama
|
||||
} else {
|
||||
a.RequestMode = RequestModeGemini
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,6 +88,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
a.AccountCredentials = *adc
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
@@ -100,6 +106,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
} else {
|
||||
suffix = "generateContent"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
suffix = "predict"
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
@@ -231,6 +242,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||
} else {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
return gemini.GeminiImageHandler(c, info, resp)
|
||||
}
|
||||
usage, err = gemini.GeminiChatHandler(c, info, resp)
|
||||
}
|
||||
case RequestModeLlama:
|
||||
|
||||
@@ -23,6 +23,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
//panic("implement me")
|
||||
|
||||
@@ -17,6 +17,11 @@ type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -220,12 +220,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
|
||||
var zhipuResponse ZhipuResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if !zhipuResponse.Success {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -40,7 +40,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateClaudeRequest(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if textRequest.Stream {
|
||||
@@ -49,18 +49,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
@@ -77,10 +77,9 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
|
||||
if textRequest.MaxTokens == 0 {
|
||||
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
@@ -108,18 +107,41 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
}
|
||||
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(jsonData))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(jsonData))
|
||||
}
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
|
||||
@@ -88,6 +88,7 @@ type RelayInfo struct {
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
||||
IsModelMapped bool
|
||||
ClientWs *websocket.Conn
|
||||
TargetWs *websocket.Conn
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
if common.DebugEnabled {
|
||||
@@ -27,7 +27,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
var xinRerankResponse xinference.XinRerankResponse
|
||||
err = common.Unmarshal(responseBody, &xinRerankResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
|
||||
for i, result := range xinRerankResponse.Results {
|
||||
@@ -62,7 +62,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
} else {
|
||||
err = common.Unmarshal(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
|
||||
}
|
||||
|
||||
@@ -41,17 +41,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||
@@ -59,7 +59,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -74,18 +74,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -2,9 +2,9 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
|
||||
request := &gemini.GeminiChatRequest{}
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
|
||||
request := &dto.GeminiChatRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -44,7 +44,7 @@ func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
// }
|
||||
}
|
||||
|
||||
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
|
||||
func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
|
||||
var inputTexts []string
|
||||
for _, content := range textRequest.Contents {
|
||||
for _, part := range content.Parts {
|
||||
@@ -61,7 +61,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
// 计算输入 token 数量
|
||||
var inputTexts []string
|
||||
for _, content := range req.Contents {
|
||||
@@ -78,9 +78,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
|
||||
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
|
||||
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
|
||||
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0
|
||||
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget == 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -109,7 +109,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
req, err := getAndValidateGeminiRequest(c)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||
@@ -121,14 +121,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
sensitiveWords, err := checkGeminiInputSensitive(req)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
// model mapped 模型映射
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
@@ -159,7 +159,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// pre consume quota
|
||||
@@ -175,7 +175,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
adaptor.Init(relayInfo)
|
||||
@@ -194,19 +194,47 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewReader(body)
|
||||
} else {
|
||||
// 使用 ConvertGeminiRequest 转换请求格式
|
||||
convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("Gemini request body: %s", string(jsonData))
|
||||
}
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("Gemini request body: %s", string(requestBody))
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -139,6 +139,24 @@ func GetLocalRealtimeID(c *gin.Context) string {
|
||||
return fmt.Sprintf("evt_%s", logID)
|
||||
}
|
||||
|
||||
func GenerateStartEmptyResponse(id string, createAt int64, model string, systemFingerprint *string) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: createAt,
|
||||
Model: model,
|
||||
SystemFingerprint: systemFingerprint,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{
|
||||
{
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Role: "assistant",
|
||||
Content: common.GetPointer(""),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
|
||||
@@ -54,7 +54,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
)
|
||||
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
pingEnabled := generalSettings.PingIntervalEnabled
|
||||
pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
if pingInterval <= 0 {
|
||||
pingInterval = DefaultPingInterval
|
||||
@@ -234,6 +234,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// done, 处理完成标志,直接退出停止读取剩余数据防止出错
|
||||
if common.DebugEnabled {
|
||||
println("received [DONE], stopping scanner")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
@@ -114,17 +115,17 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
var preConsumedQuota int
|
||||
var quota int
|
||||
@@ -172,37 +173,58 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
||||
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if userQuota-quota < 0 {
|
||||
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota)
|
||||
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
var requestBody io.Reader
|
||||
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||
requestBody = convertedRequest.(io.Reader)
|
||||
} else {
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||
requestBody = convertedRequest.(io.Reader)
|
||||
} else {
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("image request body: %s", requestBody))
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("image request body: %s", string(jsonData)))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -2,7 +2,6 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -91,9 +90,8 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
@@ -104,13 +102,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||||
@@ -122,14 +120,14 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
promptTokens, err = getPromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
@@ -166,25 +164,49 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(body))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
|
||||
if relayInfo.ChannelSetting.SystemPrompt != "" {
|
||||
// 如果有系统提示,则将其添加到请求中
|
||||
request := convertedRequest.(*dto.GeneralOpenAIRequest)
|
||||
containSystemPrompt := false
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
containSystemPrompt = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !containSystemPrompt {
|
||||
// 如果没有系统提示,则添加系统提示
|
||||
systemMessage := dto.Message{
|
||||
Role: request.GetSystemRoleName(),
|
||||
Content: relayInfo.ChannelSetting.SystemPrompt,
|
||||
}
|
||||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
@@ -196,7 +218,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,7 +230,6 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
@@ -281,13 +302,13 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
|
||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if userQuota <= 0 {
|
||||
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
|
||||
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
|
||||
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
relayInfo.UserQuota = userQuota
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
@@ -311,11 +332,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
if preConsumedQuota > 0 {
|
||||
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden)
|
||||
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError)
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
return preConsumedQuota, userQuota, nil
|
||||
@@ -494,6 +515,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||
} else {
|
||||
if !ratio.IsZero() && quota == 0 {
|
||||
quota = 1
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
taskjimeng "one-api/relay/channel/task/jimeng"
|
||||
"one-api/relay/channel/task/kling"
|
||||
"one-api/relay/channel/task/suno"
|
||||
taskVidu "one-api/relay/channel/task/vidu"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/vertex"
|
||||
"one-api/relay/channel/volcengine"
|
||||
@@ -122,6 +123,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
return &kling.TaskAdaptor{}
|
||||
case constant.ChannelTypeJimeng:
|
||||
return &taskjimeng.TaskAdaptor{}
|
||||
case constant.ChannelTypeVidu:
|
||||
return &taskVidu.TaskAdaptor{}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -3,12 +3,14 @@ package relay
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -29,21 +31,21 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
|
||||
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
|
||||
|
||||
if rerankRequest.Query == "" {
|
||||
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if len(rerankRequest.Documents) == 0 {
|
||||
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptToken := getRerankPromptToken(*rerankRequest)
|
||||
@@ -51,7 +53,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -66,22 +68,46 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("Rerank request body: %s", requestBody.String()))
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("Rerank request body: %s", string(jsonData)))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
|
||||
@@ -51,7 +51,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
req, err := getAndValidateResponsesRequest(c)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
|
||||
@@ -60,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
sensitiveWords, err := checkInputSensitive(req, relayInfo)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
@@ -79,7 +79,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
// pre consume quota
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -93,38 +93,38 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}()
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed)
|
||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
err = json.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = json.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,12 +24,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr
|
||||
|
||||
err := helper.ModelMappedHelper(c, relayInfo, nil)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
@@ -46,7 +46,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
//var requestBody io.Reader
|
||||
|
||||
@@ -45,7 +45,7 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
|
||||
if types.IsChannelError(err) {
|
||||
return true
|
||||
}
|
||||
if types.IsLocalError(err) {
|
||||
if types.IsSkipRetryError(err) {
|
||||
return false
|
||||
}
|
||||
if err.StatusCode == http.StatusUnauthorized {
|
||||
|
||||
@@ -188,28 +188,6 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
|
||||
return &openAIRequest, nil
|
||||
}
|
||||
|
||||
func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
|
||||
claudeError := dto.ClaudeError{
|
||||
Type: "new_api_error",
|
||||
Message: openAIError.Error.Message,
|
||||
}
|
||||
return &dto.ClaudeErrorWithStatusCode{
|
||||
Error: claudeError,
|
||||
StatusCode: openAIError.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: claudeError.Error.Message,
|
||||
Type: "new_api_error",
|
||||
}
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: openAIError,
|
||||
StatusCode: claudeError.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func generateStopBlock(index int) *dto.ClaudeResponse {
|
||||
return &dto.ClaudeResponse{
|
||||
Type: "content_block_stop",
|
||||
@@ -251,22 +229,54 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
resp.SetIndex(0)
|
||||
claudeResponses = append(claudeResponses, resp)
|
||||
} else {
|
||||
//resp := &dto.ClaudeResponse{
|
||||
// Type: "content_block_start",
|
||||
// ContentBlock: &dto.ClaudeMediaMessage{
|
||||
// Type: "text",
|
||||
// Text: common.GetPointer[string](""),
|
||||
// },
|
||||
//}
|
||||
//resp.SetIndex(0)
|
||||
//claudeResponses = append(claudeResponses, resp)
|
||||
|
||||
}
|
||||
// 判断首个响应是否存在内容(非标准的 OpenAI 响应)
|
||||
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()),
|
||||
},
|
||||
})
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||
}
|
||||
return claudeResponses
|
||||
}
|
||||
|
||||
if len(openAIResponse.Choices) == 0 {
|
||||
// no choices
|
||||
// TODO: handle this case
|
||||
// 可能为非标准的 OpenAI 响应,判断是否已经完成
|
||||
if info.Done {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
},
|
||||
})
|
||||
}
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_stop",
|
||||
})
|
||||
}
|
||||
return claudeResponses
|
||||
} else {
|
||||
chosenChoice := openAIResponse.Choices[0]
|
||||
@@ -438,3 +448,353 @@ func toJSONString(v interface{}) string {
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||||
openaiRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: info.UpstreamModelName,
|
||||
Stream: info.IsStream,
|
||||
}
|
||||
|
||||
// 转换 messages
|
||||
var messages []dto.Message
|
||||
for _, content := range geminiRequest.Contents {
|
||||
message := dto.Message{
|
||||
Role: convertGeminiRoleToOpenAI(content.Role),
|
||||
}
|
||||
|
||||
// 处理 parts
|
||||
var mediaContents []dto.MediaContent
|
||||
var toolCalls []dto.ToolCallRequest
|
||||
for _, part := range content.Parts {
|
||||
if part.Text != "" {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "text",
|
||||
Text: part.Text,
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.InlineData != nil {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "image_url",
|
||||
ImageUrl: &dto.MessageImageUrl{
|
||||
Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
|
||||
Detail: "auto",
|
||||
MimeType: part.InlineData.MimeType,
|
||||
},
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.FileData != nil {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "image_url",
|
||||
ImageUrl: &dto.MessageImageUrl{
|
||||
Url: part.FileData.FileUri,
|
||||
Detail: "auto",
|
||||
MimeType: part.FileData.MimeType,
|
||||
},
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.FunctionCall != nil {
|
||||
// 处理 Gemini 的工具调用
|
||||
toolCall := dto.ToolCallRequest{
|
||||
ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: part.FunctionCall.FunctionName,
|
||||
Arguments: toJSONString(part.FunctionCall.Arguments),
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
} else if part.FunctionResponse != nil {
|
||||
// 处理 Gemini 的工具响应,创建单独的 tool 消息
|
||||
toolMessage := dto.Message{
|
||||
Role: "tool",
|
||||
ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
|
||||
}
|
||||
toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
|
||||
messages = append(messages, toolMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置消息内容
|
||||
if len(toolCalls) > 0 {
|
||||
// 如果有工具调用,设置工具调用
|
||||
message.SetToolCalls(toolCalls)
|
||||
} else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
|
||||
// 如果只有一个文本内容,直接设置字符串
|
||||
message.Content = mediaContents[0].Text
|
||||
} else if len(mediaContents) > 0 {
|
||||
// 如果有多个内容或包含媒体,设置为数组
|
||||
message.SetMediaContent(mediaContents)
|
||||
}
|
||||
|
||||
// 只有当消息有内容或工具调用时才添加
|
||||
if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
|
||||
messages = append(messages, message)
|
||||
}
|
||||
}
|
||||
|
||||
openaiRequest.Messages = messages
|
||||
|
||||
if geminiRequest.GenerationConfig.Temperature != nil {
|
||||
openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
|
||||
}
|
||||
if geminiRequest.GenerationConfig.TopP > 0 {
|
||||
openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
|
||||
}
|
||||
if geminiRequest.GenerationConfig.TopK > 0 {
|
||||
openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
|
||||
}
|
||||
// gemini stop sequences 最多 5 个,openai stop 最多 4 个
|
||||
if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
|
||||
openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
|
||||
}
|
||||
if geminiRequest.GenerationConfig.CandidateCount > 0 {
|
||||
openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
|
||||
}
|
||||
|
||||
// 转换工具调用
|
||||
if len(geminiRequest.Tools) > 0 {
|
||||
var tools []dto.ToolCallRequest
|
||||
for _, tool := range geminiRequest.Tools {
|
||||
if tool.FunctionDeclarations != nil {
|
||||
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
||||
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
||||
if ok {
|
||||
for _, function := range functionDeclarations {
|
||||
openAITool := dto.ToolCallRequest{
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: function.Name,
|
||||
Description: function.Description,
|
||||
Parameters: function.Parameters,
|
||||
},
|
||||
}
|
||||
tools = append(tools, openAITool)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
openaiRequest.Tools = tools
|
||||
}
|
||||
}
|
||||
|
||||
// gemini system instructions
|
||||
if geminiRequest.SystemInstructions != nil {
|
||||
// 将系统指令作为第一条消息插入
|
||||
systemMessage := dto.Message{
|
||||
Role: "system",
|
||||
Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
|
||||
}
|
||||
openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
|
||||
}
|
||||
|
||||
return openaiRequest, nil
|
||||
}
|
||||
|
||||
func convertGeminiRoleToOpenAI(geminiRole string) string {
|
||||
switch geminiRole {
|
||||
case "user":
|
||||
return "user"
|
||||
case "model":
|
||||
return "assistant"
|
||||
case "function":
|
||||
return "function"
|
||||
default:
|
||||
return "user"
|
||||
}
|
||||
}
|
||||
|
||||
func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
|
||||
var texts []string
|
||||
for _, part := range parts {
|
||||
if part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
}
|
||||
|
||||
// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
|
||||
func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: openAIResponse.PromptTokens,
|
||||
CandidatesTokenCount: openAIResponse.CompletionTokens,
|
||||
TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
}
|
||||
|
||||
// 设置结束原因
|
||||
var finishReason string
|
||||
switch choice.FinishReason {
|
||||
case "stop":
|
||||
finishReason = "STOP"
|
||||
case "length":
|
||||
finishReason = "MAX_TOKENS"
|
||||
case "content_filter":
|
||||
finishReason = "SAFETY"
|
||||
case "tool_calls":
|
||||
finishReason = "STOP"
|
||||
default:
|
||||
finishReason = "STOP"
|
||||
}
|
||||
candidate.FinishReason = &finishReason
|
||||
|
||||
// 转换消息内容
|
||||
content := dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: make([]dto.GeminiPart, 0),
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
toolCalls := choice.Message.ParseToolCalls()
|
||||
if len(toolCalls) > 0 {
|
||||
for _, toolCall := range toolCalls {
|
||||
// 解析参数
|
||||
var args map[string]interface{}
|
||||
if toolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||||
}
|
||||
} else {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
|
||||
part := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: toolCall.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
} else {
|
||||
// 处理文本内容
|
||||
textContent := choice.Message.StringContent()
|
||||
if textContent != "" {
|
||||
part := dto.GeminiPart{
|
||||
Text: textContent,
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
}
|
||||
|
||||
candidate.Content = content
|
||||
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||||
}
|
||||
|
||||
return geminiResponse
|
||||
}
|
||||
|
||||
// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
|
||||
func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||||
// 检查是否有实际内容或结束标志
|
||||
hasContent := false
|
||||
hasFinishReason := false
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
|
||||
hasContent = true
|
||||
}
|
||||
if choice.FinishReason != nil {
|
||||
hasFinishReason = true
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
|
||||
if !hasContent && !hasFinishReason {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: info.PromptTokens,
|
||||
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
|
||||
TotalTokenCount: info.PromptTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
}
|
||||
|
||||
// 设置结束原因
|
||||
if choice.FinishReason != nil {
|
||||
var finishReason string
|
||||
switch *choice.FinishReason {
|
||||
case "stop":
|
||||
finishReason = "STOP"
|
||||
case "length":
|
||||
finishReason = "MAX_TOKENS"
|
||||
case "content_filter":
|
||||
finishReason = "SAFETY"
|
||||
case "tool_calls":
|
||||
finishReason = "STOP"
|
||||
default:
|
||||
finishReason = "STOP"
|
||||
}
|
||||
candidate.FinishReason = &finishReason
|
||||
}
|
||||
|
||||
// 转换消息内容
|
||||
content := dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: make([]dto.GeminiPart, 0),
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
for _, toolCall := range choice.Delta.ToolCalls {
|
||||
// 解析参数
|
||||
var args map[string]interface{}
|
||||
if toolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||||
}
|
||||
} else {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
|
||||
part := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: toolCall.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
} else {
|
||||
// 处理文本内容
|
||||
textContent := choice.Delta.GetContentString()
|
||||
if textContent != "" {
|
||||
part := dto.GeminiPart{
|
||||
Text: textContent,
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
}
|
||||
|
||||
candidate.Content = content
|
||||
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||||
}
|
||||
|
||||
return geminiResponse
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -63,7 +62,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError
|
||||
text = "请求上游地址失败"
|
||||
}
|
||||
}
|
||||
claudeError := dto.ClaudeError{
|
||||
claudeError := types.ClaudeError{
|
||||
Message: text,
|
||||
Type: "new_api_error",
|
||||
}
|
||||
@@ -80,10 +79,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
||||
newApiErr = &types.NewAPIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ErrorType: types.ErrorTypeOpenAIError,
|
||||
}
|
||||
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -105,8 +101,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
|
||||
// General format error (OpenAI, Anthropic, Gemini, etc.)
|
||||
newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode)
|
||||
} else {
|
||||
newApiErr = types.NewErrorWithStatusCode(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
newApiErr.ErrorType = types.ErrorTypeOpenAIError
|
||||
newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -116,7 +111,7 @@ func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string)
|
||||
return
|
||||
}
|
||||
statusCodeMapping := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
|
||||
err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package setting
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
)
|
||||
@@ -58,6 +59,9 @@ func CheckModelRequestRateLimitGroup(jsonStr string) error {
|
||||
if limits[0] < 0 || limits[1] < 1 {
|
||||
return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1])
|
||||
}
|
||||
if limits[0] > math.MaxInt32 || limits[1] > math.MaxInt32 {
|
||||
return fmt.Errorf("group %s [%d, %d] has max rate limits value 2147483647", group, limits[0], limits[1])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
147
types/error.go
147
types/error.go
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -15,8 +16,8 @@ type OpenAIError struct {
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type ErrorType string
|
||||
@@ -28,6 +29,7 @@ const (
|
||||
ErrorTypeMidjourneyError ErrorType = "midjourney_error"
|
||||
ErrorTypeGeminiError ErrorType = "gemini_error"
|
||||
ErrorTypeRerankError ErrorType = "rerank_error"
|
||||
ErrorTypeUpstreamError ErrorType = "upstream_error"
|
||||
)
|
||||
|
||||
type ErrorCode string
|
||||
@@ -62,6 +64,7 @@ const (
|
||||
ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"
|
||||
ErrorCodeBadResponse ErrorCode = "bad_response"
|
||||
ErrorCodeBadResponseBody ErrorCode = "bad_response_body"
|
||||
ErrorCodeEmptyResponse ErrorCode = "empty_response"
|
||||
|
||||
// sql error
|
||||
ErrorCodeQueryDataError ErrorCode = "query_data_error"
|
||||
@@ -73,11 +76,13 @@ const (
|
||||
)
|
||||
|
||||
type NewAPIError struct {
|
||||
Err error
|
||||
RelayError any
|
||||
ErrorType ErrorType
|
||||
errorCode ErrorCode
|
||||
StatusCode int
|
||||
Err error
|
||||
RelayError any
|
||||
skipRetry bool
|
||||
recordErrorLog *bool
|
||||
errorType ErrorType
|
||||
errorCode ErrorCode
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (e *NewAPIError) GetErrorCode() ErrorCode {
|
||||
@@ -87,6 +92,13 @@ func (e *NewAPIError) GetErrorCode() ErrorCode {
|
||||
return e.errorCode
|
||||
}
|
||||
|
||||
func (e *NewAPIError) GetErrorType() ErrorType {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.errorType
|
||||
}
|
||||
|
||||
func (e *NewAPIError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
@@ -98,19 +110,30 @@ func (e *NewAPIError) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *NewAPIError) MaskSensitiveError() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.Err == nil {
|
||||
return string(e.errorCode)
|
||||
}
|
||||
return common.MaskSensitiveInfo(e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *NewAPIError) SetMessage(message string) {
|
||||
e.Err = errors.New(message)
|
||||
}
|
||||
|
||||
func (e *NewAPIError) ToOpenAIError() OpenAIError {
|
||||
switch e.ErrorType {
|
||||
var result OpenAIError
|
||||
switch e.errorType {
|
||||
case ErrorTypeOpenAIError:
|
||||
if openAIError, ok := e.RelayError.(OpenAIError); ok {
|
||||
return openAIError
|
||||
result = openAIError
|
||||
}
|
||||
case ErrorTypeClaudeError:
|
||||
if claudeError, ok := e.RelayError.(ClaudeError); ok {
|
||||
return OpenAIError{
|
||||
result = OpenAIError{
|
||||
Message: e.Error(),
|
||||
Type: claudeError.Type,
|
||||
Param: "",
|
||||
@@ -118,85 +141,122 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
|
||||
}
|
||||
}
|
||||
}
|
||||
return OpenAIError{
|
||||
result = OpenAIError{
|
||||
Message: e.Error(),
|
||||
Type: string(e.ErrorType),
|
||||
Type: string(e.errorType),
|
||||
Param: "",
|
||||
Code: e.errorCode,
|
||||
}
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
return result
|
||||
}
|
||||
|
||||
func (e *NewAPIError) ToClaudeError() ClaudeError {
|
||||
switch e.ErrorType {
|
||||
var result ClaudeError
|
||||
switch e.errorType {
|
||||
case ErrorTypeOpenAIError:
|
||||
openAIError := e.RelayError.(OpenAIError)
|
||||
return ClaudeError{
|
||||
result = ClaudeError{
|
||||
Message: e.Error(),
|
||||
Type: fmt.Sprintf("%v", openAIError.Code),
|
||||
}
|
||||
case ErrorTypeClaudeError:
|
||||
return e.RelayError.(ClaudeError)
|
||||
result = e.RelayError.(ClaudeError)
|
||||
default:
|
||||
return ClaudeError{
|
||||
result = ClaudeError{
|
||||
Message: e.Error(),
|
||||
Type: string(e.ErrorType),
|
||||
Type: string(e.errorType),
|
||||
}
|
||||
}
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
return result
|
||||
}
|
||||
|
||||
func NewError(err error, errorCode ErrorCode) *NewAPIError {
|
||||
return &NewAPIError{
|
||||
type NewAPIErrorOptions func(*NewAPIError)
|
||||
|
||||
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
e := &NewAPIError{
|
||||
Err: err,
|
||||
RelayError: nil,
|
||||
ErrorType: ErrorTypeNewAPIError,
|
||||
errorType: ErrorTypeNewAPIError,
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
errorCode: errorCode,
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(e)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
|
||||
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
openaiError := OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: string(errorCode),
|
||||
}
|
||||
return WithOpenAIError(openaiError, statusCode)
|
||||
return WithOpenAIError(openaiError, statusCode, ops...)
|
||||
}
|
||||
|
||||
func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
|
||||
return &NewAPIError{
|
||||
func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
openaiError := OpenAIError{
|
||||
Type: string(errorCode),
|
||||
}
|
||||
return WithOpenAIError(openaiError, statusCode, ops...)
|
||||
}
|
||||
|
||||
func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
e := &NewAPIError{
|
||||
Err: err,
|
||||
RelayError: OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: string(errorCode),
|
||||
},
|
||||
ErrorType: ErrorTypeNewAPIError,
|
||||
errorType: ErrorTypeNewAPIError,
|
||||
StatusCode: statusCode,
|
||||
errorCode: errorCode,
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(e)
|
||||
}
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError {
|
||||
func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
code, ok := openAIError.Code.(string)
|
||||
if !ok {
|
||||
code = fmt.Sprintf("%v", openAIError.Code)
|
||||
}
|
||||
return &NewAPIError{
|
||||
if openAIError.Type == "" {
|
||||
openAIError.Type = "upstream_error"
|
||||
}
|
||||
e := &NewAPIError{
|
||||
RelayError: openAIError,
|
||||
ErrorType: ErrorTypeOpenAIError,
|
||||
errorType: ErrorTypeOpenAIError,
|
||||
StatusCode: statusCode,
|
||||
Err: errors.New(openAIError.Message),
|
||||
errorCode: ErrorCode(code),
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(e)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError {
|
||||
return &NewAPIError{
|
||||
func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
if claudeError.Type == "" {
|
||||
claudeError.Type = "upstream_error"
|
||||
}
|
||||
e := &NewAPIError{
|
||||
RelayError: claudeError,
|
||||
ErrorType: ErrorTypeClaudeError,
|
||||
errorType: ErrorTypeClaudeError,
|
||||
StatusCode: statusCode,
|
||||
Err: errors.New(claudeError.Message),
|
||||
errorCode: ErrorCode(claudeError.Type),
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(e)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func IsChannelError(err *NewAPIError) bool {
|
||||
@@ -206,10 +266,33 @@ func IsChannelError(err *NewAPIError) bool {
|
||||
return strings.HasPrefix(string(err.errorCode), "channel:")
|
||||
}
|
||||
|
||||
func IsLocalError(err *NewAPIError) bool {
|
||||
func IsSkipRetryError(err *NewAPIError) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return err.ErrorType == ErrorTypeNewAPIError
|
||||
return err.skipRetry
|
||||
}
|
||||
|
||||
func ErrOptionWithSkipRetry() NewAPIErrorOptions {
|
||||
return func(e *NewAPIError) {
|
||||
e.skipRetry = true
|
||||
}
|
||||
}
|
||||
|
||||
func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
|
||||
return func(e *NewAPIError) {
|
||||
e.recordErrorLog = common.GetPointer(false)
|
||||
}
|
||||
}
|
||||
|
||||
func IsRecordErrorLog(e *NewAPIError) bool {
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
if e.recordErrorLog == nil {
|
||||
// default to true if not set
|
||||
return true
|
||||
}
|
||||
return *e.recordErrorLog
|
||||
}
|
||||
|
||||
622
web/src/components/common/JSONEditor.js
Normal file
622
web/src/components/common/JSONEditor.js
Normal file
@@ -0,0 +1,622 @@
|
||||
import React, { useState, useEffect, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
Space,
|
||||
Button,
|
||||
Form,
|
||||
Card,
|
||||
Typography,
|
||||
Banner,
|
||||
Row,
|
||||
Col,
|
||||
InputNumber,
|
||||
Switch,
|
||||
Select,
|
||||
Input,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IconCode,
|
||||
IconEdit,
|
||||
IconPlus,
|
||||
IconDelete,
|
||||
IconSetting,
|
||||
} from '@douyinfe/semi-icons';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const JSONEditor = ({
|
||||
value = '',
|
||||
onChange,
|
||||
field,
|
||||
label,
|
||||
placeholder,
|
||||
extraText,
|
||||
showClear = true,
|
||||
template,
|
||||
templateLabel,
|
||||
editorType = 'keyValue', // keyValue, object, region
|
||||
autosize = true,
|
||||
rules = [],
|
||||
formApi = null,
|
||||
...props
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
// 初始化JSON数据
|
||||
const [jsonData, setJsonData] = useState(() => {
|
||||
// 初始化时解析JSON数据
|
||||
if (value && value.trim()) {
|
||||
try {
|
||||
const parsed = JSON.parse(value);
|
||||
return parsed;
|
||||
} catch (error) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
return {};
|
||||
});
|
||||
|
||||
// 根据键数量决定默认编辑模式
|
||||
const [editMode, setEditMode] = useState(() => {
|
||||
// 如果初始JSON数据的键数量大于10个,则默认使用手动模式
|
||||
if (value && value.trim()) {
|
||||
try {
|
||||
const parsed = JSON.parse(value);
|
||||
const keyCount = Object.keys(parsed).length;
|
||||
return keyCount > 10 ? 'manual' : 'visual';
|
||||
} catch (error) {
|
||||
// JSON无效时默认显示手动编辑模式
|
||||
return 'manual';
|
||||
}
|
||||
}
|
||||
return 'visual';
|
||||
});
|
||||
const [jsonError, setJsonError] = useState('');
|
||||
|
||||
// 数据同步 - 当value变化时总是更新jsonData(如果JSON有效)
|
||||
useEffect(() => {
|
||||
try {
|
||||
const parsed = value && value.trim() ? JSON.parse(value) : {};
|
||||
setJsonData(parsed);
|
||||
setJsonError('');
|
||||
} catch (error) {
|
||||
console.log('JSON解析失败:', error.message);
|
||||
setJsonError(error.message);
|
||||
// JSON格式错误时不更新jsonData
|
||||
}
|
||||
}, [value]);
|
||||
|
||||
|
||||
// 处理可视化编辑的数据变化
|
||||
const handleVisualChange = useCallback((newData) => {
|
||||
setJsonData(newData);
|
||||
setJsonError('');
|
||||
const jsonString = Object.keys(newData).length === 0 ? '' : JSON.stringify(newData, null, 2);
|
||||
|
||||
// 通过formApi设置值(如果提供的话)
|
||||
if (formApi && field) {
|
||||
formApi.setValue(field, jsonString);
|
||||
}
|
||||
|
||||
onChange?.(jsonString);
|
||||
}, [onChange, formApi, field]);
|
||||
|
||||
// 处理手动编辑的数据变化
|
||||
const handleManualChange = useCallback((newValue) => {
|
||||
onChange?.(newValue);
|
||||
// 验证JSON格式
|
||||
if (newValue && newValue.trim()) {
|
||||
try {
|
||||
const parsed = JSON.parse(newValue);
|
||||
setJsonError('');
|
||||
// 预先准备可视化数据,但不立即应用
|
||||
// 这样切换到可视化模式时数据已经准备好了
|
||||
} catch (error) {
|
||||
setJsonError(error.message);
|
||||
}
|
||||
} else {
|
||||
setJsonError('');
|
||||
}
|
||||
}, [onChange]);
|
||||
|
||||
// 切换编辑模式
|
||||
const toggleEditMode = useCallback(() => {
|
||||
if (editMode === 'visual') {
|
||||
// 从可视化模式切换到手动模式
|
||||
setEditMode('manual');
|
||||
} else {
|
||||
// 从手动模式切换到可视化模式,需要验证JSON
|
||||
try {
|
||||
const parsed = value && value.trim() ? JSON.parse(value) : {};
|
||||
setJsonData(parsed);
|
||||
setJsonError('');
|
||||
setEditMode('visual');
|
||||
} catch (error) {
|
||||
setJsonError(error.message);
|
||||
// JSON格式错误时不切换模式
|
||||
return;
|
||||
}
|
||||
}
|
||||
}, [editMode, value]);
|
||||
|
||||
// 添加键值对
|
||||
const addKeyValue = useCallback(() => {
|
||||
const newData = { ...jsonData };
|
||||
const keys = Object.keys(newData);
|
||||
let newKey = 'key';
|
||||
let counter = 1;
|
||||
while (newData.hasOwnProperty(newKey)) {
|
||||
newKey = `key${counter}`;
|
||||
counter++;
|
||||
}
|
||||
newData[newKey] = '';
|
||||
handleVisualChange(newData);
|
||||
}, [jsonData, handleVisualChange]);
|
||||
|
||||
// 删除键值对
|
||||
const removeKeyValue = useCallback((keyToRemove) => {
|
||||
const newData = { ...jsonData };
|
||||
delete newData[keyToRemove];
|
||||
handleVisualChange(newData);
|
||||
}, [jsonData, handleVisualChange]);
|
||||
|
||||
// 更新键名
|
||||
const updateKey = useCallback((oldKey, newKey) => {
|
||||
if (oldKey === newKey) return;
|
||||
const newData = { ...jsonData };
|
||||
const value = newData[oldKey];
|
||||
delete newData[oldKey];
|
||||
newData[newKey] = value;
|
||||
handleVisualChange(newData);
|
||||
}, [jsonData, handleVisualChange]);
|
||||
|
||||
// 更新值
|
||||
const updateValue = useCallback((key, newValue) => {
|
||||
const newData = { ...jsonData };
|
||||
newData[key] = newValue;
|
||||
handleVisualChange(newData);
|
||||
}, [jsonData, handleVisualChange]);
|
||||
|
||||
// 填入模板
|
||||
const fillTemplate = useCallback(() => {
|
||||
if (template) {
|
||||
const templateString = JSON.stringify(template, null, 2);
|
||||
|
||||
// 通过formApi设置值(如果提供的话)
|
||||
if (formApi && field) {
|
||||
formApi.setValue(field, templateString);
|
||||
}
|
||||
|
||||
// 无论哪种模式都要更新值
|
||||
onChange?.(templateString);
|
||||
|
||||
// 如果是可视化模式,同时更新jsonData
|
||||
if (editMode === 'visual') {
|
||||
setJsonData(template);
|
||||
}
|
||||
|
||||
// 清除错误状态
|
||||
setJsonError('');
|
||||
}
|
||||
}, [template, onChange, editMode, formApi, field]);
|
||||
|
||||
// 渲染键值对编辑器
|
||||
const renderKeyValueEditor = () => {
|
||||
if (typeof jsonData !== 'object' || jsonData === null) {
|
||||
return (
|
||||
<div className="text-center py-6 px-4">
|
||||
<div className="text-gray-400 mb-2">
|
||||
<IconCode size={32} />
|
||||
</div>
|
||||
<Text type="tertiary" className="text-gray-500 text-sm">
|
||||
{t('无效的JSON数据,请检查格式')}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const entries = Object.entries(jsonData);
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
{entries.length === 0 && (
|
||||
<div className="text-center py-6 px-4">
|
||||
<div className="text-gray-400 mb-2">
|
||||
<IconCode size={32} />
|
||||
</div>
|
||||
<Text type="tertiary" className="text-gray-500 text-sm">
|
||||
{t('暂无数据,点击下方按钮添加键值对')}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{entries.map(([key, value], index) => (
|
||||
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
|
||||
<Row gutter={12} align="middle">
|
||||
<Col span={10}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('键名')}</Text>
|
||||
<Input
|
||||
placeholder={t('键名')}
|
||||
value={key}
|
||||
onChange={(newKey) => updateKey(key, newKey)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={11}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('值')}</Text>
|
||||
<Input
|
||||
placeholder={t('值')}
|
||||
value={value}
|
||||
onChange={(newValue) => updateValue(key, newValue)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={3}>
|
||||
<div className="flex justify-center pt-4">
|
||||
<Button
|
||||
icon={<IconDelete />}
|
||||
type="danger"
|
||||
theme="borderless"
|
||||
size="small"
|
||||
onClick={() => removeKeyValue(key)}
|
||||
className="hover:bg-red-50"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
</Row>
|
||||
</Card>
|
||||
))}
|
||||
|
||||
<div className="flex justify-center pt-1">
|
||||
<Button
|
||||
icon={<IconPlus />}
|
||||
onClick={addKeyValue}
|
||||
size="small"
|
||||
theme="solid"
|
||||
type="primary"
|
||||
className="shadow-sm hover:shadow-md transition-shadow px-4"
|
||||
>
|
||||
{t('添加键值对')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 渲染对象编辑器(用于复杂JSON)
|
||||
const renderObjectEditor = () => {
|
||||
const entries = Object.entries(jsonData);
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
{entries.length === 0 && (
|
||||
<div className="text-center py-6 px-4">
|
||||
<div className="text-gray-400 mb-2">
|
||||
<IconSetting size={32} />
|
||||
</div>
|
||||
<Text type="tertiary" className="text-gray-500 text-sm">
|
||||
{t('暂无参数,点击下方按钮添加请求参数')}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{entries.map(([key, value], index) => (
|
||||
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
|
||||
<Row gutter={12} align="middle">
|
||||
<Col span={8}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('参数名')}</Text>
|
||||
<Input
|
||||
placeholder={t('参数名')}
|
||||
value={key}
|
||||
onChange={(newKey) => updateKey(key, newKey)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={13}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('参数值')} ({typeof value})</Text>
|
||||
{renderValueInput(key, value)}
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={3}>
|
||||
<div className="flex justify-center pt-4">
|
||||
<Button
|
||||
icon={<IconDelete />}
|
||||
type="danger"
|
||||
theme="borderless"
|
||||
size="small"
|
||||
onClick={() => removeKeyValue(key)}
|
||||
className="hover:bg-red-50"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
</Row>
|
||||
</Card>
|
||||
))}
|
||||
|
||||
<div className="flex justify-center pt-1">
|
||||
<Button
|
||||
icon={<IconPlus />}
|
||||
onClick={addKeyValue}
|
||||
size="small"
|
||||
theme="solid"
|
||||
type="primary"
|
||||
className="shadow-sm hover:shadow-md transition-shadow px-4"
|
||||
>
|
||||
{t('添加参数')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 渲染参数值输入控件
|
||||
const renderValueInput = (key, value) => {
|
||||
const valueType = typeof value;
|
||||
|
||||
if (valueType === 'boolean') {
|
||||
return (
|
||||
<div className="flex items-center">
|
||||
<Switch
|
||||
checked={value}
|
||||
onChange={(newValue) => updateValue(key, newValue)}
|
||||
size="small"
|
||||
/>
|
||||
<Text type="tertiary" size="small" className="ml-2">
|
||||
{value ? t('true') : t('false')}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (valueType === 'number') {
|
||||
return (
|
||||
<InputNumber
|
||||
value={value}
|
||||
onChange={(newValue) => updateValue(key, newValue)}
|
||||
size="small"
|
||||
style={{ width: '100%' }}
|
||||
step={key === 'temperature' ? 0.1 : 1}
|
||||
precision={key === 'temperature' ? 2 : 0}
|
||||
placeholder={t('输入数字')}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// 字符串类型或其他类型
|
||||
return (
|
||||
<Input
|
||||
placeholder={t('参数值')}
|
||||
value={String(value)}
|
||||
onChange={(newValue) => {
|
||||
// 尝试转换为适当的类型
|
||||
let convertedValue = newValue;
|
||||
if (newValue === 'true') convertedValue = true;
|
||||
else if (newValue === 'false') convertedValue = false;
|
||||
else if (!isNaN(newValue) && newValue !== '' && newValue !== '0') {
|
||||
convertedValue = Number(newValue);
|
||||
}
|
||||
|
||||
updateValue(key, convertedValue);
|
||||
}}
|
||||
size="small"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
// 渲染区域编辑器(特殊格式)
|
||||
const renderRegionEditor = () => {
|
||||
const entries = Object.entries(jsonData);
|
||||
const defaultEntry = entries.find(([key]) => key === 'default');
|
||||
const modelEntries = entries.filter(([key]) => key !== 'default');
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
{/* 默认区域 */}
|
||||
<Card className="!p-2 !border-blue-200 !bg-blue-50">
|
||||
<div className="flex items-center mb-1">
|
||||
<Text strong size="small" className="text-blue-700">{t('默认区域')}</Text>
|
||||
</div>
|
||||
<Input
|
||||
placeholder={t('默认区域,如: us-central1')}
|
||||
value={defaultEntry ? defaultEntry[1] : ''}
|
||||
onChange={(value) => updateValue('default', value)}
|
||||
size="small"
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* 模型专用区域 */}
|
||||
<div className="space-y-1">
|
||||
<Text strong size="small">{t('模型专用区域')}</Text>
|
||||
{modelEntries.map(([modelName, region], index) => (
|
||||
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
|
||||
<Row gutter={12} align="middle">
|
||||
<Col span={10}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('模型名称')}</Text>
|
||||
<Input
|
||||
placeholder={t('模型名称')}
|
||||
value={modelName}
|
||||
onChange={(newKey) => updateKey(modelName, newKey)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={11}>
|
||||
<div className="space-y-1">
|
||||
<Text type="tertiary" size="small">{t('区域')}</Text>
|
||||
<Input
|
||||
placeholder={t('区域')}
|
||||
value={region}
|
||||
onChange={(newValue) => updateValue(modelName, newValue)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={3}>
|
||||
<div className="flex justify-center pt-4">
|
||||
<Button
|
||||
icon={<IconDelete />}
|
||||
type="danger"
|
||||
theme="borderless"
|
||||
size="small"
|
||||
onClick={() => removeKeyValue(modelName)}
|
||||
className="hover:bg-red-50"
|
||||
/>
|
||||
</div>
|
||||
</Col>
|
||||
</Row>
|
||||
</Card>
|
||||
))}
|
||||
|
||||
<div className="flex justify-center pt-1">
|
||||
<Button
|
||||
icon={<IconPlus />}
|
||||
onClick={addKeyValue}
|
||||
size="small"
|
||||
theme="solid"
|
||||
type="primary"
|
||||
className="shadow-sm hover:shadow-md transition-shadow px-4"
|
||||
>
|
||||
{t('添加模型区域')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 渲染可视化编辑器
|
||||
const renderVisualEditor = () => {
|
||||
switch (editorType) {
|
||||
case 'region':
|
||||
return renderRegionEditor();
|
||||
case 'object':
|
||||
return renderObjectEditor();
|
||||
case 'keyValue':
|
||||
default:
|
||||
return renderKeyValueEditor();
|
||||
}
|
||||
};
|
||||
|
||||
const hasJsonError = jsonError && jsonError.trim() !== '';
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
{/* Label统一显示在上方 */}
|
||||
{label && (
|
||||
<div className="flex items-center">
|
||||
<Text className="text-sm font-medium text-gray-900">{label}</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 编辑模式切换 */}
|
||||
<div className="flex items-center justify-between p-2 bg-gray-50 rounded-md">
|
||||
<div className="flex items-center gap-2">
|
||||
{editMode === 'visual' && (
|
||||
<Text type="tertiary" size="small" className="bg-blue-100 text-blue-700 px-2 py-0.5 rounded text-xs">
|
||||
{t('可视化模式')}
|
||||
</Text>
|
||||
)}
|
||||
{editMode === 'manual' && (
|
||||
<Text type="tertiary" size="small" className="bg-green-100 text-green-700 px-2 py-0.5 rounded text-xs">
|
||||
{t('手动编辑模式')}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{template && templateLabel && (
|
||||
<Button
|
||||
size="small"
|
||||
type="tertiary"
|
||||
onClick={fillTemplate}
|
||||
className="!text-semi-color-primary hover:bg-blue-50 text-xs"
|
||||
>
|
||||
{templateLabel}
|
||||
</Button>
|
||||
)}
|
||||
<Space size="tight">
|
||||
<Button
|
||||
size="small"
|
||||
type={editMode === 'visual' ? 'primary' : 'tertiary'}
|
||||
icon={<IconEdit />}
|
||||
onClick={toggleEditMode}
|
||||
disabled={editMode === 'manual' && hasJsonError}
|
||||
className={editMode === 'visual' ? 'shadow-sm' : ''}
|
||||
>
|
||||
{t('可视化')}
|
||||
</Button>
|
||||
<Button
|
||||
size="small"
|
||||
type={editMode === 'manual' ? 'primary' : 'tertiary'}
|
||||
icon={<IconCode />}
|
||||
onClick={toggleEditMode}
|
||||
className={editMode === 'manual' ? 'shadow-sm' : ''}
|
||||
>
|
||||
{t('手动编辑')}
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* JSON错误提示 */}
|
||||
{hasJsonError && (
|
||||
<Banner
|
||||
type="danger"
|
||||
description={`JSON 格式错误: ${jsonError}`}
|
||||
className="!rounded-md text-sm"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* 编辑器内容 */}
|
||||
{editMode === 'visual' ? (
|
||||
<div>
|
||||
<Card className="!p-3 !border-gray-200 !shadow-sm !rounded-md bg-white">
|
||||
{renderVisualEditor()}
|
||||
</Card>
|
||||
{/* 可视化模式下的额外文本显示在下方 */}
|
||||
{extraText && (
|
||||
<div className="text-xs text-gray-600 mt-0.5">
|
||||
{extraText}
|
||||
</div>
|
||||
)}
|
||||
{/* 隐藏的Form字段用于验证和数据绑定 */}
|
||||
<Form.Input
|
||||
field={field}
|
||||
value={value}
|
||||
rules={rules}
|
||||
style={{ display: 'none' }}
|
||||
noLabel={true}
|
||||
{...props}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<Form.TextArea
|
||||
field={field}
|
||||
placeholder={placeholder}
|
||||
value={value}
|
||||
onChange={handleManualChange}
|
||||
showClear={showClear}
|
||||
rows={Math.max(8, value ? value.split('\n').length : 8)}
|
||||
rules={rules}
|
||||
noLabel={true}
|
||||
{...props}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* 额外文本在手动编辑模式下显示 */}
|
||||
{extraText && editMode === 'manual' && (
|
||||
<div className="text-xs text-gray-600">
|
||||
{extraText}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default JSONEditor;
|
||||
@@ -33,7 +33,7 @@ import {
|
||||
Settings,
|
||||
} from 'lucide-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { renderGroupOption, modelSelectFilter } from '../../helpers';
|
||||
import { renderGroupOption, selectFilter } from '../../helpers';
|
||||
import ParameterControl from './ParameterControl';
|
||||
import ImageUrlInput from './ImageUrlInput';
|
||||
import ConfigManager from './ConfigManager';
|
||||
@@ -173,7 +173,7 @@ const SettingsPanel = ({
|
||||
name='model'
|
||||
required
|
||||
selection
|
||||
filter={modelSelectFilter}
|
||||
filter={selectFilter}
|
||||
autoClearSearchValue={false}
|
||||
onChange={(value) => onInputChange('model', value)}
|
||||
value={inputs.model}
|
||||
|
||||
@@ -46,7 +46,9 @@ import {
|
||||
Col,
|
||||
Highlight,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { getChannelModels, copy, getChannelIcon, getModelCategories, modelSelectFilter } from '../../../../helpers';
|
||||
import { getChannelModels, copy, getChannelIcon, getModelCategories, selectFilter } from '../../../../helpers';
|
||||
import ModelSelectModal from './ModelSelectModal';
|
||||
import JSONEditor from '../../../common/JSONEditor';
|
||||
import {
|
||||
IconSave,
|
||||
IconClose,
|
||||
@@ -68,7 +70,9 @@ const STATUS_CODE_MAPPING_EXAMPLE = {
|
||||
};
|
||||
|
||||
const REGION_EXAMPLE = {
|
||||
default: 'us-central1',
|
||||
"default": 'global',
|
||||
"gemini-1.5-pro-002": "europe-west2",
|
||||
"gemini-1.5-flash-002": "europe-west2",
|
||||
'claude-3-5-sonnet-20240620': 'europe-west1',
|
||||
};
|
||||
|
||||
@@ -121,6 +125,12 @@ const EditChannelModal = (props) => {
|
||||
weight: 0,
|
||||
tag: '',
|
||||
multi_key_mode: 'random',
|
||||
// 渠道额外设置的默认值
|
||||
force_format: false,
|
||||
thinking_to_content: false,
|
||||
proxy: '',
|
||||
pass_through_body_enabled: false,
|
||||
system_prompt: '',
|
||||
};
|
||||
const [batch, setBatch] = useState(false);
|
||||
const [multiToSingle, setMultiToSingle] = useState(false);
|
||||
@@ -135,6 +145,8 @@ const EditChannelModal = (props) => {
|
||||
const [customModel, setCustomModel] = useState('');
|
||||
const [modalImageUrl, setModalImageUrl] = useState('');
|
||||
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
|
||||
const [modelModalVisible, setModelModalVisible] = useState(false);
|
||||
const [fetchedModels, setFetchedModels] = useState([]);
|
||||
const formApiRef = useRef(null);
|
||||
const [vertexKeys, setVertexKeys] = useState([]);
|
||||
const [vertexFileList, setVertexFileList] = useState([]);
|
||||
@@ -142,8 +154,70 @@ const EditChannelModal = (props) => {
|
||||
const [isMultiKeyChannel, setIsMultiKeyChannel] = useState(false);
|
||||
const [channelSearchValue, setChannelSearchValue] = useState('');
|
||||
const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式
|
||||
const [keyMode, setKeyMode] = useState('append'); // 密钥模式:replace(覆盖)或 append(追加)
|
||||
// 渠道额外设置状态
|
||||
const [channelSettings, setChannelSettings] = useState({
|
||||
force_format: false,
|
||||
thinking_to_content: false,
|
||||
proxy: '',
|
||||
pass_through_body_enabled: false,
|
||||
system_prompt: '',
|
||||
});
|
||||
const showApiConfigCard = inputs.type !== 45; // 控制是否显示 API 配置卡片(仅当渠道类型不是 豆包 时显示)
|
||||
const getInitValues = () => ({ ...originInputs });
|
||||
|
||||
// 处理渠道额外设置的更新
|
||||
const handleChannelSettingsChange = (key, value) => {
|
||||
// 更新内部状态
|
||||
setChannelSettings(prev => ({ ...prev, [key]: value }));
|
||||
|
||||
// 同步更新到表单字段
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue(key, value);
|
||||
}
|
||||
|
||||
// 同步更新inputs状态
|
||||
setInputs(prev => ({ ...prev, [key]: value }));
|
||||
|
||||
// 生成setting JSON并更新
|
||||
const newSettings = { ...channelSettings, [key]: value };
|
||||
const settingsJson = JSON.stringify(newSettings);
|
||||
handleInputChange('setting', settingsJson);
|
||||
};
|
||||
|
||||
// 解析渠道设置JSON为单独的状态
|
||||
const parseChannelSettings = (settingJson) => {
|
||||
try {
|
||||
if (settingJson && settingJson.trim()) {
|
||||
const parsed = JSON.parse(settingJson);
|
||||
setChannelSettings({
|
||||
force_format: parsed.force_format || false,
|
||||
thinking_to_content: parsed.thinking_to_content || false,
|
||||
proxy: parsed.proxy || '',
|
||||
pass_through_body_enabled: parsed.pass_through_body_enabled || false,
|
||||
system_prompt: parsed.system_prompt || '',
|
||||
});
|
||||
} else {
|
||||
setChannelSettings({
|
||||
force_format: false,
|
||||
thinking_to_content: false,
|
||||
proxy: '',
|
||||
pass_through_body_enabled: false,
|
||||
system_prompt: '',
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('解析渠道设置失败:', error);
|
||||
setChannelSettings({
|
||||
force_format: false,
|
||||
thinking_to_content: false,
|
||||
proxy: '',
|
||||
pass_through_body_enabled: false,
|
||||
system_prompt: '',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handleInputChange = (name, value) => {
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue(name, value);
|
||||
@@ -256,6 +330,30 @@ const EditChannelModal = (props) => {
|
||||
setBatch(false);
|
||||
setMultiToSingle(false);
|
||||
}
|
||||
// 解析渠道额外设置并合并到data中
|
||||
if (data.setting) {
|
||||
try {
|
||||
const parsedSettings = JSON.parse(data.setting);
|
||||
data.force_format = parsedSettings.force_format || false;
|
||||
data.thinking_to_content = parsedSettings.thinking_to_content || false;
|
||||
data.proxy = parsedSettings.proxy || '';
|
||||
data.pass_through_body_enabled = parsedSettings.pass_through_body_enabled || false;
|
||||
data.system_prompt = parsedSettings.system_prompt || '';
|
||||
} catch (error) {
|
||||
console.error('解析渠道设置失败:', error);
|
||||
data.force_format = false;
|
||||
data.thinking_to_content = false;
|
||||
data.proxy = '';
|
||||
data.pass_through_body_enabled = false;
|
||||
}
|
||||
} else {
|
||||
data.force_format = false;
|
||||
data.thinking_to_content = false;
|
||||
data.proxy = '';
|
||||
data.pass_through_body_enabled = false;
|
||||
data.system_prompt = '';
|
||||
}
|
||||
|
||||
setInputs(data);
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValues(data);
|
||||
@@ -266,6 +364,14 @@ const EditChannelModal = (props) => {
|
||||
setAutoBan(true);
|
||||
}
|
||||
setBasicModels(getChannelModels(data.type));
|
||||
// 同步更新channelSettings状态显示
|
||||
setChannelSettings({
|
||||
force_format: data.force_format,
|
||||
thinking_to_content: data.thinking_to_content,
|
||||
proxy: data.proxy,
|
||||
pass_through_body_enabled: data.pass_through_body_enabled,
|
||||
system_prompt: data.system_prompt,
|
||||
});
|
||||
// console.log(data);
|
||||
} else {
|
||||
showError(message);
|
||||
@@ -279,7 +385,7 @@ const EditChannelModal = (props) => {
|
||||
// return;
|
||||
// }
|
||||
setLoading(true);
|
||||
const models = inputs['models'] || [];
|
||||
const models = [];
|
||||
let err = false;
|
||||
|
||||
if (isEdit) {
|
||||
@@ -320,8 +426,9 @@ const EditChannelModal = (props) => {
|
||||
}
|
||||
|
||||
if (!err) {
|
||||
handleInputChange(name, Array.from(new Set(models)));
|
||||
showSuccess(t('获取模型列表成功'));
|
||||
const uniqueModels = Array.from(new Set(models));
|
||||
setFetchedModels(uniqueModels);
|
||||
setModelModalVisible(true);
|
||||
} else {
|
||||
showError(t('获取模型列表失败'));
|
||||
}
|
||||
@@ -446,6 +553,20 @@ const EditChannelModal = (props) => {
|
||||
setUseManualInput(false);
|
||||
} else {
|
||||
formApiRef.current?.reset();
|
||||
// 重置渠道设置状态
|
||||
setChannelSettings({
|
||||
force_format: false,
|
||||
thinking_to_content: false,
|
||||
proxy: '',
|
||||
pass_through_body_enabled: false,
|
||||
system_prompt: '',
|
||||
});
|
||||
// 重置密钥模式状态
|
||||
setKeyMode('append');
|
||||
// 清空表单中的key_mode字段
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('key_mode', undefined);
|
||||
}
|
||||
}
|
||||
}, [props.visible, channelId]);
|
||||
|
||||
@@ -579,6 +700,24 @@ const EditChannelModal = (props) => {
|
||||
if (localInputs.type === 18 && localInputs.other === '') {
|
||||
localInputs.other = 'v2.1';
|
||||
}
|
||||
|
||||
// 生成渠道额外设置JSON
|
||||
const channelExtraSettings = {
|
||||
force_format: localInputs.force_format || false,
|
||||
thinking_to_content: localInputs.thinking_to_content || false,
|
||||
proxy: localInputs.proxy || '',
|
||||
pass_through_body_enabled: localInputs.pass_through_body_enabled || false,
|
||||
system_prompt: localInputs.system_prompt || '',
|
||||
};
|
||||
localInputs.setting = JSON.stringify(channelExtraSettings);
|
||||
|
||||
// 清理不需要发送到后端的字段
|
||||
delete localInputs.force_format;
|
||||
delete localInputs.thinking_to_content;
|
||||
delete localInputs.proxy;
|
||||
delete localInputs.pass_through_body_enabled;
|
||||
delete localInputs.system_prompt;
|
||||
|
||||
let res;
|
||||
localInputs.auto_ban = localInputs.auto_ban ? 1 : 0;
|
||||
localInputs.models = localInputs.models.join(',');
|
||||
@@ -593,6 +732,7 @@ const EditChannelModal = (props) => {
|
||||
res = await API.put(`/api/channel/`, {
|
||||
...localInputs,
|
||||
id: parseInt(channelId),
|
||||
key_mode: isMultiKeyChannel ? keyMode : undefined, // 只在多key模式下传递
|
||||
});
|
||||
} else {
|
||||
res = await API.post(`/api/channel/`, {
|
||||
@@ -655,69 +795,73 @@ const EditChannelModal = (props) => {
|
||||
const batchAllowed = !isEdit || isMultiKeyChannel;
|
||||
const batchExtra = batchAllowed ? (
|
||||
<Space>
|
||||
<Checkbox
|
||||
disabled={isEdit}
|
||||
checked={batch}
|
||||
onChange={(e) => {
|
||||
const checked = e.target.checked;
|
||||
{!isEdit && (
|
||||
<Checkbox
|
||||
disabled={isEdit}
|
||||
checked={batch}
|
||||
onChange={(e) => {
|
||||
const checked = e.target.checked;
|
||||
|
||||
if (!checked && vertexFileList.length > 1) {
|
||||
Modal.confirm({
|
||||
title: t('切换为单密钥模式'),
|
||||
content: t('将仅保留第一个密钥文件,其余文件将被移除,是否继续?'),
|
||||
onOk: () => {
|
||||
const firstFile = vertexFileList[0];
|
||||
const firstKey = vertexKeys[0] ? [vertexKeys[0]] : [];
|
||||
if (!checked && vertexFileList.length > 1) {
|
||||
Modal.confirm({
|
||||
title: t('切换为单密钥模式'),
|
||||
content: t('将仅保留第一个密钥文件,其余文件将被移除,是否继续?'),
|
||||
onOk: () => {
|
||||
const firstFile = vertexFileList[0];
|
||||
const firstKey = vertexKeys[0] ? [vertexKeys[0]] : [];
|
||||
|
||||
setVertexFileList([firstFile]);
|
||||
setVertexKeys(firstKey);
|
||||
setVertexFileList([firstFile]);
|
||||
setVertexKeys(firstKey);
|
||||
|
||||
formApiRef.current?.setValue('vertex_files', [firstFile]);
|
||||
setInputs((prev) => ({ ...prev, vertex_files: [firstFile] }));
|
||||
formApiRef.current?.setValue('vertex_files', [firstFile]);
|
||||
setInputs((prev) => ({ ...prev, vertex_files: [firstFile] }));
|
||||
|
||||
setBatch(false);
|
||||
setMultiToSingle(false);
|
||||
setMultiKeyMode('random');
|
||||
},
|
||||
onCancel: () => {
|
||||
setBatch(true);
|
||||
},
|
||||
centered: true,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
setBatch(checked);
|
||||
if (!checked) {
|
||||
setMultiToSingle(false);
|
||||
setMultiKeyMode('random');
|
||||
} else {
|
||||
// 批量模式下禁用手动输入,并清空手动输入的内容
|
||||
setUseManualInput(false);
|
||||
if (inputs.type === 41) {
|
||||
// 清空手动输入的密钥内容
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('key', '');
|
||||
}
|
||||
handleInputChange('key', '');
|
||||
setBatch(false);
|
||||
setMultiToSingle(false);
|
||||
setMultiKeyMode('random');
|
||||
},
|
||||
onCancel: () => {
|
||||
setBatch(true);
|
||||
},
|
||||
centered: true,
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}}
|
||||
>{t('批量创建')}</Checkbox>
|
||||
{/*{batch && (*/}
|
||||
{/* <Checkbox disabled={isEdit} checked={multiToSingle} onChange={() => {*/}
|
||||
{/* setMultiToSingle(prev => !prev);*/}
|
||||
{/* setInputs(prev => {*/}
|
||||
{/* const newInputs = { ...prev };*/}
|
||||
{/* if (!multiToSingle) {*/}
|
||||
{/* newInputs.multi_key_mode = multiKeyMode;*/}
|
||||
{/* } else {*/}
|
||||
{/* delete newInputs.multi_key_mode;*/}
|
||||
{/* }*/}
|
||||
{/* return newInputs;*/}
|
||||
{/* });*/}
|
||||
{/* }}>{t('密钥聚合模式')}</Checkbox>*/}
|
||||
{/*)}*/}
|
||||
|
||||
setBatch(checked);
|
||||
if (!checked) {
|
||||
setMultiToSingle(false);
|
||||
setMultiKeyMode('random');
|
||||
} else {
|
||||
// 批量模式下禁用手动输入,并清空手动输入的内容
|
||||
setUseManualInput(false);
|
||||
if (inputs.type === 41) {
|
||||
// 清空手动输入的密钥内容
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('key', '');
|
||||
}
|
||||
handleInputChange('key', '');
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
{t('批量创建')}
|
||||
</Checkbox>
|
||||
)}
|
||||
{batch && (
|
||||
<Checkbox disabled={isEdit} checked={multiToSingle} onChange={() => {
|
||||
setMultiToSingle(prev => !prev);
|
||||
setInputs(prev => {
|
||||
const newInputs = { ...prev };
|
||||
if (!multiToSingle) {
|
||||
newInputs.multi_key_mode = multiKeyMode;
|
||||
} else {
|
||||
delete newInputs.multi_key_mode;
|
||||
}
|
||||
return newInputs;
|
||||
});
|
||||
}}>{t('密钥聚合模式')}</Checkbox>
|
||||
)}
|
||||
</Space>
|
||||
) : null;
|
||||
|
||||
@@ -854,7 +998,7 @@ const EditChannelModal = (props) => {
|
||||
rules={[{ required: true, message: t('请选择渠道类型') }]}
|
||||
optionList={channelOptionList}
|
||||
style={{ width: '100%' }}
|
||||
filter={modelSelectFilter}
|
||||
filter={selectFilter}
|
||||
autoClearSearchValue={false}
|
||||
searchPosition='dropdown'
|
||||
onSearch={(value) => setChannelSearchValue(value)}
|
||||
@@ -900,7 +1044,16 @@ const EditChannelModal = (props) => {
|
||||
autosize
|
||||
autoComplete='new-password'
|
||||
onChange={(value) => handleInputChange('key', value)}
|
||||
extraText={batchExtra}
|
||||
extraText={
|
||||
<div className="flex items-center gap-2">
|
||||
{isEdit && isMultiKeyChannel && keyMode === 'append' && (
|
||||
<Text type="warning" size="small">
|
||||
{t('追加模式:新密钥将添加到现有密钥列表的末尾')}
|
||||
</Text>
|
||||
)}
|
||||
{batchExtra}
|
||||
</div>
|
||||
}
|
||||
showClear
|
||||
/>
|
||||
)
|
||||
@@ -967,6 +1120,11 @@ const EditChannelModal = (props) => {
|
||||
<Text type="tertiary" size="small">
|
||||
{t('请输入完整的 JSON 格式密钥内容')}
|
||||
</Text>
|
||||
{isEdit && isMultiKeyChannel && keyMode === 'append' && (
|
||||
<Text type="warning" size="small">
|
||||
{t('追加模式:新密钥将添加到现有密钥列表的末尾')}
|
||||
</Text>
|
||||
)}
|
||||
{batchExtra}
|
||||
</div>
|
||||
}
|
||||
@@ -1000,13 +1158,44 @@ const EditChannelModal = (props) => {
|
||||
rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]}
|
||||
autoComplete='new-password'
|
||||
onChange={(value) => handleInputChange('key', value)}
|
||||
extraText={batchExtra}
|
||||
extraText={
|
||||
<div className="flex items-center gap-2">
|
||||
{isEdit && isMultiKeyChannel && keyMode === 'append' && (
|
||||
<Text type="warning" size="small">
|
||||
{t('追加模式:新密钥将添加到现有密钥列表的末尾')}
|
||||
</Text>
|
||||
)}
|
||||
{batchExtra}
|
||||
</div>
|
||||
}
|
||||
showClear
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{isEdit && isMultiKeyChannel && (
|
||||
<Form.Select
|
||||
field='key_mode'
|
||||
label={t('密钥更新模式')}
|
||||
placeholder={t('请选择密钥更新模式')}
|
||||
optionList={[
|
||||
{ label: t('追加到现有密钥'), value: 'append' },
|
||||
{ label: t('覆盖现有密钥'), value: 'replace' },
|
||||
]}
|
||||
style={{ width: '100%' }}
|
||||
value={keyMode}
|
||||
onChange={(value) => setKeyMode(value)}
|
||||
extraText={
|
||||
<Text type="tertiary" size="small">
|
||||
{keyMode === 'replace'
|
||||
? t('覆盖模式:将完全替换现有的所有密钥')
|
||||
: t('追加模式:将新密钥添加到现有密钥列表末尾')
|
||||
}
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{batch && multiToSingle && (
|
||||
<>
|
||||
<Form.Select
|
||||
@@ -1045,24 +1234,24 @@ const EditChannelModal = (props) => {
|
||||
)}
|
||||
|
||||
{inputs.type === 41 && (
|
||||
<Form.TextArea
|
||||
<JSONEditor
|
||||
field='other'
|
||||
label={t('部署地区')}
|
||||
placeholder={t(
|
||||
'请输入部署地区,例如:us-central1\n支持使用模型映射格式\n{\n "default": "us-central1",\n "claude-3-5-sonnet-20240620": "europe-west1"\n}'
|
||||
)}
|
||||
autosize
|
||||
value={inputs.other || ''}
|
||||
onChange={(value) => handleInputChange('other', value)}
|
||||
rules={[{ required: true, message: t('请填写部署地区') }]}
|
||||
template={REGION_EXAMPLE}
|
||||
templateLabel={t('填入模板')}
|
||||
editorType="region"
|
||||
formApi={formApiRef.current}
|
||||
extraText={
|
||||
<Text
|
||||
className="!text-semi-color-primary cursor-pointer"
|
||||
onClick={() => handleInputChange('other', JSON.stringify(REGION_EXAMPLE, null, 2))}
|
||||
>
|
||||
{t('填入模板')}
|
||||
<Text type="tertiary" size="small">
|
||||
{t('设置默认地区和特定模型的专用地区')}
|
||||
</Text>
|
||||
}
|
||||
showClear
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -1255,7 +1444,7 @@ const EditChannelModal = (props) => {
|
||||
placeholder={t('请选择该渠道所支持的模型')}
|
||||
rules={[{ required: true, message: t('请选择模型') }]}
|
||||
multiple
|
||||
filter={modelSelectFilter}
|
||||
filter={selectFilter}
|
||||
autoClearSearchValue={false}
|
||||
searchPosition='dropdown'
|
||||
optionList={modelOptions}
|
||||
@@ -1318,24 +1507,24 @@ const EditChannelModal = (props) => {
|
||||
showClear
|
||||
/>
|
||||
|
||||
<Form.TextArea
|
||||
<JSONEditor
|
||||
field='model_mapping'
|
||||
label={t('模型重定向')}
|
||||
placeholder={
|
||||
t('此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:') +
|
||||
`\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`
|
||||
}
|
||||
autosize
|
||||
value={inputs.model_mapping || ''}
|
||||
onChange={(value) => handleInputChange('model_mapping', value)}
|
||||
template={MODEL_MAPPING_EXAMPLE}
|
||||
templateLabel={t('填入模板')}
|
||||
editorType="keyValue"
|
||||
formApi={formApiRef.current}
|
||||
extraText={
|
||||
<Text
|
||||
className="!text-semi-color-primary cursor-pointer"
|
||||
onClick={() => handleInputChange('model_mapping', JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2))}
|
||||
>
|
||||
{t('填入模板')}
|
||||
<Text type="tertiary" size="small">
|
||||
{t('键为请求中的模型名称,值为要替换的模型名称')}
|
||||
</Text>
|
||||
}
|
||||
showClear
|
||||
/>
|
||||
</Card>
|
||||
|
||||
@@ -1400,7 +1589,7 @@ const EditChannelModal = (props) => {
|
||||
label={t('是否自动禁用')}
|
||||
checkedText={t('开')}
|
||||
uncheckedText={t('关')}
|
||||
onChange={(val) => setAutoBan(val)}
|
||||
onChange={(value) => setAutoBan(value)}
|
||||
extraText={t('仅当自动禁用开启时有效,关闭后不会自动禁用该渠道')}
|
||||
initValue={autoBan}
|
||||
/>
|
||||
@@ -1425,7 +1614,7 @@ const EditChannelModal = (props) => {
|
||||
showClear
|
||||
/>
|
||||
|
||||
<Form.TextArea
|
||||
<JSONEditor
|
||||
field='status_code_mapping'
|
||||
label={t('状态码复写')}
|
||||
placeholder={
|
||||
@@ -1433,45 +1622,78 @@ const EditChannelModal = (props) => {
|
||||
'\n' +
|
||||
JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2)
|
||||
}
|
||||
autosize
|
||||
value={inputs.status_code_mapping || ''}
|
||||
onChange={(value) => handleInputChange('status_code_mapping', value)}
|
||||
template={STATUS_CODE_MAPPING_EXAMPLE}
|
||||
templateLabel={t('填入模板')}
|
||||
editorType="keyValue"
|
||||
formApi={formApiRef.current}
|
||||
extraText={
|
||||
<Text
|
||||
className="!text-semi-color-primary cursor-pointer"
|
||||
onClick={() => handleInputChange('status_code_mapping', JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2))}
|
||||
>
|
||||
{t('填入模板')}
|
||||
<Text type="tertiary" size="small">
|
||||
{t('键为原状态码,值为要复写的状态码,仅影响本地判断')}
|
||||
</Text>
|
||||
}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* Channel Extra Settings Card */}
|
||||
<Card className="!rounded-2xl shadow-sm border-0 mb-6">
|
||||
{/* Header: Channel Extra Settings */}
|
||||
<div className="flex items-center mb-2">
|
||||
<Avatar size="small" color="violet" className="mr-2 shadow-md">
|
||||
<IconBolt size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className="text-lg font-medium">{t('渠道额外设置')}</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{inputs.type === 1 && (
|
||||
<Form.Switch
|
||||
field='force_format'
|
||||
label={t('强制格式化')}
|
||||
checkedText={t('开')}
|
||||
uncheckedText={t('关')}
|
||||
onChange={(value) => handleChannelSettingsChange('force_format', value)}
|
||||
extraText={t('强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)')}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Form.Switch
|
||||
field='thinking_to_content'
|
||||
label={t('思考内容转换')}
|
||||
checkedText={t('开')}
|
||||
uncheckedText={t('关')}
|
||||
onChange={(value) => handleChannelSettingsChange('thinking_to_content', value)}
|
||||
extraText={t('将 reasoning_content 转换为 <think> 标签拼接到内容中')}
|
||||
/>
|
||||
|
||||
<Form.Switch
|
||||
field='pass_through_body_enabled'
|
||||
label={t('透传请求体')}
|
||||
checkedText={t('开')}
|
||||
uncheckedText={t('关')}
|
||||
onChange={(value) => handleChannelSettingsChange('pass_through_body_enabled', value)}
|
||||
extraText={t('启用请求体透传功能')}
|
||||
/>
|
||||
|
||||
<Form.Input
|
||||
field='proxy'
|
||||
label={t('代理地址')}
|
||||
placeholder={t('例如: socks5://user:pass@host:port')}
|
||||
onChange={(value) => handleChannelSettingsChange('proxy', value)}
|
||||
showClear
|
||||
extraText={t('用于配置网络代理,支持 socks5 协议')}
|
||||
/>
|
||||
|
||||
<Form.TextArea
|
||||
field='setting'
|
||||
label={t('渠道额外设置')}
|
||||
placeholder={
|
||||
t('此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:') +
|
||||
'\n{\n "force_format": true\n}'
|
||||
}
|
||||
field='system_prompt'
|
||||
label={t('系统提示词')}
|
||||
placeholder={t('输入系统提示词,用户的系统提示词将优先于此设置')}
|
||||
onChange={(value) => handleChannelSettingsChange('system_prompt', value)}
|
||||
autosize
|
||||
onChange={(value) => handleInputChange('setting', value)}
|
||||
extraText={(
|
||||
<Space wrap>
|
||||
<Text
|
||||
className="!text-semi-color-primary cursor-pointer"
|
||||
onClick={() => handleInputChange('setting', JSON.stringify({ force_format: true }, null, 2))}
|
||||
>
|
||||
{t('填入模板')}
|
||||
</Text>
|
||||
<Text
|
||||
className="!text-semi-color-primary cursor-pointer"
|
||||
onClick={() => window.open('https://github.com/QuantumNous/new-api/blob/main/docs/channel/other_setting.md')}
|
||||
>
|
||||
{t('设置说明')}
|
||||
</Text>
|
||||
</Space>
|
||||
)}
|
||||
showClear
|
||||
extraText={t('用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置')}
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
@@ -1484,6 +1706,17 @@ const EditChannelModal = (props) => {
|
||||
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
|
||||
/>
|
||||
</SideSheet>
|
||||
<ModelSelectModal
|
||||
visible={modelModalVisible}
|
||||
models={fetchedModels}
|
||||
selected={inputs.models}
|
||||
onConfirm={(selectedModels) => {
|
||||
handleInputChange('models', selectedModels);
|
||||
showSuccess(t('模型列表已更新'));
|
||||
setModelModalVisible(false);
|
||||
}}
|
||||
onCancel={() => setModelModalVisible(false)}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -25,7 +25,7 @@ import {
|
||||
showSuccess,
|
||||
showWarning,
|
||||
verifyJSON,
|
||||
modelSelectFilter,
|
||||
selectFilter,
|
||||
} from '../../../../helpers';
|
||||
import {
|
||||
SideSheet,
|
||||
@@ -395,7 +395,7 @@ const EditTagModal = (props) => {
|
||||
label={t('模型')}
|
||||
placeholder={t('请选择该渠道所支持的模型,留空则不更改')}
|
||||
multiple
|
||||
filter={modelSelectFilter}
|
||||
filter={selectFilter}
|
||||
autoClearSearchValue={false}
|
||||
searchPosition='dropdown'
|
||||
optionList={modelOptions}
|
||||
|
||||
272
web/src/components/table/channels/modals/ModelSelectModal.jsx
Normal file
272
web/src/components/table/channels/modals/ModelSelectModal.jsx
Normal file
@@ -0,0 +1,272 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { useIsMobile } from '../../../../hooks/common/useIsMobile.js';
|
||||
import { Modal, Checkbox, Spin, Input, Typography, Empty, Tabs, Collapse } from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IllustrationNoResult,
|
||||
IllustrationNoResultDark
|
||||
} from '@douyinfe/semi-illustrations';
|
||||
import { IconSearch } from '@douyinfe/semi-icons';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { getModelCategories } from '../../../../helpers/render';
|
||||
|
||||
const ModelSelectModal = ({ visible, models = [], selected = [], onConfirm, onCancel }) => {
|
||||
const { t } = useTranslation();
|
||||
const [checkedList, setCheckedList] = useState(selected);
|
||||
const [keyword, setKeyword] = useState('');
|
||||
const [activeTab, setActiveTab] = useState('new');
|
||||
|
||||
const isMobile = useIsMobile();
|
||||
|
||||
const filteredModels = models.filter((m) => m.toLowerCase().includes(keyword.toLowerCase()));
|
||||
|
||||
// 分类模型:新获取的模型和已有模型
|
||||
const newModels = filteredModels.filter(model => !selected.includes(model));
|
||||
const existingModels = filteredModels.filter(model => selected.includes(model));
|
||||
|
||||
// 同步外部选中值
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
setCheckedList(selected);
|
||||
}
|
||||
}, [visible, selected]);
|
||||
|
||||
// 当模型列表变化时,设置默认tab
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
// 默认显示新获取模型tab,如果没有新模型则显示已有模型
|
||||
const hasNewModels = newModels.length > 0;
|
||||
setActiveTab(hasNewModels ? 'new' : 'existing');
|
||||
}
|
||||
}, [visible, newModels.length, selected]);
|
||||
|
||||
const handleOk = () => {
|
||||
onConfirm && onConfirm(checkedList);
|
||||
};
|
||||
|
||||
// 按厂商分类模型
|
||||
const categorizeModels = (models) => {
|
||||
const categories = getModelCategories(t);
|
||||
const categorizedModels = {};
|
||||
const uncategorizedModels = [];
|
||||
|
||||
models.forEach(model => {
|
||||
let foundCategory = false;
|
||||
for (const [key, category] of Object.entries(categories)) {
|
||||
if (key !== 'all' && category.filter({ model_name: model })) {
|
||||
if (!categorizedModels[key]) {
|
||||
categorizedModels[key] = {
|
||||
label: category.label,
|
||||
icon: category.icon,
|
||||
models: []
|
||||
};
|
||||
}
|
||||
categorizedModels[key].models.push(model);
|
||||
foundCategory = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!foundCategory) {
|
||||
uncategorizedModels.push(model);
|
||||
}
|
||||
});
|
||||
|
||||
// 如果有未分类模型,添加到"其他"分类
|
||||
if (uncategorizedModels.length > 0) {
|
||||
categorizedModels['other'] = {
|
||||
label: t('其他'),
|
||||
icon: null,
|
||||
models: uncategorizedModels
|
||||
};
|
||||
}
|
||||
|
||||
return categorizedModels;
|
||||
};
|
||||
|
||||
const newModelsByCategory = categorizeModels(newModels);
|
||||
const existingModelsByCategory = categorizeModels(existingModels);
|
||||
|
||||
// Tab列表配置
|
||||
const tabList = [
|
||||
...(newModels.length > 0 ? [{
|
||||
tab: `${t('新获取的模型')} (${newModels.length})`,
|
||||
itemKey: 'new'
|
||||
}] : []),
|
||||
...(existingModels.length > 0 ? [{
|
||||
tab: `${t('已有的模型')} (${existingModels.length})`,
|
||||
itemKey: 'existing'
|
||||
}] : [])
|
||||
];
|
||||
|
||||
// 处理分类全选/取消全选
|
||||
const handleCategorySelectAll = (categoryModels, isChecked) => {
|
||||
let newCheckedList = [...checkedList];
|
||||
|
||||
if (isChecked) {
|
||||
// 全选:添加该分类下所有未选中的模型
|
||||
categoryModels.forEach(model => {
|
||||
if (!newCheckedList.includes(model)) {
|
||||
newCheckedList.push(model);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 取消全选:移除该分类下所有已选中的模型
|
||||
newCheckedList = newCheckedList.filter(model => !categoryModels.includes(model));
|
||||
}
|
||||
|
||||
setCheckedList(newCheckedList);
|
||||
};
|
||||
|
||||
// 检查分类是否全选
|
||||
const isCategoryAllSelected = (categoryModels) => {
|
||||
return categoryModels.length > 0 && categoryModels.every(model => checkedList.includes(model));
|
||||
};
|
||||
|
||||
// 检查分类是否部分选中
|
||||
const isCategoryIndeterminate = (categoryModels) => {
|
||||
const selectedCount = categoryModels.filter(model => checkedList.includes(model)).length;
|
||||
return selectedCount > 0 && selectedCount < categoryModels.length;
|
||||
};
|
||||
|
||||
const renderModelsByCategory = (modelsByCategory, categoryKeyPrefix) => {
|
||||
const categoryEntries = Object.entries(modelsByCategory);
|
||||
if (categoryEntries.length === 0) return null;
|
||||
|
||||
// 生成所有面板的key,确保都展开
|
||||
const allActiveKeys = categoryEntries.map((_, index) => `${categoryKeyPrefix}_${index}`);
|
||||
|
||||
return (
|
||||
<Collapse activeKey={allActiveKeys}>
|
||||
{categoryEntries.map(([key, categoryData], index) => (
|
||||
<Collapse.Panel
|
||||
key={`${categoryKeyPrefix}_${index}`}
|
||||
itemKey={`${categoryKeyPrefix}_${index}`}
|
||||
header={`${categoryData.label} (${categoryData.models.length})`}
|
||||
extra={
|
||||
<Checkbox
|
||||
checked={isCategoryAllSelected(categoryData.models)}
|
||||
indeterminate={isCategoryIndeterminate(categoryData.models)}
|
||||
onChange={(e) => {
|
||||
e.stopPropagation(); // 防止触发面板折叠
|
||||
handleCategorySelectAll(categoryData.models, e.target.checked);
|
||||
}}
|
||||
onClick={(e) => e.stopPropagation()} // 防止点击checkbox时折叠面板
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-2 mb-3">
|
||||
{categoryData.icon}
|
||||
<Typography.Text type="secondary" size="small">
|
||||
{t('已选择 {{selected}} / {{total}}', {
|
||||
selected: categoryData.models.filter(model => checkedList.includes(model)).length,
|
||||
total: categoryData.models.length
|
||||
})}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
<div className="grid grid-cols-2 gap-x-4">
|
||||
{categoryData.models.map((model) => (
|
||||
<Checkbox key={model} value={model} className="my-1">
|
||||
{model}
|
||||
</Checkbox>
|
||||
))}
|
||||
</div>
|
||||
</Collapse.Panel>
|
||||
))}
|
||||
</Collapse>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
header={
|
||||
<div className="flex flex-col sm:flex-row sm:items-center sm:justify-between gap-2 sm:gap-4 py-4">
|
||||
<Typography.Title heading={5} className="m-0">
|
||||
{t('选择模型')}
|
||||
</Typography.Title>
|
||||
<div className="flex-shrink-0">
|
||||
<Tabs
|
||||
type="slash"
|
||||
size="small"
|
||||
tabList={tabList}
|
||||
activeKey={activeTab}
|
||||
onChange={(key) => setActiveTab(key)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
visible={visible}
|
||||
onOk={handleOk}
|
||||
onCancel={onCancel}
|
||||
okText={t('确定')}
|
||||
cancelText={t('取消')}
|
||||
size={isMobile ? 'full-width' : 'large'}
|
||||
closeOnEsc
|
||||
maskClosable
|
||||
centered
|
||||
>
|
||||
<Input
|
||||
prefix={<IconSearch size={14} />}
|
||||
placeholder={t('搜索模型')}
|
||||
value={keyword}
|
||||
onChange={(v) => setKeyword(v)}
|
||||
showClear
|
||||
/>
|
||||
|
||||
<Spin spinning={!models || models.length === 0}>
|
||||
<div style={{ maxHeight: 400, overflowY: 'auto', paddingRight: 8 }}>
|
||||
{filteredModels.length === 0 ? (
|
||||
<Empty
|
||||
image={<IllustrationNoResult style={{ width: 150, height: 150 }} />}
|
||||
darkModeImage={<IllustrationNoResultDark style={{ width: 150, height: 150 }} />}
|
||||
description={t('暂无匹配模型')}
|
||||
style={{ padding: 30 }}
|
||||
/>
|
||||
) : (
|
||||
<Checkbox.Group value={checkedList} onChange={(vals) => setCheckedList(vals)}>
|
||||
{activeTab === 'new' && newModels.length > 0 && (
|
||||
<div>
|
||||
{renderModelsByCategory(newModelsByCategory, 'new')}
|
||||
</div>
|
||||
)}
|
||||
{activeTab === 'existing' && existingModels.length > 0 && (
|
||||
<div>
|
||||
{renderModelsByCategory(existingModelsByCategory, 'existing')}
|
||||
</div>
|
||||
)}
|
||||
</Checkbox.Group>
|
||||
)}
|
||||
</div>
|
||||
</Spin>
|
||||
|
||||
<Typography.Text type="secondary" size="small" className="block text-right mt-4">
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{(() => {
|
||||
const currentModels = activeTab === 'new' ? newModels : existingModels;
|
||||
const currentSelected = currentModels.filter(model => checkedList.includes(model)).length;
|
||||
const isAllSelected = currentModels.length > 0 && currentSelected === currentModels.length;
|
||||
const isIndeterminate = currentSelected > 0 && currentSelected < currentModels.length;
|
||||
|
||||
return (
|
||||
<>
|
||||
<span>
|
||||
{t('已选择 {{selected}} / {{total}}', {
|
||||
selected: currentSelected,
|
||||
total: currentModels.length
|
||||
})}
|
||||
</span>
|
||||
<Checkbox
|
||||
checked={isAllSelected}
|
||||
indeterminate={isIndeterminate}
|
||||
onChange={(e) => {
|
||||
handleCategorySelectAll(currentModels, e.target.checked);
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
})()}
|
||||
</div>
|
||||
</Typography.Text>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default ModelSelectModal;
|
||||
@@ -26,7 +26,7 @@ import {
|
||||
renderGroupOption,
|
||||
renderQuotaWithPrompt,
|
||||
getModelCategories,
|
||||
modelSelectFilter,
|
||||
selectFilter,
|
||||
} from '../../../../helpers';
|
||||
import { useIsMobile } from '../../../../hooks/common/useIsMobile.js';
|
||||
import {
|
||||
@@ -514,7 +514,7 @@ const EditTokenModal = (props) => {
|
||||
multiple
|
||||
optionList={models}
|
||||
extraText={t('非必要,不建议启用模型限制')}
|
||||
filter={modelSelectFilter}
|
||||
filter={selectFilter}
|
||||
autoClearSearchValue={false}
|
||||
searchPosition='dropdown'
|
||||
showClear
|
||||
|
||||
@@ -154,6 +154,11 @@ export const CHANNEL_OPTIONS = [
|
||||
color: 'blue',
|
||||
label: '即梦',
|
||||
},
|
||||
{
|
||||
value: 52,
|
||||
color: 'purple',
|
||||
label: 'Vidu',
|
||||
},
|
||||
];
|
||||
|
||||
export const MODEL_TABLE_PAGE_SIZE = 10;
|
||||
|
||||
@@ -883,12 +883,22 @@ export function renderQuotaWithAmount(amount) {
|
||||
}
|
||||
|
||||
export function renderQuota(quota, digits = 2) {
|
||||
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||
displayInCurrency = displayInCurrency === 'true';
|
||||
if (displayInCurrency) {
|
||||
return '$' + (quota / quotaPerUnit).toFixed(digits);
|
||||
const result = quota / quotaPerUnit;
|
||||
const fixedResult = result.toFixed(digits);
|
||||
|
||||
// 如果 toFixed 后结果为 0 但原始值不为 0,显示最小值
|
||||
if (parseFloat(fixedResult) === 0 && quota > 0 && result > 0) {
|
||||
const minValue = Math.pow(10, -digits);
|
||||
return '$' + minValue.toFixed(digits);
|
||||
}
|
||||
|
||||
return '$' + fixedResult;
|
||||
}
|
||||
return renderNumber(quota);
|
||||
}
|
||||
|
||||
@@ -560,12 +560,16 @@ export function setTableCompactMode(compact, tableKey = 'global') {
|
||||
|
||||
// -------------------------------
|
||||
// Select 组件统一过滤逻辑
|
||||
// 解决 label 为 ReactNode(带图标等)时无法用内置 filter 搜索的问题。
|
||||
// 使用方式: <Select filter={modelSelectFilter} ... />
|
||||
export const modelSelectFilter = (input, option) => {
|
||||
// 使用方式: <Select filter={selectFilter} ... />
|
||||
// 统一的 Select 搜索过滤逻辑 -- 支持同时匹配 option.value 与 option.label
|
||||
export const selectFilter = (input, option) => {
|
||||
if (!input) return true;
|
||||
const val = (option?.value || '').toString().toLowerCase();
|
||||
return val.includes(input.trim().toLowerCase());
|
||||
|
||||
const keyword = input.trim().toLowerCase();
|
||||
const valueText = (option?.value ?? '').toString().toLowerCase();
|
||||
const labelText = (option?.label ?? '').toString().toLowerCase();
|
||||
|
||||
return valueText.includes(keyword) || labelText.includes(keyword);
|
||||
};
|
||||
|
||||
// -------------------------------
|
||||
|
||||
@@ -60,6 +60,8 @@ export const useMjLogsData = () => {
|
||||
|
||||
// User and admin
|
||||
const isAdminUser = isAdmin();
|
||||
// Role-specific storage key to prevent different roles from overwriting each other
|
||||
const STORAGE_KEY = isAdminUser ? 'mj-logs-table-columns-admin' : 'mj-logs-table-columns-user';
|
||||
|
||||
// Modal states
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
@@ -88,12 +90,18 @@ export const useMjLogsData = () => {
|
||||
|
||||
// Load saved column preferences from localStorage
|
||||
useEffect(() => {
|
||||
const savedColumns = localStorage.getItem('mj-logs-table-columns');
|
||||
const savedColumns = localStorage.getItem(STORAGE_KEY);
|
||||
if (savedColumns) {
|
||||
try {
|
||||
const parsed = JSON.parse(savedColumns);
|
||||
const defaults = getDefaultColumnVisibility();
|
||||
const merged = { ...defaults, ...parsed };
|
||||
|
||||
// For non-admin users, force-hide admin-only columns (does not touch admin settings)
|
||||
if (!isAdminUser) {
|
||||
merged[COLUMN_KEYS.CHANNEL] = false;
|
||||
merged[COLUMN_KEYS.SUBMIT_RESULT] = false;
|
||||
}
|
||||
setVisibleColumns(merged);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse saved column preferences', e);
|
||||
@@ -134,7 +142,7 @@ export const useMjLogsData = () => {
|
||||
const initDefaultColumns = () => {
|
||||
const defaults = getDefaultColumnVisibility();
|
||||
setVisibleColumns(defaults);
|
||||
localStorage.setItem('mj-logs-table-columns', JSON.stringify(defaults));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(defaults));
|
||||
};
|
||||
|
||||
// Handle column visibility change
|
||||
@@ -162,10 +170,10 @@ export const useMjLogsData = () => {
|
||||
setVisibleColumns(updatedColumns);
|
||||
};
|
||||
|
||||
// Update table when column visibility changes
|
||||
// Persist column settings to the role-specific STORAGE_KEY
|
||||
useEffect(() => {
|
||||
if (Object.keys(visibleColumns).length > 0) {
|
||||
localStorage.setItem('mj-logs-table-columns', JSON.stringify(visibleColumns));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(visibleColumns));
|
||||
}
|
||||
}, [visibleColumns]);
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ export const useTaskLogsData = () => {
|
||||
|
||||
// User and admin
|
||||
const isAdminUser = isAdmin();
|
||||
// Role-specific storage key to prevent different roles from overwriting each other
|
||||
const STORAGE_KEY = isAdminUser ? 'task-logs-table-columns-admin' : 'task-logs-table-columns-user';
|
||||
|
||||
// Modal state
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
@@ -86,12 +88,17 @@ export const useTaskLogsData = () => {
|
||||
|
||||
// Load saved column preferences from localStorage
|
||||
useEffect(() => {
|
||||
const savedColumns = localStorage.getItem('task-logs-table-columns');
|
||||
const savedColumns = localStorage.getItem(STORAGE_KEY);
|
||||
if (savedColumns) {
|
||||
try {
|
||||
const parsed = JSON.parse(savedColumns);
|
||||
const defaults = getDefaultColumnVisibility();
|
||||
const merged = { ...defaults, ...parsed };
|
||||
|
||||
// For non-admin users, force-hide admin-only columns (does not touch admin settings)
|
||||
if (!isAdminUser) {
|
||||
merged[COLUMN_KEYS.CHANNEL] = false;
|
||||
}
|
||||
setVisibleColumns(merged);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse saved column preferences', e);
|
||||
@@ -123,7 +130,7 @@ export const useTaskLogsData = () => {
|
||||
const initDefaultColumns = () => {
|
||||
const defaults = getDefaultColumnVisibility();
|
||||
setVisibleColumns(defaults);
|
||||
localStorage.setItem('task-logs-table-columns', JSON.stringify(defaults));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(defaults));
|
||||
};
|
||||
|
||||
// Handle column visibility change
|
||||
@@ -148,10 +155,10 @@ export const useTaskLogsData = () => {
|
||||
setVisibleColumns(updatedColumns);
|
||||
};
|
||||
|
||||
// Update table when column visibility changes
|
||||
// Persist column settings to the role-specific STORAGE_KEY
|
||||
useEffect(() => {
|
||||
if (Object.keys(visibleColumns).length > 0) {
|
||||
localStorage.setItem('task-logs-table-columns', JSON.stringify(visibleColumns));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(visibleColumns));
|
||||
}
|
||||
}, [visibleColumns]);
|
||||
|
||||
|
||||
@@ -74,6 +74,8 @@ export const useLogsData = () => {
|
||||
|
||||
// User and admin
|
||||
const isAdminUser = isAdmin();
|
||||
// Role-specific storage key to prevent different roles from overwriting each other
|
||||
const STORAGE_KEY = isAdminUser ? 'logs-table-columns-admin' : 'logs-table-columns-user';
|
||||
|
||||
// Statistics state
|
||||
const [stat, setStat] = useState({
|
||||
@@ -110,12 +112,19 @@ export const useLogsData = () => {
|
||||
|
||||
// Load saved column preferences from localStorage
|
||||
useEffect(() => {
|
||||
const savedColumns = localStorage.getItem('logs-table-columns');
|
||||
const savedColumns = localStorage.getItem(STORAGE_KEY);
|
||||
if (savedColumns) {
|
||||
try {
|
||||
const parsed = JSON.parse(savedColumns);
|
||||
const defaults = getDefaultColumnVisibility();
|
||||
const merged = { ...defaults, ...parsed };
|
||||
|
||||
// For non-admin users, force-hide admin-only columns (does not touch admin settings)
|
||||
if (!isAdminUser) {
|
||||
merged[COLUMN_KEYS.CHANNEL] = false;
|
||||
merged[COLUMN_KEYS.USERNAME] = false;
|
||||
merged[COLUMN_KEYS.RETRY] = false;
|
||||
}
|
||||
setVisibleColumns(merged);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse saved column preferences', e);
|
||||
@@ -150,7 +159,7 @@ export const useLogsData = () => {
|
||||
const initDefaultColumns = () => {
|
||||
const defaults = getDefaultColumnVisibility();
|
||||
setVisibleColumns(defaults);
|
||||
localStorage.setItem('logs-table-columns', JSON.stringify(defaults));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(defaults));
|
||||
};
|
||||
|
||||
// Handle column visibility change
|
||||
@@ -180,13 +189,10 @@ export const useLogsData = () => {
|
||||
setVisibleColumns(updatedColumns);
|
||||
};
|
||||
|
||||
// Update table when column visibility changes
|
||||
// Persist column settings to the role-specific STORAGE_KEY
|
||||
useEffect(() => {
|
||||
if (Object.keys(visibleColumns).length > 0) {
|
||||
localStorage.setItem(
|
||||
'logs-table-columns',
|
||||
JSON.stringify(visibleColumns),
|
||||
);
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(visibleColumns));
|
||||
}
|
||||
}, [visibleColumns]);
|
||||
|
||||
|
||||
@@ -1330,6 +1330,18 @@
|
||||
"API地址": "Base URL",
|
||||
"对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in",
|
||||
"渠道额外设置": "Channel extra settings",
|
||||
"强制格式化": "Force format",
|
||||
"强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)": "Force format responses to OpenAI standard format (Only for OpenAI channel types)",
|
||||
"思考内容转换": "Thinking content conversion",
|
||||
"将 reasoning_content 转换为 <think> 标签拼接到内容中": "Convert reasoning_content to <think> tags and append to content",
|
||||
"透传请求体": "Pass through body",
|
||||
"启用请求体透传功能": "Enable request body pass-through functionality",
|
||||
"代理地址": "Proxy address",
|
||||
"例如: socks5://user:pass@host:port": "e.g.: socks5://user:pass@host:port",
|
||||
"用于配置网络代理,支持 socks5 协议": "Used to configure network proxy, supports socks5 protocol",
|
||||
"系统提示词": "System Prompt",
|
||||
"输入系统提示词,用户的系统提示词将优先于此设置": "Enter system prompt, user's system prompt will take priority over this setting",
|
||||
"用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置": "User priority: If the user specifies a system prompt in the request, the user's setting will be used first",
|
||||
"参数覆盖": "Parameters override",
|
||||
"模型请求速率限制": "Model request rate limit",
|
||||
"启用用户模型请求速率限制(可能会影响高并发性能)": "Enable user model request rate limit (may affect high concurrency performance)",
|
||||
@@ -1787,5 +1799,10 @@
|
||||
"显示第": "Showing",
|
||||
"条 - 第": "to",
|
||||
"条,共": "of",
|
||||
"条": "items"
|
||||
"条": "items",
|
||||
"选择模型": "Select model",
|
||||
"已选择 {{selected}} / {{total}}": "Selected {{selected}} / {{total}}",
|
||||
"新获取的模型": "New models",
|
||||
"已有的模型": "Existing models",
|
||||
"搜索模型": "Search models"
|
||||
}
|
||||
@@ -147,6 +147,7 @@ export default function RequestRateLimit(props) {
|
||||
label={t('用户每周期最多请求次数')}
|
||||
step={1}
|
||||
min={0}
|
||||
max={100000000}
|
||||
suffix={t('次')}
|
||||
extraText={t('包括失败请求的次数,0代表不限制')}
|
||||
field={'ModelRequestRateLimitCount'}
|
||||
@@ -163,6 +164,7 @@ export default function RequestRateLimit(props) {
|
||||
label={t('用户每周期最多请求完成次数')}
|
||||
step={1}
|
||||
min={1}
|
||||
max={100000000}
|
||||
suffix={t('次')}
|
||||
extraText={t('只包括请求成功的次数')}
|
||||
field={'ModelRequestRateLimitSuccessCount'}
|
||||
@@ -199,6 +201,7 @@ export default function RequestRateLimit(props) {
|
||||
<li>{t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}</li>
|
||||
<li>{t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}</li>
|
||||
<li>{t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}</li>
|
||||
<li>{t('[最多请求次数]和[最多请求完成次数]的最大值为2147483647。')}</li>
|
||||
<li>{t('分组速率配置优先级高于全局速率限制。')}</li>
|
||||
<li>{t('限制周期统一使用上方配置的“限制周期”值。')}</li>
|
||||
</ul>
|
||||
|
||||
Reference in New Issue
Block a user